## Dependencies
import itertools
from itertools import chain
from importlib import reload
from tqdm import trange
import numpy as np
import pickle
import seaborn as snsn
import torch
import torch.nn as nn

## Utils
import data_gen_utils as utils

PARAMS = {
    'player_list'   : np.arange(4, 11),
    'n_games_play'  : 5000, # Number of games to generate per player
    'max_len'       : 20,   # Sequence length
    'prob_q'        : 'gauss',
    'alpha'         : 1,
    'beta'          : 1,
    'loc'           : 1,
    'log'           : False,
    'path'          : '',
    'filename'      : '4to10play_train.pickle', 
    'least_core'    : True,
    'shapley'       : True,
    'banzhaf'       : True,  
    'n_samples'     : 1000,
    'n_resamples'   : 10,
}

N_max = PARAMS['max_len']
G = int(PARAMS['n_games_play'] * len(PARAMS['player_list']))

# Get number of samples per player based on proportions
n_player_repeats = np.repeat(PARAMS['player_list'], PARAMS['n_games_play']).astype(int)

# Get all combinations (coalitions)
combs_dict = {}
for play in PARAMS['player_list']:
    combs = np.array(list(chain.from_iterable(itertools.combinations(np.arange(play), k) for k in range(1, play + 1))), dtype='object')
    combs_dict.update({play : combs})

perms_dict = {}
for play in range(4, 10): # Compute the set of all perms for up to 9 players
    perms = np.array(list(itertools.permutations(range(int(play)), int(play))), dtype='object')
    perms_dict.update({play : perms})

W = np.zeros((G, N_max))
X = np.zeros((G, N_max))
q = np.zeros((G))
sol_stack = np.zeros((G, N_max + 1)) # Labels
Y_shap = np.zeros((G, N_max))
Y_banz = np.zeros((G, N_max))
C_min_win = []  # Minimal set of winning coalitions
player_to_index = {} # Map players to shuffled index in array

for game in trange(G):

    N = n_player_repeats[game]

    # Define the weighted voting game
    weights = (np.random.beta(a=PARAMS['alpha'], b=PARAMS['beta'], 
                              size=(N))) * ((2*N)-1) + PARAMS['loc']
    quota = utils.gen_quota(N, PARAMS['prob_q'])

    while quota > weights.sum(): # Make sure there is at least one solution
        quota = utils.gen_quota(N, prob_dist=PARAMS['prob_q'])

    # Generate set of winning and minimal winning coalitions
    coals_win = [i for i in combs_dict[N] if weights[tuple([i])].sum() >= quota]
    coals_min_win = utils.get_min_win_coals(coals_win, weights, quota)

    # Solve
    if PARAMS['least_core']:
        leastcore_sol = utils.solve_optimal_payoff(N, coals_min_win)

    if PARAMS['shapley']:
        if N > 9: # Approximate 
            shapley_sol_tmp = np.zeros((PARAMS['n_resamples'], N))
            for s in range(PARAMS['n_resamples']): 
                sampled_perms = utils.sample_permutations(N, PARAMS['n_samples'])
                shapley_sol_tmp[s, :] = utils.compute_shapley_vals(N, weights, quota, sampled_perms)
            # Average over resamples
            shapley_sol = shapley_sol_tmp.mean(axis=0)
           
        else:
            # Use all permutations 
            shapley_sol = utils.compute_shapley_vals(N, weights, quota, perms_dict[N])
            
    if PARAMS['banzhaf']:
        banzhaf_sol = utils.compute_banzhaf_index(N, weights, quota, coals_win)

    ###########################
    # Generate random permutation; store elements at these locs
    index = np.random.permutation(N_max)[:N] 
    player_to_index[game] = dict(zip(list(range(N)), list(index)))
    
    W[game, index] = weights
    X[game, index] = weights / quota
    q[game] = quota

    # Store solutions at random indices
    sol_stack[game, index] = leastcore_sol[:-1] 
    sol_stack[game, N_max] = leastcore_sol[-1] # Store epsilon as the last element in the array
    Y_shap[game, index] = shapley_sol
    Y_banz[game, index] = banzhaf_sol

    # Store set of minimal winning coalitions
    C_min_win.append([list(coal) for coal in coals_min_win])

    if PARAMS['log']:
        print(f'Game number {game}')
        print(f'G = [ w = {weights} ; q = {quota} ]')
        print(f'Payoffs full game [{N} players] y = {leastcore_sol[:-1]}, eps = {leastcore_sol[-1]} \n\n')

# Convert the list of variable sized lists of winning coalitions to one hot 
# tensors so that we can use them with Pytorch
max_coal_set = max([len(coal) for coal in C_min_win])
c_tensor_onehot = np.zeros((G, max_coal_set, N_max)) 
c_min_win_shuf = []

for game in trange(G):
    shuf_win_coals_set = []
    for coal in C_min_win[game]:
        shuf_coal = []
        for player in coal:
            # Get shuffled player index
            shuf_coal.append(player_to_index[game].get(player))
        shuf_win_coals_set.append(shuf_coal)
    c_min_win_shuf.append(shuf_win_coals_set)

    for i, shuf_coal in enumerate(shuf_win_coals_set):
        c_tensor_onehot[game, i, shuf_coal] = 1

# Store
data_dict = { 'W'                : torch.from_numpy(W),                       
              'q'                : torch.from_numpy(q),            
              'C_min_tensor'     : c_tensor_onehot, 
              'C_min_shuffled'   : c_min_win_shuf,
              'X'                : torch.from_numpy(X),  
              'sol_stack_lc'     : torch.from_numpy(sol_stack),
              'Y_shap'           : torch.from_numpy(Y_shap),
              'Y_banz'           : torch.from_numpy(Y_banz),
            }

with open(f'{PARAMS["path"]}{PARAMS["filename"]}', 'wb') as handle:
    pickle.dump(data_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)
print('Data stored succesfully.')