"""
Aina Johannessen. 
2017/2018 - Masterthesis

"""

import argparse                                         #
import pygrib                                           #
import matplotlib                                       #
import matplotlib.pyplot as plt                         #
from mpl_toolkits.basemap import Basemap                #
from mpl_toolkits.basemap import shiftgrid              #
from scipy.ndimage.filters import gaussian_filter       #
import numpy as np                                      #
import warnings                                         #
#import matplotlib.colors as colors                     #
#import spharm  # Conda has it. Converts from spectral grid to gaussian grid, http://code.google.com/p/pyspharm
#matplotlib.colors import Colormap 
#Colormap.set_under(cmap,color='k')


def get_data( MSL_infile, Z_infile, w250_infile ):
    
    #--INITIATING--------------------------------------------------------------------------------------------------
    obj_wind = pygrib.open( w250_infile )          #File containing IVT
    obj_msl= pygrib.open( MSL_infile ) 
    obj_z = pygrib.open( Z_infile ) 
    
    u_obj = obj_wind(name = "U component of wind")[0]
    v_obj = obj_wind(name = "V component of wind")[0]
    z_obj = obj_z(name = "Geopotential")[0]
    msl_obj = obj_msl(name = "Mean sea level pressure")[0]
    
    print(u_obj)
    print(v_obj)
    print(z_obj)
    print(msl_obj)
    
    #---Extract data------------------------------------------------------------------------------------------------
    msl_data = msl_obj.values/100.      #divide by 100 to go from PA to hPa
    u_data = u_obj.values
    v_data = v_obj.values
    z_data = z_obj.values/9.81
    
    plat,plon= msl_obj.latlons()        #lat and lon from mslp file:MSL_infile 
    lat,lon = u_obj.latlons()          #lat and lon from u component of wind 
    latv,lonv = v_obj.latlons()        ##lat and lon from v component of win
    latz,lonz = z_obj.latlons()        ##lat and lon from v component of win
    
    #---Shifts grid in lon dir to get values where I need---------------------------
    x_dir = np.arange(0,lon.shape[0],1 )
    for i in x_dir:
        u_data[i,:], lon[i,:] = shiftgrid(180, u_data[i,:], lon[i,:],start=False)
        v_data[i,:], lonv[i,:] = shiftgrid(180, v_data[i,:], lonv[i,:],start=False)
        msl_data[i,:], plon[i,:] = shiftgrid(180, msl_data[i,:], plon[i,:],start=False)
        z_data[i,:], lonz[i,:] = shiftgrid(180, z_data[i,:], lonz[i,:],start=False)
        
    """
    |->-----------NOTE on for loop above---------------------------------------------------------------------------|
    | Not all basemap needs this, ut the one I have does. This is because the data starts at lon=0 (over england)  |
    |   and continues to plot with increasing lon. But that means I do not get data at negative vanlues of lon     |
    \    which it is west of england.                                                                              |
    |--------------------------------------------------------------------------------------------------------------|
    """
    obj_wind.close()
    obj_msl.close()
    obj_z.close()
    
    return  u_data, v_data, msl_data, z_data, lat, lon 

def set_background():
    """
    ------------------------------------------------------------------
    Mission of function:
    - Generates the background map
    - Called from my custom IVT() function
   
    PARAMETERS
    inputs:     - None
    returns:    - m: 
    other:      - None
   
    USES
    installed functions:  - Basemap
                          - 
    custom functions:     - None
    source:               - None   
    -------------------------------------------------------------------
    """
    # ------------Atlantic-----------------
    m = Basemap( llcrnrlon = -90., llcrnrlat = 10., urcrnrlon = 50., urcrnrlat=70.,\
               resolution = 'l', area_thresh = 10000., projection = 'merc' )
    
    #m = Basemap( llcrnrlon = -170., llcrnrlat = 10., urcrnrlon = 50., urcrnrlat=80.,\
    #           resolution = 'h', area_thresh = 10000., projection = 'merc' )
    
    # ------------Norway-------------------
    #m = Basemap( llcrnrlon = -30., llcrnrlat = 40., urcrnrlon = 40., urcrnrlat=70.,\
    #           resolution = 'h', area_thresh = 10000., projection = 'merc' )
    
    m.drawcoastlines( linewidth = 0.5, linestyle = 'solid', color = "k", zorder=5)#[ 75./255., 75/255., 75/255. ] )
    m.drawmapboundary()#fill_color='aqua')
    m.fillcontinents(color=[ 75./255., 75/255., 75/255. ], zorder=1)#'coral',lake_color='aqua')
        
    # --------draw parallels------
    circles = np.arange( -90., 90. + 30, 20. ) #delat = 10.
    m.drawparallels( circles, color = [ 55./255., 55/255., 55/255. ], labels = [ 1, 0, 0, 0 ], linewidth = 0.1, fontsize=12 )
   
    # --------draw meridians-----
    meridians = np.arange( 0., 360, 20. ) #delon = 10.
    m.drawmeridians( meridians, color = [ 55./255., 55./255., 55./255. ], labels = [ 0, 0, 0, 1 ], linewidth = 0.1, fontsize=12 )
    return m

