# -*- coding: utf-8 -*-
"""
Created on Wed Oct 30 11:29:08 2024

@author: shara
"""

# -*- coding: utf-8 -*-
"""
Created on Thu Oct 24 23:40:23 2024

@author: shara
"""
from polar_modified import *
import matplotlib.pyplot as plt
import numpy as np
import random
import time
import math
from scipy.optimize import minimize_scalar
import pandas as pd
import psutil, os
import pickle
p = psutil.Process(os.getpid())
p.nice(psutil.HIGH_PRIORITY_CLASS)
# Set experiment seed
np.random.seed(101)

# HELPER FUNCTIONS

# Convert a n to binary representation with num_bits bits
def binary(n, num_bits):
    '''
    Encode number in binary in np array

          Arguments:
                  n (int64): Number
                  num_bits (int64): Number of bits

          Returns:
                  x (int64[:]): Numpy array of encoded bits
    '''
    binary_str = format(n, 'b').zfill(num_bits)
    return np.array([int(bit) for bit in binary_str], dtype=np.int64)

# Compute the binary entropy of p
def binary_entropy(p):
    return -p * np.log2(p) - (1 - p) * np.log2(1 - p)

# Given an array of H, use the inverse H function to compute p table
def inverse_H(arr):
    # Accepts an array of channel capacities and uses it to find p parameter?
    vals = []
    for val in arr:
        # Define a function that returns the absolute difference between H(p) and the desired value
        func = lambda p: abs(binary_entropy(p) - val)
        # Use minimize_scalar to find the value of p that minimizes the absolute difference
        result = minimize_scalar(func, bounds=(1e-15, 1-1e-15), method='bounded')
        vals.append(result.x)
    vals = np.array(vals)
    vals[vals > 0.5] = 1 - vals[vals > 0.5]
    return vals

# Recursively compute the capacity of BEC subchannels
def I_bec(N, i, eps):
  if N == 1:                # Base Case: capacity is 1 - eps
    return (1.0 - eps)
  elif (i + 1) % 2 == 1:                               # Case: Index is odd
    return (I_bec(int(N/2), int((i + 1)/2), eps)) ** 2
  else:                                                # Case: Index is even
    recursive = I_bec(int(N/2), int(i/2), eps)
    return (2 * recursive) - (recursive ** 2)



# Simulate trials for the BEC with size N, erasure probability eps
# Returns an array of size num_trials containing rates
def simulate_trials_for_arg(chan, N, p, num_trials):
    # Set up rates array to collect results
    results_arr = np.zeros(num_trials, dtype=np.float64)
    eps_avg_e = 0
    exec_times = []
    t_0_preprocess = time.time()
    # Obtain probability table for size N
    p_n = polar_channel_mc(int(np.log2(N)),chan,p,2001)
    p_n = np.array( [ min(a,1-a) for a in p_n ] )
    t_1_preprocess = time.time()
    amortized_preprocess_time = (t_1_preprocess - t_0_preprocess)/num_trials
    for t in range(num_trials):
        t_0 = time.time()
        #Generate P1 domain input realizations for channel simulation
        x_n_true = np.random.randint(2, size=N).astype(np.float64)
        y_n = chan( x_n_true, p )
        #Generate Common randomness
        zn = np.random.uniform(0, 1, N)
        # Run Henry Pfister decoder with modified random choice
        uhat, xhat, delta = polar_decode_with_cr(y_n, zn, np.full(N, 0.5, dtype=np.float64))
        rate_sum = 0
        for i in range(N):
            if delta[i] == 1:
                rate_sum += -1 * np.log2(1/2 - p_n[i])
            else:
                rate_sum += -1 * np.log2(1/2 + p_n[i])
        result_rate = (rate_sum + 1)/N
        t_1 = time.time()
        eps_avg_e += result_rate / num_trials
        results_arr[t]= result_rate
        exec_times.append(t_1 - t_0)

    print('Average Rate Is:', eps_avg_e)
    return results_arr, exec_times


#BSC
for N in [2**12]:
    num_trials = 1000

    for p in [0.01, 0.25, 0.49]:
        exec_times_filename = 'Exec_Times_BSC_N=' + str(N) + '_p=' + str(round(p,4)) + '.pickle'
        results_for_p, exec_times_p = simulate_trials_for_arg(bsc_p1, N, p, num_trials)
        with open(exec_times_filename, 'wb') as f:
            # Pickle the 'data' dictionary using the highest protocol available.
            pickle.dump(exec_times_p, f, pickle.HIGHEST_PROTOCOL)