import matplotlib as mpl
mpl.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.cm as cm
from mpl_toolkits.basemap import Basemap
import netCDF4
import numpy as np
import pandas as pd
import math as ma
import calendar as cd
from scipy.spatial.distance import pdist, squareform


def nan_helper(y):
    return np.isnan(y), lambda z: z.nonzero()[0]


def find_index_nearest(array,value):
    idx = (np.abs(array-value)).argmin()
    return idx

def months_to_seasons(array):
    winter=[12,1,2]
    spring=[3,4,5]
    summer=[6,7,8]
    autumn=[9,10,11]
    seasons=[spring,summer,autumn,winter]
    idxoff=[1,2,len()]
    return 

def histo(map_obs,map_mod,mask_05,nb_zones,zones,ind_zones,bins,noun,yvec):
    fig, ax = plt.subplots(4,3,figsize=(20,15))
    fig.suptitle(noun+' Distribution Comparison of Soil Moisture [m3/m3]'+yvec, fontsize=18)
    for i, ax in enumerate(ax.flat,start=1):
        zo=i-1
        nc=netCDF4.Dataset(mask_05)
        reg_mask_05 = nc.variables["mask_region"][ind_zones[zo],:,:]
        nc.close()
        print(zones[zo])
        aux_obs=map_obs*reg_mask_05
        aux_obs[aux_obs==0]=np.nan
        aux_mod=map_mod+aux_obs-aux_obs
        n_obs, bins_obs, patches_obs = ax.hist(aux_obs.flatten(), bins, facecolor='Maroon', alpha=0.5, label='ESA-CCI')
        n_mod, bins_mod, patches_mod = ax.hist(aux_mod.flatten(), bins, facecolor='DarkGreen', alpha=0.5, label='ORCHIDEE')
        ax.set_title(zones[zo],fontsize=20)
        #ax.ylabel('Number of data',fontsize=20)
        #plt.setp(ax.get_xticklabels(), fontsize=18)
        #plt.setp(ax.get_yticklabels(), fontsize=18)
        if (i==1):
            ax.legend(loc='upper right',prop={'size':18})
    plt.subplots_adjust(left=0.15)
    plt.savefig('/home/wflmd/PC/seasoncycle/output/hist_regions_'+noun+'_'+yvec+'_winter.png')
    plt.close()

def seasonal_cycle(map_mod,areacell,mask_05,nb_zones,zones,ind_zones,noun,ntime,isimu):
    fig, ax = plt.subplots(4,3,figsize=(20,15))
    fig.suptitle(noun+' seasonal_cycle '+isimu, fontsize=18)
    tab=np.zeros((12,nb_zones))
    for i, ax in enumerate(ax.flat,start=1):
        zo=i-1
        nc=netCDF4.Dataset(mask_05)
        reg_mask_05 = nc.variables["mask_region"][ind_zones[zo],:,:]
        nc.close()
        #print(zones[zo])
        reg_mask_05=np.float_(reg_mask_05)
        reg_mask_05[reg_mask_05==0]=np.nan
        areacellreg=areacell*reg_mask_05
        aux_mod=map_mod*reg_mask_05
        mod_season=np.zeros((12,aux_mod.shape[1],aux_mod.shape[2]))
        for t in range(12):
            mod_season[t,:,:]=np.nanmean(aux_mod[range(t,ntime,12),:,:],axis=0)
        #mod=np.nanmean(mod_season,axis=(1,2))
        modtmp=areacellreg*mod_season
        mod=np.nansum(modtmp,axis=(1,2)) #/np.nansum(areacellreg)
        tab[:,zo]=mod#if (zones[zo]=='Global'):  
    return tab

def mean_map(map_mod,lat,lon,noun,isimu,areacell,ymin,ymax,cmap,units_lab):
    aux_mod=np.ma.array(map_mod,mask=np.isnan(map_mod))
    fig=plt.figure()
    m = Basemap(projection='robin', lon_0=0)
    m.drawparallels(np.arange(-90.,90.,30.))
    m.drawmeridians(np.arange(-180.,180.,60.))
    m.drawcoastlines(linewidth=0.5)
    cmesh = m.pcolormesh(lon, lat, aux_mod, shading="flat", cmap=cmap, latlon=True, vmin=ymin, vmax=ymax)
    cbar = m.colorbar(cmesh, pad = 0.08)
    plt.title(isimu+'Mean = %.3f' % (np.nansum(aux_mod*areacell)/areacell.sum()) + units_lab,fontsize=20)
    plt.savefig('/home/wflmd/PC/seasoncycle/output/Map_meanSM_'+noun+'_'+isimu+'.png')
    plt.close()

def season_map(map_mod,lat,lon,noun,isimu,areacell,ymin,ymax,cmap,units_lab):
    aux_mod=np.ma.array(map_mod,mask=np.isnan(map_mod))
    fig=plt.figure()
    m = Basemap(projection='robin', lon_0=0)
    m.drawparallels(np.arange(-90.,90.,30.))
    m.drawmeridians(np.arange(-180.,180.,60.))
    m.drawcoastlines(linewidth=0.5)
    cmesh = m.pcolormesh(lon, lat, aux_mod, shading="flat", cmap=cmap, latlon=True, vmin=ymin, vmax=ymax)
    cbar = m.colorbar(cmesh, pad = 0.08)
    plt.title(isimu+'Mean = %.3f' % (np.nansum(aux_mod*areacell)/areacell.sum()) + units_lab,fontsize=20)
    plt.savefig('/home/wflmd/PC/seasoncycle/output/Map_meanSM_'+noun+'_'+isimu+'.png')
    plt.close()

def timeserie(map_mod,mask_05,nb_zones,zones,ind_zones,noun,nmonths,isimu):
    fig, ax = plt.subplots(4,3,figsize=(20,15))
    fig.suptitle(noun+' Timeserie '+isimu, fontsize=18)
    for i, ax in enumerate(ax.flat,start=1):
        zo=i-1
        nc=netCDF4.Dataset(mask_05)
        reg_mask_05 = nc.variables["mask_region"][ind_zones[zo],:,:]
        nc.close()
        print(zones[zo])
        aux_mod=map_mod*reg_mask_05
        mod=np.nanmean(aux_mod,axis=(1,2))
        ax.plot(nmonths,obs,color="black",label='ESA-CCI')
        ax.plot(nmonths,mod,color="dodgerblue", label='ORCHIDEE')
        ax.set_title(zones[zo],fontsize=20)
        if (i==1):
            ax.legend(loc='upper right',prop={'size':18})
    plt.subplots_adjust(left=0.15)
    plt.savefig('/home/wflmd/PC/seasoncycle/output/timeserie_'+noun+'_'+yvec+'_test.png')
    plt.close()

#def season_maps(map_mod,lat,lon,noun,isimu,areacell,ymin,ymax,cmap,units_lab):