def Z_MSLP( fig_path_and_name, MSL_infile, Z_infile, w250_infile, time, date ):
    u_data, v_data, msl_data, z_data, lat, lon =  get_data( MSL_infile, Z_infile, w250_infile )
    #print (fig_path_and_name, MSL_infile, Z_infile, w250_infile, time, date)
    
    #------BAKGROUND MAP-------------------------------------------------------------------
    fig = plt.figure()
    m = set_background()
    
    #------PLOT DATA-------------------------------------------------------------------
      
    x, y = m( lon,lat ) 
    #zblurred=gaussian_filter( z_data, sigma=2 )
    #pblurred=gaussian_filter( msl_data, sigma=5 )
    zblurred= z_data
    pblurred=msl_data
    
    """
    |------Notes to expressions above------------------------------------------------------------------------------------------|
    | Should read about what is actually does..But it makes output smooth and nice                                             |                                                                                     
    |--------------------------------------------------------------------------------------------------------------------------|
    """
   
    pcont_val = np.arange( 800,2000,5 )                                         #defines which contourlines to plot
    
    #--->PLOTS PRESSURE ISOBARS-----------------------------------------
    CS = m.contour( x,y,pblurred,pcont_val, colors = "k", linewidths = 0.5, zorder=9 )
    plt.clabel( CS, fontsize=7, inline=1,fmt = '%1.0f',zorder=9)#,ticks=Contourrange )
    #----------------------------------

      
    #--->PLOT Z------------------------------------------------------
    #z_data_red = np.copy(zblurred)
    #z_data_blue = np.copy(zblurred)
    z_data_red = np.copy(z_data)
    z_data_blue = np.copy(z_data)
    
    for i in range(0,z_data.shape[0]):
        for j in range(0,z_data.shape[1]):
            if z_data[i,j] <=5500.:
                z_data_red[i,j] = None
            else:
                z_data_blue[i,j] = None
 
    #----->Generates thinlines at contour interface--
    C = [[255,255,60],#
         [214,1,97]] #
    C = np.array( C )
    C = np.divide( C, 255. )  # RGB has to be between 0 and 1 in python
    
    tz1=np.arange(5500,6500,30)
    tz2=np.arange(4500,5500,30)
    contours_z = np.concatenate((tz2, tz1))

    CS = m.contour( x,y,z_data_red, tz1,linestyles = ':', colors = "r", linewidths = 0.5, zorder=8 ) 
    plt.clabel( CS, fontsize=7, inline=1,fmt = '%1.0f',zorder=8)#,ticks=Contourrange,zorder=8 )
    
    CS = m.contour( x,y,z_data_blue,tz2, linestyles = ':', colors = "b", linewidths = 0.5, zorder=8 )
    plt.clabel( CS, fontsize=7, inline=1,fmt = '%1.0f',zorder=8)#,ticks=Contourrange, zorder=8 )
    
    wind_size = np.sqrt(u_data**2 + v_data**2)
    
    C = [[206,226,255],
        [141,165,253],
        [241,208,181]]
    C = np.array( C )
    C = np.divide( C, 255. )  # RGB has to be between 0 and 1 in python
    
    contourwind = (35,50,65,80)
    
    CS = m.contourf(x,y,wind_size,contourwind,colors=C, alpha = 0.7, zorder=2)
    cbar = plt.colorbar(CS, fraction=0.046, pad=0.01)
    cbar.ax.tick_params(labelsize=15) #NEW
    
    cbar.set_label('Wind speed [ m/s ]', fontsize = 15)
    """
    |------Notes to zorder------------------------------------------------------------------------------------------------------|
    | zorder makes it possible to tell which layer of plots should be above or under another layer                              |
    |---------------------------------------------------------------------------------------------------------------------------|
    """
    date_format = "{6}{7} / {4}{5} - {0}{1}{2}{3}".format(*str(date))
    date_format = "{0}{1}{2}{3}-{4}{5}-{6}{7} ".format(*str(date))   
    
    #plt.title("Thickness:$Z_{500}$-$Z_{1000}$(m; colored dashed contours ) \n"\
    #        +"Wind strength at 250 hpa(m $s^{-1}$; shaded), MSLP(hPa; black contours) \n " \
    #        +"ECMWF-Analysis, Valid: "+time+" UTC, "+date_format, fontsize=11)
    plt.title( date_format + time+" UTC", fontsize=15, position=(0.17, 0.93), bbox=dict(facecolor='white', alpha=0.5))
    
    
    fig.set_size_inches( 12.80, 7.15 )    
    fig.savefig( fig_path_and_name, dpi = 200 )
    plt.close( )
    
    


def user_interface():
    """
     ------------------------------------------------------------------
    Mission of function:
    - Handles Terminal input from user. 
    - Makes it possible to run for a user defined date and time
     -------------------------------------------------------------------
    """
    
    #----Sets up user interface from Terminal----------------------------
    parser = argparse.ArgumentParser(description='Process some integers.')
    parser.add_argument( "--date", type = int,
        choices= [ 20160920,20160921,20160922,20160923,20160924,20160925,20160926,20160927, 20160928, 20160929, 20160930, 20161001],
        help = "the date you want. " )
    parser.add_argument( "time", type = str,
        help = "the times you want. ex 160924, 160926, 160922, 160930", 
        default = "all" )
        
    args = parser.parse_args( )
    
    #----------SET PATHS-------------------------------------------
    path_gribs="../gribs/"
    path_figures="../figs/"
    #path_norestore_files="/Users/ainajohannessen/Documents/Aina/skole/master/Masterthesis/data_norstore/"

    main_name = str(args.date) + "_" + args.time                                        #example: 20160930_1800
     
    MSL_infile = path_gribs + "param_msl_" + str(args.date) + "_" + args.time + ".grib"
    Z_infile = path_gribs + "Z_1000_500_" + main_name + ".grib"
    w250_infile = path_gribs + "wind250_" + main_name + ".grib"
    
    
    fig_path_and_name = path_figures+"Z_"+ main_name
        
    
    
    #-------Calls main function-------------------------------------
    Z_MSLP(fig_path_and_name,MSL_infile,Z_infile,w250_infile, args.time, args.date)

warnings.filterwarnings("ignore",category=matplotlib.mplDeprecation)
user_interface()

    
    
    
    
