"""
Aina Johannessen. 
2017/2018 - Masterthesis

"""


import argparse                                         #
import pygrib                                           #
import matplotlib                                       #
import matplotlib.pyplot as plt                         #
from mpl_toolkits.basemap import Basemap                #
from mpl_toolkits.basemap import shiftgrid              #
from scipy.ndimage.filters import gaussian_filter       #
import numpy as np                                      #
import warnings                                         #
#import matplotlib.colors as colors                     #
#import spharm  # Conda has it. Converts from spectral grid to gaussian grid, http://code.google.com/p/pyspharm
#matplotlib.colors import Colormap 
#Colormap.set_under(cmap,color='k')


def get_data(tcwv_infile, msl_infile, wind_infile):
    print ("in getdata")
    obj_tcwv = pygrib.open( tcwv_infile ) 
    tcwv_obj = obj_tcwv(name = "Total column water vapour")[0]
    tcwv_data = tcwv_obj.values
    lat,lon = tcwv_obj.latlons()
    print ("got tcwv_obj")
    
    obj_msl = pygrib.open( msl_infile ) 
    msl_obj = obj_msl(name = "Mean sea level pressure")[0]
    msl_data = msl_obj.values/100.
    plat,plon = msl_obj.latlons()
    print ("got msl_obj")
    
    obj_wind = pygrib.open( wind_infile ) 
    
    u_obj = obj_wind(name = "U component of wind", level = 850 )[0]
    u_data = u_obj.values
    ulat,ulon = u_obj.latlons()
    
    #print ("got u_obj")
    v_obj = obj_wind(name = "V component of wind", level = 850 )[0]
    v_data = v_obj.values
    vlat,vlon = v_obj.latlons()
    #print ("got v_obj")
    
    print(tcwv_obj)
    print(msl_obj)
    #---Extract data------------------------------------------------------------------------------------------------
    
    
    
    
    
    
    #---Shifts grid in lon dir to get values where I need---------------------------
    x_dir = np.arange(0,lon.shape[0],1 )
    for i in x_dir:
        tcwv_data[i,:], lon[i,:] = shiftgrid(180, tcwv_data[i,:], lon[i,:],start=False)
        msl_data[i,:], plon[i,:] = shiftgrid(180, msl_data[i,:], plon[i,:],start=False)
        
        u_data[i,:], ulon[i,:] = shiftgrid(180, u_data[i,:], ulon[i,:],start=False)
        v_data[i,:], vlon[i,:] = shiftgrid(180, v_data[i,:], vlon[i,:],start=False)
        
        
    """
    |->-----------NOTE on for loop above---------------------------------------------------------------------------|
    | Not all basemap needs this, ut the one I have does. This is because the data starts at lon=0 (over england)  |
    |   and continues to plot with increasing lon. But that means I do not get data at negative vanlues of lon     |
    \    which it is west of england.                                                                              |
    |--------------------------------------------------------------------------------------------------------------|
    """
    obj_tcwv.close()
    obj_msl.close()
    
    obj_wind.close()
    
    return tcwv_data, msl_data, u_data, v_data, lat, lon

def set_background():
    """
    ------------------------------------------------------------------
    Mission of function:
    - Generates the background map
    - Called from my custom IVT() function
   
    PARAMETERS
    inputs:     - None
    returns:    - m: 
    other:      - None
   
    USES
    installed functions:  - Basemap
                          - 
    custom functions:     - None
    source:               - None   
    -------------------------------------------------------------------
    """
    # ------------Atlantic-----------------
    m = Basemap( llcrnrlon = -90., llcrnrlat = 10., urcrnrlon = 50., urcrnrlat=70.,\
               resolution = 'l', area_thresh = 10000., projection = 'merc' )
    # ------------Norway-------------------
    #m = Basemap( llcrnrlon = -30., llcrnrlat = 40., urcrnrlon = 40., urcrnrlat=70.,\
    #           resolution = 'h', area_thresh = 10000., projection = 'merc' )
    
    m.drawcoastlines( linewidth = 0.5, linestyle = 'solid', color = "k", zorder=3)#[ 75./255., 75/255., 75/255. ] )
    m.drawmapboundary()#fill_color='aqua')
    m.fillcontinents(color=[ 75./255., 75/255., 75/255. ], zorder=1)#'coral',lake_color='aqua')
        
    # --------draw parallels------
    circles = np.arange( -90., 90. + 30, 10. ) #delat = 10.
    m.drawparallels( circles, color = [ 55./255., 55/255., 55/255. ], labels = [ 1, 0, 0, 0 ], linewidth = 0.1, fontsize=7 )
   
    # --------draw meridians-----
    meridians = np.arange( 0., 360, 10. ) #delon = 10.
    m.drawmeridians( meridians, color = [ 55./255., 55./255., 55./255. ], labels = [ 0, 0, 0, 1 ], linewidth = 0.1, fontsize=7 )
    return m



