'''Functions for policy evaluation experiments. '''

import numpy as np

from scipy.optimize import linear_sum_assignment
from scipy.special import expit, logit
from scipy.stats import bernoulli

from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.linear_model import LogisticRegression

from sklearn.metrics import mean_squared_error

import data
import utils
import estimation



def policy_eval_helper(num_runs = 2, eval_pol_phi=1, eval_pol_b=0.5, h=1, outcome_degree=2,
                      mu_degree=2, S=10, method='WDM',right_node_num=20,
                      theta_true=None, true_pol_type='logistic',
                         cov_dist='gaussian',
                         cov_mean=0, cov_sigma=1, 
                         noise_sigma=1, test_size=10, train_size=1000,
                       pol_theta=0.5, pol_b = 0.5, fit_outcome_degree=2, bootstrap=True):
    
    '''
    Evaluate \rho in Algorithm 1 over multiple realizations of train/test data.
    
    Logistic policy: expit(pol_theta * W + pol_b)
    
    pol_theta: true policy theta
    pol_b: true policy b
    
    
    '''
    
    oracle_model = estimation.oracle_outcome_model(true_theta=theta_true)
    eval_pol = estimation.my_logistic_policy(phi=eval_pol_phi, b=eval_pol_b)
    
    rho_oracle_list = []
    rho_direct_list = []
    rho_WDM_list = []
    rho_GRDR_list = []
    
    
    for run in range(num_runs):
        
        W_samples, Z_samples, Y_samples = data.generate_data(degree=outcome_degree, 
                                                     theta_true=theta_true, 
                                                     cov_dist=cov_dist,
                                                     cov_mean=cov_mean, cov_sigma=cov_sigma, 
                                                     noise_sigma=noise_sigma, 
                                                             pol_type=true_pol_type,
                                                     num_samples = train_size, 
                                                     pol_theta=pol_theta, pol_b = pol_b, 
                                                     test_data=False)
        
        test_W_samples = data.generate_data(degree=outcome_degree, theta_true=OUTCOME_THETA, 
                                            cov_dist='gaussian',
                                            cov_mean=cov_mean, cov_sigma=cov_sigma, 
                                            num_samples=test_size, test_data=True)
        
        #compute Oracle rho    
        rho_oracle = estimation.eval_value(eval_pol=eval_pol, right_node_num=right_node_num, method='direct',
               test_W_samples=test_W_samples, outcome_degree=outcome_degree, mu_degree=mu_degree,
              outcome_model=oracle_model)
        
        
        #compute estimates
        rho_direct = estimation.perturbation_alg_bootstrap(eval_pol=eval_pol, h=h, 
                               W_samples=W_samples, Z_samples=Z_samples, 
                               Y_samples=Y_samples, outcome_degree=fit_outcome_degree, 
                                                           test_W_samples=test_W_samples,
                              mu_degree=mu_degree, S=S, method='direct',right_node_num=right_node_num, bootstrap=bootstrap)
        
        
        rho_WDM = estimation.perturbation_alg_bootstrap(eval_pol=eval_pol, h=h, 
                               W_samples=W_samples, Z_samples=Z_samples, 
                               Y_samples=Y_samples, outcome_degree=fit_outcome_degree, 
                                                        test_W_samples=test_W_samples,
                              mu_degree=mu_degree, S=S, method='WDM',right_node_num=right_node_num, bootstrap=bootstrap)
        

        rho_GRDR = estimation.perturbation_alg_bootstrap(eval_pol=eval_pol, h=h, 
                               W_samples=W_samples, Z_samples=Z_samples, 
                               Y_samples=Y_samples, outcome_degree=fit_outcome_degree, 
                                                         test_W_samples=test_W_samples,
                              mu_degree=mu_degree, S=S, method='GRDR',right_node_num=right_node_num, bootstrap=bootstrap)
        
    
        rho_oracle_list.append(rho_oracle/float(min(test_size, right_node_num)))
        rho_direct_list.append(rho_direct/float(min(test_size, right_node_num)))
        rho_WDM_list.append(rho_WDM/float(min(test_size, right_node_num)))
        rho_GRDR_list.append(rho_GRDR/float(min(test_size, right_node_num)))
        
    rho_oracle_mean = np.mean(rho_oracle_list)
    
    print('rho_oracle_list ', rho_oracle_list)
    print('rho_direct_list ', rho_direct_list)
    print('rho_WDM_list ', rho_WDM_list)
    print('rho_GRDR_list ', rho_GRDR_list)
    
    bias_direct = np.mean(rho_direct_list) - rho_oracle_mean
    var_direct  = np.var(rho_direct_list)
    
    bias_WDM = np.mean(rho_WDM_list) - rho_oracle_mean
    var_WDM  = np.var(rho_WDM_list)
    
    bias_GRDR = np.mean(rho_GRDR_list) - rho_oracle_mean
    var_GRDR  = np.var(rho_GRDR_list)
    
    
    print('\nBias (direct): ', bias_direct, 'Variance (direct): ', var_direct, 'MSE: ', mean_squared_error(rho_oracle_list, rho_direct_list))
    print('Bias (WDM): ', bias_WDM, 'Variance (WDM): ', var_WDM, 'MSE: ', mean_squared_error(rho_oracle_list, rho_WDM_list))
    print('Bias (GRDR): ', bias_GRDR, 'Variance (GRDR): ', var_GRDR, 'MSE: ', mean_squared_error(rho_oracle_list, rho_GRDR_list))
    
    
    results_dict = {'direct': np.array(rho_direct_list),
                   'WDM': np.array(rho_WDM_list),
                   'GRDR': np.array(rho_GRDR_list),
                    'oracle': np.array(rho_oracle_list)}
    
    return results_dict