import numpy as np
import re, sys, os, glob

from util_py.csv_readwrite import csv_data_reader
from util_py.drive_testing import run_pattern_identifiers
import matplotlib.pyplot as plt
from util_py.sorting import sort_all_by_index, get_running_average


filename_runid='r'
filename_expid='e'
expidpatt =re.compile(filename_expid+r'\[([^\[]*)\]')
runidpatt =re.compile(r'-'+filename_runid+r'\[([^\[]*)\]')
idvalpair=re.compile(r'(.*)\[(.*)\]')


def get_value_of (flnm, idvpatt):
    mch = re.findall(idvpatt,flnm)
    if len(mch)<=0:
        return None
    return int(mch[0])

def get_final_values(base_dir, dirnmpatt_str, plot_with):

    dir_nm_patt=re.compile(dirnmpatt_str)
    dir_nms = []
    
    for dr in os.listdir(base_dir):
        if dir_nm_patt.fullmatch(dr):
            print('accepted folder: "{}"'.format(dr))
            dir_nms.append(dr)
    
    dt_rdr = csv_data_reader(colnames)

    final_values=None    
    finval={}
    n_plots = len(plot_with)

    for dir_nm in dir_nms:
        
        file_patt=os.path.join(glob.escape(base_dir+dir_nm+'/'), '*.csv')
        
        # first we glob all csv files
        path_files=glob.glob(file_patt)
        n_files = len(path_files)
        if (n_files<=0):
            print("file pattern='{}' did not match any files, skipping dir {}.".format(
                    file_patt, dir_nm),file=sys.stderr)
            continue
        
        path_files.sort()
        
        for fl in range(n_files):
            t_fl = path_files[fl]
            dat = {}
            # then we read a dict of each file and save that data in memory
            dt_rdr.read_data_from_file(t_fl, dat)
        
            if len (dat) <=0: continue
            # sort it by iter just in case
            sort_all_by_index(dat,sort_by)
            
            # average it 
            for sh in shorten:
                dat[sh] = dat[sh][(n_iters2avg_by-1):]
        
            for cl in average:
                dat[cl] = get_running_average(dat[cl], n_iters2avg_by)
            
            t_fin=[]
            for prn in range(n_plots):
                xser, yser=plot_with[prn]
                t_fin.append(np.max(dat[yser]))
    
            # here we average for each 'e' using the 'r'
            exptid=get_value_of(t_fl, expidpatt)
            #runid = get_value_of(t_fl, runidpatt)
        
            if not (exptid in finval):
                finval[exptid] = []
                
            finval[exptid].append(t_fin)
    
    final_values = np.zeros((len(finval), n_plots))

    ndx=0
    for ke in finval:
        plar=np.array(finval[ke])
        final_values[ndx] = np.average(plar,axis=0)
        ndx+=1
    
    return final_values        


if __name__ == "__main__":

    num_bins = 20
    n_iters2avg_by = 10
    colnames=['n_itr','objval']
    sort_by=colnames[0]
    shorten=[colnames[0]]
    average=[colnames[1]]
    plot_with=[ [colnames[0],colnames[1]] ]
   
    base_dir='../results/'
    
    if len(sys.argv) < 2:
        print("Need as first argument the name of the subdir in '../results/' that contains the csv files." )
        sys.exit(1)
        
    # these are the patterns that define all the files that contain sample run data    
    # first we start with the P_alt based results
    dirnm_P_alt = 'E[50]-R[5]-N[25]-M[50]-lvl[3]-dim[20]-J[3]-maxstep[0.05]-Pstar[gmix-rot-sft]'
    alt_final_rn = get_final_values(base_dir, dirnm_P_alt, plot_with)   

    # the pstar output
    dirnm_P_star  = 'E[50]-R[5]-N[25]-M[50]-lvl[3]-dim[20]-J[3]-maxstep[0.05]-Pstar[gmix]'
    star_final_rn = get_final_values(base_dir, dirnm_P_star, plot_with)   
    print(star_final_rn.shape)
        
    star_sort_ndx=np.argsort(star_final_rn, axis=0)
    ln=star_final_rn.shape[0]
    print("N total star {}".format(ln))
    qntlndx=int(np.floor(0.95*ln))
    print("qntl ndx: {}".format(qntlndx))
    qntl=star_final_rn[star_sort_ndx[qntlndx]]
    print("95th quanitle of star is: {} - shp {}".format(qntl, qntl.shape))
    
    ngtr=np.where(star_final_rn > qntl)
    print("n grtr {} / {}".format(ngtr[0].shape, star_final_rn.shape))
    
    for n in range(qntlndx, ln):
        print("sorted [{}] = {} ".format(n, star_final_rn[star_sort_ndx[n]]))
        
    nsmlr=np.where(alt_final_rn<=qntl)
    print("n alt smaller than qntl: {} / {} = {}".format(nsmlr[0].shape[0], alt_final_rn.shape[0],
          nsmlr[0].shape[0] / alt_final_rn.shape[0]))    
    
    fig, axen = plt.subplots(nrows=1, ncols=1, figsize=(5.,5.), squeeze=False)
    
    denis=True
    axen[0,0].hist(star_final_rn, num_bins, facecolor='blue', alpha=0.5, density=denis)
    
    axen[0,0].hist(alt_final_rn, num_bins, facecolor='red', alpha=0.5, density=denis)
    
    axen[0,0].plot([qntl[0], qntl[0]], [0,70.], 'k--')#, linewidth=.01)
    axen[0,0].set_ylim((0,69.))
    axen[0,0].set_xlim((0,.2))
    axen[0,0].set_xticks(np.arange(0,.24,0.05))
    axen[0,0].get_yaxis().set_visible(False)
    axen[0,0].set_xlabel("$\hat{R}_n$")
    plt.show()
