# -*- coding: utf-8 -*-
"""
Created on Fri Oct 25 17:51:50 2024

@author: shara
"""
from polar_modified import *
import matplotlib.pyplot as plt
import numpy as np
from numpy import genfromtxt
from scipy.optimize import minimize_scalar
from scipy.stats import bootstrap
import montecarlo_biawgn_capacity
import scienceplots
import os
plt.style.use(['science', 'grid'])

plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman']

filename_prefix = './../'

def safe_xlogx(x):
    if x == 0:
        return 0
    else:
        return x * np.log2(x)

def binary_entropy(p):
    return -1 * safe_xlogx(p) - 1 * safe_xlogx(1 - p)

#Upper bound on the one-shot redundancy applied to blocklength n. MI is single letter mutual information of channel
def PFR_UB(MI, N):
    return (np.log2(MI*N + 1) + 5)/N

fig, axs = plt.subplots( 3, sharex='col', sharey = 'row' )

def get_filename_channel( channel, N, arg ):
    if channel == 'BSC':
        filename = "PolarSim_Rate_Data/BSC_trials_N_" + str(N) + "_p_" + str(round(arg, 4)) + ".csv"
    elif channel == 'BEC':
        filename = "PolarSim_Rate_Data/BEC_trials_N_" + str(N) + "_eps_" + str(round(arg, 4)) + ".csv"
    else:
        filename = "PolarSim_Rate_Data/BIAWGN_trials_N_" + str(N) + "_sig_" + str(round(arg, 4)) + ".csv"
    return filename_prefix + filename

def simulate_trials_for_arg(chan, N, p, num_trials):
    # Set up rates array to collect results
    results_arr = np.zeros(num_trials, dtype=np.float64)
    eps_avg_e = 0
    # Obtain probability table for size N
    p_n = polar_channel_mc(int(np.log2(N)),chan,p,2001)
    p_n = np.array( [ min(a,1-a) for a in p_n ] )
    for t in range(num_trials):
        #Generate P1 domain input realizations for channel simulation
        x_n_true = np.random.randint(2, size=N).astype(np.float64)
        y_n = chan( x_n_true, p )
        #Generate Common randomness
        zn = np.random.uniform(0, 1, N)
        # Run Henry Pfister decoder with modified random choice
        uhat, xhat, delta = polar_decode_with_cr(y_n, zn, np.full(N, 0.5, dtype=np.float64))
        rate_sum = 0
        for i in range(N):
            if delta[i] == 1:
                rate_sum += -1 * np.log2(1/2 - p_n[i])
            else:
                rate_sum += -1 * np.log2(1/2 + p_n[i])
        result_rate = (rate_sum + 1)/N
        results_arr[t]= result_rate
        eps_avg_e += result_rate / num_trials
    print('Average Rate Is:', eps_avg_e)
    print('Standard Deviation', np.std( results_arr ))
    return results_arr

def plot_for_channel( channel, arg, N_range, MI, ax ):
    rates = []
    rates_5 = []
    rates_95 = []
    pfr_ubs = []
    bootrap_ci_low = []
    bootstrap_ci_high = []
    for N in N_range:
        filename = get_filename_channel(channel, N, arg)
        if os.path.isfile(filename):
            raw_rates = genfromtxt(filename, delimiter=",")
            data_for_bootstrap = raw_rates - MI
            data_for_bootstrap = (data_for_bootstrap,)
            ci_lims = bootstrap(data_for_bootstrap, np.median, n_resamples=20, method = 'percentile' )
            bootrap_ci_low.append( ci_lims.confidence_interval.low )
            bootstrap_ci_high.append( ci_lims.confidence_interval.high )
            # median_rate = np.median(dist)

            rates.append( np.median( raw_rates ) - MI )
            rates_5.append( np.percentile( raw_rates, 5 ) - MI )
            rates_95.append( np.percentile( raw_rates, 95 ) - MI )
            pfr_ubs.append( PFR_UB(MI, N) )

    ax.tick_params(axis='both', which='major', color='0', labelsize=6)
    ax.tick_params(axis='both', which='minor', color='0.3')
    ax.grid(b=True, which='major', color='0.65', linestyle='-', linewidth = 0.2)

    #ax.fill_between( np.log2(N_range), rates_5, rates_95, color = 'yellow', alpha = 0.3 )
    ax.fill_between( np.log2(N_range), bootrap_ci_low, bootstrap_ci_high, color = 'red', alpha = 0.3 )
    ax.plot( np.log2(N_range), rates, label = 'PolarSim', color = 'red')
    ax.plot( np.log2(N_range), pfr_ubs, label = 'SFRL Upper Bound', color = 'black' )
N_range = [2**i for i in np.arange(6, 15)]
#BSC
p = 0.0505
MI_bsc = 1 - binary_entropy(p)




#BIAWGN
sig = 0.5074
MI_sig = montecarlo_biawgn_capacity.biawgn_capacity(sig)


#BEC
eps = 0.202
MI_bec = 1 - eps

plot_for_channel( 'BSC', p, N_range, MI_bsc, axs[0] )
plot_for_channel( 'BIAWGN', sig, N_range, MI_sig, axs[1] )
plot_for_channel( 'BEC', eps, N_range, MI_bec, axs[2] )



fig.supxlabel('$\log n$', size = 10)
fig.supylabel('Redundancy', size = 10)


fig.savefig("Figure4.pdf")