#!/usr/bin/env python

#--------------------------------------------------------------
#
# creates jobsfile for subsetting to the domain
#
# jobsfile will contain a list composed of variable name, directories in which there are data for given variable, and target directory 
# each line will then be passed to the subset.sh script
#
# Piotr, Jan 2022
#
# status: works, but needs some error catching and help??
#--------------------------------------------------------------


import sys, shutil
import os.path
import os
import glob
import netCDF4
import numpy as np

#python makejobs_subset.py $searchpath $outputrootdir $vars $rcps member $jobsfile

searchpath=sys.argv[1]
outputrootdir=sys.argv[2] #where the data goes. This will substutute /terra/data in the search path
subsetcode=sys.argv[3]
varnames=sys.argv[4] #comma-separated list of variables
frequencies=sys.argv[5] #comma-separated list of variables
activities=sys.argv[6] #name "activities" is used here for compatibility with terms used to describe data model on terra, but it correspoinds to rcps for gcm data, it's a comma-separated list of rcps
jobsfile=sys.argv[7] #name of the file to be created
cdocmd=sys.argv[8] #full cdo command with coordinates

varlist=varnames.split(",")
activitylist=activities.split(",") #rcp for gcm data
freqlist=frequencies.split(",") #rcp for gcm data

print("processing\n  searchpath: {}\n  variables: {} \n  frequencies: {}\n  activities: {}".format(searchpath, varnames,frequencies,activities))

if os.path.exists(jobsfile):
   os.remove(jobsfile)





# examples of paths
# projections data
# /terra/data/cmip5/global/rcp85/MIROC5/r1i1p1/day/native
# observed data
# /terra/data/observed/africa/gridded/CHC/CHIRPS-2.0-0p05/day/native
# reanalysis data
# /terra/data/reanalysis/global/reanalysis/ECMWF/ERA5/day/native/

#making sure the path does not end in /
if searchpath[-1]=="/":
    searchpath=searchpath[:-1]


#extracting elements from the source search path
parts = searchpath.split('/')

#check the searchpath
if len(parts)<7:
   print("expected at least 7 levels in the searchpath, got {}\n{}\nexiting...".format(len(parts), searchpath))
   sys.exit()

datatype = parts[-7]
domain = parts[-6]
activity = parts[-5]    #rcp or gridded for observational data
organization = parts[-4]
product = parts[-3] #ripf
freq = parts[-2]
grid = parts[-1]



#---------------------------------------------------------------------------------------------
#finding files

#iterating through directories to get ones with data
njobs=0
with open(jobsfile, "w") as outf:
    #iterating through actvities (or rcps)
    for activity in activitylist:
        print("activity,org: {},{}".format(activity, organization))
        searchpath="{}/{}/{}/{}/{}".format("/".join(parts[0:-7]),datatype,domain,activity,organization)
        print(searchpath)
        allorganizations=sorted(glob.glob(searchpath))
        print("{} {}".format(searchpath, len(allorganizations)))
        #iterating through organizations (or models)
        for orgdir in allorganizations:
            print("processing {}".format(orgdir))
            orgdir=orgdir.split("/")[-1]
            searchpath="{}/{}/{}/{}/{}/*".format("/".join(parts[0:-7]), datatype,domain,activity,orgdir)
            allproducts=np.array(glob.glob(searchpath))
            if product=="first":
                #sorting, intended for cmip5, cmip6 and cordex data
                ripf=[os.path.basename(x) for x in allproducts]
                ripflist=[x.replace("r"," ").replace("i"," ").replace("p"," ").replace("f"," ").split(" ") for x in ripf]
                sortind=np.argsort(["".join([x.zfill(2) for x in z]) for z in ripflist])
                allproducts=allproducts[sortind]

            #iterating through products (or ripf codes)
            if len(allproducts)>0:
                if product=="first":
                    allproducts=[allproducts[0]]
                    print("picking one member only:", allproducts)
                elif  product != "*":
                    allproducts=[product]
                for productdir in allproducts:
                    productdir=productdir.split("/")[-1]
                    for freq in freqlist:
                        for var in varlist:
                            inputdir="{}/{}/{}/{}/{}/{}/{}/{}".format("/".join(parts[0:-7]), datatype,domain,activity,orgdir,productdir,freq,grid)
                            searchpath="{}/{}_*".format(inputdir,var)
                            print("searching: {}".format(searchpath))
                            filesindir=glob.glob(searchpath)

                            if len(filesindir)==0:
                                print("found 0 files, skipping")
                            else:
                                print("found {} files, adding directory to jobsfile".format(len(filesindir)))

                                #
                                #outputdir is constructed below
                                #inputdir will be e.g.: 
                                # /terra/data/cmip5/global/rcp85/MIROC5/r1i1p1/day/native
                                # or /terra/data/observed/africa/gridded/CHC/CHIRPS-2.0-0p05/day/native
                                # outputdir will be constructed by substituting outputroot dir for /terra/data/
                                
                                outputdir=inputdir.split("/")
                                outputdir="{}/{}/{}".format(outputrootdir, "/".join(outputdir[-7:-1]),subsetcode)
                                #
                                # writing to jobsfile
                                subsetcmd="./do_subset.sh {} {} {} {}\n".format(var, inputdir, outputdir, cdocmd)
                                outf.write(subsetcmd)
                                njobs=njobs+1
            print("")

shutil.chown(jobsfile, user=None, group="afriverse")
os.chmod(jobsfile, 0o775)

batchsize=20
print("submittig {} jobs in batches of {}".format(njobs,batchsize))
cmd='sbatch --array=1-{}%{} run.slrm {}'.format(njobs,batchsize, jobsfile)
print(cmd)
os.system(cmd)