def TCWV(fig_path_and_name,tcwv_infile, MSL_infile, wind_infile, time, date):
    print ("in TCWV")
    
    tcwv_data, msl_data, u_data, v_data, lat, lon = get_data(tcwv_infile, MSL_infile, wind_infile)
    print ("back in TCWV")
    
    #------BAKGROUND MAP-------------------------------------------------------------------
    fig = plt.figure()
    m = set_background()
    
    #------PLOT DATA-------------------------------------------------------------------
      
    x, y = m( lon,lat )
    
    #--->------TCWV------------------------------------
    C = [[51,86,14], #green darkest,1
        [71,119,28], #green dark,2
        [91,152,36], #green ,3
        [127,183,133],# green light,4
        [155,193,224],# blue,1
        [170,216,236],# ligh blue,2
        [183,239,248],# lighter blue,3
        [200,255,255],# lighter lighter blue,4
        [223,255,254],# lightest blue,5
        [247,255,255],# blue almost white,6
        [255,255,188],# lightest yellow,1
        [255,254,94], # light yellow,2
        [251,243,57], # yellow,3
        [240,203,48], # light orange,4
        [230,163,38], # orange,5
        [223,124,45], #dark orange,6
        [218,87,49], # almost red,7
        [211,60,64], # red,8
        [200,57,87], #darker red,9
        [167,49,68], #darkest red,10
        [137,49,68]] #darkest red,10
        
    C = np.array(C)
    C = np.divide(C,255.) # RGB has to be between 0 and 1 in python
    contours = [ 0, 0.5, 1, 2, 5, 10, 12, 15, 16, 18, 20., 25, 30., 35., 40., 45., 50., 55., 60., 65., 80]
    #contours = [ 25.,30., 35., 40., 45., 50., 55., 60., 65.]
    
    #c=plt.get_cmap('Blues_r')#jet
    CS = m.contourf(x,y,tcwv_data, contours, colors = C, alpha = 1.0, zorder = 2)
    cbar = plt.colorbar(CS, fraction=0.046, pad=0.01)
    #cmap = c,
    cbar.set_label('Total Column Water Capor(TCWV) [kg $m^{-2}$]')
        
    #--->-----Pressure isobars-------------------------
    #pblurred=gaussian_filter( msl_data, sigma=7 )
    pblurred = msl_data
    pcont_val = np.arange( 800,2000,5 ) 
    CS = m.contour( x,y, pblurred, pcont_val, colors = "k", linewidths = 0.7, zorder=4 )
    plt.clabel( CS, fontsize=7, inline=1,fmt = '%1.0f')#,ticks=Contourrange )
    #CS = m.contour(x,y,pblurred, colors="k")
    #--->-----wind-------------------------
    
    #-------->Have to limit number of windbarbs plotted-----
   
    scale = 70.
    i = 20.
    legend_size = i/scale + i    
    
    yy = np.arange( 0, y.shape[ 0 ], 30 ) #skips over every 100th value 100 = inbetween
    xx = np.arange( 0, x.shape[ 1 ], 30 ) #skips over every 100th value
    points = np.meshgrid( yy, xx )
    
    qv = m.quiver( x[points], y[points], u_data[points], v_data[points],scale = scale, scale_units='inches',linewidths = 0.8, zorder=5)
    
    
    f = fig.patches.extend([plt.Rectangle((0.863,0.89),0.083,0.042,
                                      fill=False, color='k', alpha=0.8, zorder=9,
                                      transform=fig.transFigure, figure=fig)])
    
    p = plt.quiverkey(qv, 1.08, 1.05 , legend_size , "15 m $s^{-1}$", coordinates='axes', zorder = 1,
                   color='k', labelcolor='k', labelpos = "S", labelsep= 0.05)
    
    date_format = "{6}{7} / {4}{5} - {0}{1}{2}{3}".format(*str(date))
    plt.title("TCWV(kg $m^{-2}$ ; shaded), wind at 850hpa(m $s^{-1}$ ; vector) \n ECMWF-Analysis, Valid: "+time+" UTC, "+date_format, fontsize=11)
    
    fig.set_size_inches( 12.80, 7.15 )
    fig.savefig( fig_path_and_name, dpi = 600 )
    plt.close( )
    
    
def user_interface():
    """
     ------------------------------------------------------------------
    Mission of function:
    - Handles Terminal input from user. 
    - Makes it possible to run for a user defined date and time
     -------------------------------------------------------------------
    """
    
    #----Sets up user interface from Terminal----------------------------
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument( "--date", type = int,
        choices= [ 20171215,20171216,20171217,20171218, 20171219, 20171220, 20171221, 20171222, 20171223, 20171224,20171225,20171226],
        help = "the date you want. " )
    parser.add_argument( "time", type = str,
        help = "the times you want. ex 160924, 160926, 160922, 160930", 
        default = "all" )
        
    args = parser.parse_args( )
    
    #----------SET PATHS-------------------------------------------
    
    #---PATHS-----------------------------------------------------------------------------------
    path_gribs="../gribs/"
    #path_gribs_msl="/Users/ainajohannessen/Documents/Aina/skole/master/Masterthesis/data_norstore/Birk_2017/mslp/"
    path_figures="../figs/"
    #path_norestore_files="/Users/ainajohannessen/Documents/Aina/skole/master/Masterthesis/data_norstore/"

    main_name = str(args.date) + "_" + args.time                                        #example: 20160930_1800
    
    #-->TCWV PATH
    #tcwv_infile = path_norestore_files+"SCA/sfc/tcwv/param_tcwv_" + str(args.date) + "_" + args.time + ".grib"
    tcwv_infile = path_gribs+"param_tcwv_" + str(args.date) + "_" + args.time + ".grib"
    
    #-->MSLP PATH
    #msl_path = path_norestore_files+"SCA/sfc/msl/"
    MSL_infile = path_gribs + "param_msl_" + str(args.date) + "_" + args.time + ".grib"
    
    #-->wind PATH
    #wind_infile = path_norestore_files+"SCA/pl/uv/"+"param_"+ str(args.date) + "_" + args.time + ".grib"
    wind_infile = path_gribs+"uv850hpa_"+ str(args.date) + "_" + args.time + ".grib"
    
    print ("using: "+wind_infile)
    #-->OUTFILE PATH
    fig_path_and_name = path_figures+"TCWV_"+ main_name
  
    
    #-------Calls main function-------------------------------------
    TCWV(fig_path_and_name,tcwv_infile,MSL_infile, wind_infile, args.time, args.date)

warnings.filterwarnings("ignore",category=matplotlib.mplDeprecation)
user_interface()
