import numpy as np
import pandas as pd
import matplotlib
import scienceplots

import matplotlib.pyplot as plt
plt.style.use(['science', 'grid'])
from scipy.optimize import minimize_scalar
import os

def binary_entropy(p):
    return -p * np.log2(p) - (1 - p) * np.log2(1 - p)

def inverse_H(arr):
    vals = []
    for val in arr:
        # Define a function that returns the absolute difference between H(p) and the desired value
        func = lambda p: abs(binary_entropy(p) - val)
        # Use minimize_scalar to find the value of p that minimizes the absolute difference
        result = minimize_scalar(func, bounds=(1e-15, 1-1e-15), method='bounded')
        vals.append(result.x)
    vals = np.array(vals)
    vals[vals > 0.5] = 1 - vals[vals > 0.5]
    return vals

## Exact BSC old scheme
R = np.arange(0, 1, 1/50)

# Inverse capacity function
H_inv = inverse_H(1-R) # NOTE these are the epsilons for BSC!

## 1 - Plotting Corrected Scheme and Uncorrected for 8192
import SIMULATOR_exact_BSC_MIX_scheme3
def orig_scheme_postcorrectedandnot():
    # If want to run for other block lengths n, use the following call
    # n = 4096
    # SIMULATOR_exact_BSC_MIX_scheme3.test_capacity_curve(n)

    n = 1024
    # Once data is in a file:
    csv_name = str(n) + "_scheme3_numincorrect.csv"
    df = pd.read_csv(csv_name)
    orig_num_inc = df['Num Incorrect'].values
    rate = df['Rate'].values

    csv_name = str(n) + "EXACT_center_hist_numincorrect.csv"
    df = pd.read_csv(csv_name)
    num_inc = df['Num Incorrect'].values
    plotting_x_axis = df['Total Rate'].values

    plt.plot(rate, orig_num_inc/n)
    plt.plot(plotting_x_axis, num_inc/n)
    plt.plot(R, H_inv)
    plt.grid()
    plt.xlabel("Rate")
    plt.ylabel(r"Fraction incorrect X $\rightarrow$ Y")
    plt.title("Fraction of incorrect bits vs Rate for n = " + str(n) + " Exact BSC")
    plt.legend(["Prior to Post-Correction", "With Post-Correction", "Lower Bound"])
    plt.show()

## 2 - Plot original scheme, no post correction for 8192
def orig_scheme_uncorrectedonly():
    # If want to run for other block lengths n, use the following call
    # n = 1024
    # SIMULATOR_exact_BSC_MIX_scheme3.test_capacity_curve(n)

    n = 8192
    # Once data is in a file:
    csv_name = str(n) + "_scheme3_numincorrect.csv"
    df = pd.read_csv(csv_name)
    orig_num_inc = df['Num Incorrect'].values
    rate = df['Rate'].values

    plt.plot(rate, orig_num_inc/n)
    plt.plot(R, H_inv)
    plt.grid()
    plt.xlabel("Rate")
    plt.ylabel(r"Fraction incorrect X $\rightarrow$ Y")
    plt.title("Fraction of incorrect bits vs Rate for n = " + str(n) + " Exact BSC")
    plt.legend(["Simulator", "Lower Bound"])
    plt.show()


# get_BSC_comparerate()

## 5 - Original Scheme Histogram from CSV file
import SIMULATORHISTOGRAM_code
from numpy import genfromtxt
def histogram_from_file():
    N = 8192
    k = 0.2
    epsilon = inverse_H([1-k])

    fig, axs = plt.subplots(2,  sharex = 'col')

    for ax in axs:
        ax.tick_params(axis='both', which='major', color='0', labelsize=5)
        ax.tick_params(axis='both', which='minor', color='0.3')
        ax.grid(b=True, which='major', color='0.65', linestyle='-', linewidth = 0.2)

    num_incorrect = genfromtxt('simulator_histogram.csv', delimiter=',')

    num_incorrect = np.array(num_incorrect)

    binomialRV = np.random.binomial(N, epsilon, len(num_incorrect)).astype(int)

    axs[0].hist(num_incorrect, bins=20)
    axs[0].set_xlabel('$d_H(X^n,Y^n)$', size = 8)

    axs[1].hist(binomialRV, bins=20, color='orange')
    axs[1].set_xlabel("$d_H(0^n,Z^n)$", size=8)

    fig.supylabel('Frequency', size = 8, x = 0.05)

    plt.tight_layout()
    fig.savefig("Polar_Lattice_Uncorrected_Histogram.pdf")

