import xarray as xr
import glob, os, sys, shutil
from matplotlib import pyplot as plt
import pandas as pd
import numpy as np
import geopandas as gpd
import xesmf as xe
import shapely


regs = gpd.read_file("../data/domain.geojson")
regs["geometry"] = shapely.segmentize(regs["geometry"].values, max_segment_length=0.1)


if len(sys.argv)!=3:
    print("needs 2 arguments: ensemble experiment")
else:
#    var=sys.argv[1]
    ensemble="cmip6"
    ensemble=sys.argv[1]
    experiment=sys.argv[2]
    
    var="tasmax"
    
    print(f"processing {var}")
    
    searchstring=f"../data/{ensemble}/*/{experiment}/*/*/day/wcape/{var}_*"
        
    print(searchstring)
    
    files=glob.glob(searchstring)
    print(len(files))

    for file in files:
        
        outdir=os.path.dirname(file).replace("wcape","event_domain")
        if not os.path.exists(outdir):
            os.makedirs(outdir)
            shutil.chown(outdir, user=None, group="afriverse")
            os.chmod(outdir, 0o775)
            
        outfile="{}/{}".format(outdir, os.path.basename(file))
        #os.remove(outfile)
        if not os.path.exists(outfile):
            print(file)
            ds=xr.open_dataset(file)
            if ensemble=="cordex":
                if "rlat" in ds.dims:
                    ds=ds.rename({"lat":"latgrid","rlat":"latitude","lon":"longrid","rlon":"longitude"}).drop_vars(["latgrid","longrid"])
                else:
                    #not sure why this works, but it does! original x and y are not like in other models the "correct" rlat and rlon
                    ds=ds.rename({"lat":"latgrid","y":"latitude","lon":"longrid","x":"longitude"}).drop_vars(["latgrid","longrid"])
                #if "x" in ds.dims:
                #    ds=ds.rename({"lat":"latitude","lon":"longitude","y":"rlat","x":"rlon"})
                #    ds = ds.set_coords(["rlon", "rlat"])

            if "bounds" in ds["time"].attrs:
                print("deleting")
                del ds["time"].attrs["bounds"]
            savg = xe.SpatialAverager(ds, regs.geometry, geom_dim_name="NAME")
            out = savg(ds.tasmax, skipna=True, na_thres=0.5)
            out = out.assign_coords(domain=xr.DataArray(regs["NAME"], dims=("NAME",)))
            out=out.to_dataset(name=var)
            out.to_netcdf(outfile)
            print("written",outfile)
            #sys.exit(
            shutil.chown(outfile, user=None, group="afriverse")
            os.chmod(outfile, 0o775)