# histogram_from_file()

## 8 - Produce Error Bars for Appendix Scheme
def error_bars_appendix(N = 8192):
    R = np.arange(0, 1, 1/50)

    csv0 = str(N) + 'EXACT_center_hist_numincorrect0.csv'
    csv1 = str(N) + 'EXACT_center_hist_numincorrect1.csv'
    csv2 = str(N) + 'EXACT_center_hist_numincorrect2.csv'
    csv3 = str(N) + 'EXACT_center_hist_numincorrect3.csv'
    csv4 = str(N) + 'EXACT_center_hist_numincorrect4.csv'

    # Lists to store the vectors from all files

    df = pd.read_csv(os.path.join('appendixscheme_data', csv0))
    rate0 = df['Total Rate'].values
    numinc0 = df['Num Incorrect'].values

    df = pd.read_csv(os.path.join('appendixscheme_data', csv1))
    rate1 = df['Total Rate'].values
    numinc1 = df['Num Incorrect'].values

    df = pd.read_csv(os.path.join('appendixscheme_data', csv2))
    rate2 = df['Total Rate'].values
    numinc2 = df['Num Incorrect'].values

    df = pd.read_csv(os.path.join('appendixscheme_data', csv3))
    rate3 = df['Total Rate'].values
    numinc3 = df['Num Incorrect'].values

    df = pd.read_csv(os.path.join('appendixscheme_data', csv4))
    rate4 = df['Total Rate'].values
    numinc4 = df['Num Incorrect'].values

    median_rate = np.zeros(50)
    median_numinc = np.zeros(50)
    percentile_rate_80 = np.zeros(50)
    percentile_rate_20 = np.zeros(50)
    percentile_numinc_80 = np.zeros(50)
    percentile_numinc_20 = np.zeros(50)
    for i in range(50):
        median_rate[i] = np.median((rate0[i], rate1[i], rate2[i], rate3[i], rate4[i]))
        median_numinc[i] = np.median((numinc0[i], numinc1[i], numinc2[i], numinc3[i], numinc4[i]))
        percentile_rate_20[i] = np.percentile((rate0[i], rate1[i], rate2[i], rate3[i], rate4[i]), 20)
        percentile_rate_80[i] = np.percentile((rate0[i], rate1[i], rate2[i], rate3[i], rate4[i]), 80)
        percentile_numinc_20[i] = np.percentile((numinc0[i], numinc1[i], numinc2[i], numinc3[i], numinc4[i]), 20)
        percentile_numinc_80[i] = np.percentile((numinc0[i], numinc1[i], numinc2[i], numinc3[i], numinc4[i]), 80)

    # Uncorrected data for comparison
    csv_name2 = str(N) + "epsilon_numincorrect.csv"
    df = pd.read_csv(csv_name2)
    orig_num_inc = df['Num Incorrect'].values/N



    plt.fill_betweenx(median_numinc/N, percentile_rate_20, percentile_rate_80, where=(percentile_rate_80 >= percentile_rate_20), color='red', alpha=0.3)

    plt.plot(median_rate, median_numinc/N, color = 'red')
    plt.plot(R, orig_num_inc, color = 'blue')
    plt.plot(R, H_inv, color = 'green')
    plt.grid()
    plt.xlabel("Rate")
    plt.ylabel("$p$")
    plt.tick_params(axis='both', which='major', color='0', labelsize=5)
    plt.tick_params(axis='both', which='minor', color='0.3')
    plt.grid(b=True, which='major', color='0.65', linestyle='-', linewidth = 0.2)
    plt.savefig('CorrectedPolarLatticePlot.pdf')

#error_bars_appendix(8192)