from dataset.dataset import Dataset
from dataset.domain import Domain
import dataset.cdp2adp as cdp2adp
from dataset.workloads import downward_closure
from algebra import VStack, MarginalWorkload, ResidualWorkload2, _construct_contrast_basis, Workload
from linear_operator.operators import KroneckerProductLinearOperator
import json
import pandas as pd
import itertools
import numpy as np
import torch
import seaborn as sns
import time
import matplotlib.pyplot as plt
np.set_printoptions(suppress=True)
import sys
import torch
from tqdm import tqdm
import scipy
import argparse

import warnings
warnings.filterwarnings("ignore")

from dual_ascent import dualAscent
from utils import *

# sys.path.append('./../../private-pgm/src/')
# from mbi import FactoredInference

def scoreWorstMarginal(candidate, data, M_plus, y, norm = 1):
    true = MarginalWorkload(candidate, data.domain).getAnswers(data, sigma = 0)
    QM_plus = MarginalWorkload(candidate, data.domain).toDense() @ M_plus
    inferred = QM_plus @ torch.cat(y, dim = 0)
    return torch.linalg.norm(true - inferred, norm).item()

def scoreWorstMarginalKron(candidate, data, M_plus, y, norm = 1):
    true = MarginalWorkload(candidate, data.domain).getAnswers(data, sigma = 0)
    QM_plus = MarginalWorkload(candidate, data.domain) @ M_plus
    inferred = QM_plus @ y
    return torch.linalg.norm(true - inferred, norm).item()

def exponential(R, scores, sensitivity, epsilon):
    probabilities = scipy.special.softmax((0.5*epsilon/sensitivity)*scores)
    # print(probabilities.round(5))
    # print((probabilities.min(), probabilities.mean(), probabilities.max()))
    index = np.random.choice(range(len(R)), 1, p=probabilities)[0]
    return R[index]

# pseudoinverse from residuals
def pinvFromResiduals(M, y, sigmas):
    M_gen = []
    y_gen = []
    sigmas_gen = []
    T_M = M.cols()
    for tup in set(T_M):
        first_idx = T_M.index(tup)
        M_gen.append(M.workloads[first_idx])
        ys_tup = [y[idx] for idx, cols in enumerate(T_M) if cols == tup]
        new_y = sum(ys_tup) / len(ys_tup)
        y_gen.append(new_y)
        new_sigma = sum([(sigmas[idx] / len(ys_tup)) ** 2 for idx, cols in enumerate(T_M) if cols == tup]) ** 0.5
        sigmas_gen.append(new_sigma)
    M_gen = VStack(M_gen)
    return M_gen, y_gen, sigmas_gen

# inverse variance weighting from residuals
def ivwFromResiduals(M, y, sigmas):
    M_gen = []
    y_gen = []
    sigmas_gen = []
    T_M = M.cols()
    for tup in set(T_M):
        first_idx = T_M.index(tup)
        M_gen.append(M.workloads[first_idx])
        numerator = sum([y[idx] / sigmas[idx] ** 2 for idx, cols in enumerate(T_M) if cols == tup])
        denominator = sum([sigmas[idx] ** -2 for idx, cols in enumerate(T_M) if cols == tup])
        y_gen.append(numerator/denominator)
        sigmas_gen.append(denominator ** -0.5)
    M_gen = VStack(M_gen)
    return M_gen, y_gen, sigmas_gen
  
class scalableMWEMpinv:
    def __init__(self, target_marginals, rho, rounds):
        """
        Instantiates scalable MWEM mechanism using pseudoinverse reconstruction
        :param target_marginals: list of tuples of indices
        :param rho: scalar; privacy budget
        :param rounds: int
        """
        
        self.target_marginals = target_marginals.copy()
        self.rounds = rounds
        self.rho = rho
        self.gamma = 0.1
        self.alpha = 0.5
        self.rho_init = self.rho * self.gamma
        self.rho_round = (self.rho - self.rho_init) / self.rounds
        self.initialization = 0
                 
    def run(self, data, return_marginals = False):
        ## get all 1D marginal queries 
        init_idx = [tup for tup in downward_closure(self.target_marginals) if len(tup) == self.initialization]
        
        ## create M, y, sigmas
        for idx, tup in enumerate(init_idx):
            init_marginal = MarginalWorkload(tup, data.domain)
            y_init_marginal, sigma_init_marginal = init_marginal.getAnswers(data, rho = self.rho_init/len(init_idx), return_sigma = True)
            M_init, y_init, sigmas_init = init_marginal.decomposeIntoResiduals(y = y_init_marginal, sigma = sigma_init_marginal)
            if idx == 0:
                M, y, sigmas = M_init, y_init, sigmas_init
                marginals, marginals_y, marginals_sigmas = VStack([init_marginal]), [y_init_marginal], [sigma_init_marginal]
            else:
                M += M_init
                y += y_init
                sigmas += sigmas_init
                marginals = marginals.append(init_marginal)
                marginals_y.append(y_init_marginal)
                marginals_sigmas.append(sigma_init_marginal)     
        
        candidates = self.target_marginals
        
        for t in tqdm(range(self.rounds)):
            # pseudoinverse from residuals
            M_gen, y_gen, sigmas_gen = pinvFromResiduals(M, y, sigmas)
            
            # get scores
            scores = np.array([scoreWorstMarginalKron(cand, data, M_gen.pinv(), y_gen, norm = 1) for cand in candidates])
            
            # run exp mechanism and measure selected workload
            c_star = exponential(candidates, scores, 1, (self.alpha * self.rho_round * 8) ** 0.5)
            print(c_star, scores[candidates.index(c_star)], scores.max())
            candidates.remove(c_star)
            c_star_wkload = MarginalWorkload(c_star, data.domain)
            marginal_answers, marginal_sigma = c_star_wkload.getAnswers(data, rho = (1 - self.alpha) * self.rho_round, return_sigma = True)
            c_star_residuals, residual_answers, residual_sigmas = c_star_wkload.decomposeIntoResiduals(y = marginal_answers, sigma = marginal_sigma)
            
            # Add residuals to M, y, sigmas
            M += c_star_residuals
            y += residual_answers
            sigmas += residual_sigmas
            marginals = marginals.append(c_star_wkload)
            marginals_y.append(marginal_answers)
            marginals_sigmas.append(marginal_sigma)
        
        if return_marginals:
            return ((M, y, sigmas), (marginals, marginals_y, marginals_sigmas))
        else:
            return M, y, sigmas
          
parser = argparse.ArgumentParser()
parser.add_argument('--data', default='titanic', help='dataset to run method on')
parser.add_argument('--epsilon', default=1.0, help='privacy budget', type = float)
parser.add_argument('--marginals', default=2, help='number of marginals', type = int)
parser.add_argument('--iteration', default='0', help='experiment iteration')
parser.add_argument('--rounds', default=10, help='rounds', type = int)
args = parser.parse_args()

# read in data
data = args.data
with open("../../hd-datasets/clean/" + data + "-domain.json") as f:
        domain = json.load(f)
data_raw = pd.read_csv("../../hd-datasets/clean/" + data + ".csv")
col_map = {col: str(i) for i, col in enumerate(data_raw.columns)}
domain = {col_map[col] : domain[col] for col in data_raw.columns}
data_raw.columns = [col_map[col] for col in data_raw.columns]
dat = Dataset(df = data_raw, domain = Domain.fromdict(domain))

# set parameters
epsilon = args.epsilon
delta = 0.000000001
rho = cdp2adp.cdp_rho(epsilon, delta)
num_marg = args.marginals
j = args.iteration
T_Q = list(itertools.combinations([str(num) for num in dat.df.columns], num_marg))
marginals = VStack([MarginalWorkload(tup, dat.domain) for tup in T_Q])
rounds = args.rounds

# run mechanism
mech_scale = scalableMWEMpinv(T_Q, rho, rounds)
res, mar = mech_scale.run(dat, return_marginals = True)
M, y, sigmas = res
mar_M, mar_y, mar_sigmas = mar

# calc error
output = []

# raw and heuristic error
M_pinv, y_pinv, sigmas_pinv = pinvFromResiduals(M, y, sigmas)
inferred = (marginals @ M_pinv.pinv()) @ y_pinv
true_answers = marginals.getAnswers(dat, sigma = 0)
errors_raw1 = np.mean([torch.linalg.vector_norm((inferred[idx] - true_answers[idx]), 1).item() / dat.df.shape[0] for idx in range(len(true_answers))])
errors_raw2 = np.mean([torch.linalg.vector_norm((inferred[idx] - true_answers[idx]), 2).item() / dat.df.shape[0] for idx in range(len(true_answers))])
output.append((j, round(epsilon, 4), 0, errors_raw1, errors_raw2, 'Raw'))
inferred_zero = [torch.max(ytau, torch.tensor(0)) for i, ytau in enumerate(inferred)]
errors_zero1 = np.mean([torch.linalg.vector_norm((inferred_zero[idx] - true_answers[idx]), 1).item() / dat.df.shape[0] for idx in range(len(true_answers))]) 
errors_zero2 = np.mean([torch.linalg.vector_norm((inferred_zero[idx] - true_answers[idx]), 2).item() / dat.df.shape[0] for idx in range(len(true_answers))]) 
output.append((j, round(epsilon, 4), 0, errors_zero1, errors_zero2, 'Heuristic-Z'))
inferred_normalized = [inferred_zero[i] * inferred[i].sum() / inferred_zero[i].sum() for i, _ in enumerate(inferred)]
errors_normalized1 = np.mean([torch.linalg.vector_norm((inferred_normalized[idx] - true_answers[idx]), 1).item() / dat.df.shape[0] for idx in range(len(true_answers))]) 
errors_normalized2 = np.mean([torch.linalg.vector_norm((inferred_normalized[idx] - true_answers[idx]), 2).item() / dat.df.shape[0] for idx in range(len(true_answers))]) 
output.append((j, round(epsilon, 4), 0, errors_normalized1, errors_normalized2, 'Heuristic-Z+N'))

M_gen, y_gen, sigmas_gen = ivwFromResiduals(M, y, sigmas)
inferred_gen = (marginals @ M_gen.pinv()) @ y_gen
errors_gen1 = np.mean([torch.linalg.vector_norm((inferred_gen[idx] - true_answers[idx]), 1).item() / dat.df.shape[0] for idx in range(len(true_answers))])
errors_gen2 = np.mean([torch.linalg.vector_norm((inferred_gen[idx] - true_answers[idx]), 2).item() / dat.df.shape[0] for idx in range(len(true_answers))])
output.append((j, round(epsilon, 4), 0, errors_gen1, errors_gen2, 'ReM-MLE'))

T_M = M.cols()
da_sigmas = [2 ** len(tup) for tup in T_M]
# da_sigmas = [2 ** len(tup) for tup in T_M]

# dual ascent
da = dualAscent(T_M, T_Q, y, da_sigmas, domain = dat.domain)
# da = dualAscentBlockwise(T_M, T_Q, y, da_sigmas, domain = dat.domain)
da.solveLooping(rounds = 1000, 
                lam = -1, 
                t = 0.02, 
                t_div = 10 ** 0.5, 
                true_answers = true_answers, 
                num_records = dat.df.shape[0],
                reg_param = 40
               )
da_y_opt = [da.y_opt_dict[tup] for tup in da.R_Q]
inferred_da = (da.marginals @ da.residuals_all.pinv()) @ da_y_opt
errors_da1 = np.mean([torch.linalg.vector_norm((inferred_da[idx] - true_answers[idx]), 1).item() / dat.df.shape[0] for idx in range(len(true_answers))])
errors_da2 = np.mean([torch.linalg.vector_norm((inferred_da[idx] - true_answers[idx]), 2).item() / dat.df.shape[0] for idx in range(len(true_answers))])
output.append((j, round(epsilon, 4), da.running_time, errors_da1, errors_da2, 'Dual Ascent'))

outDF = pd.DataFrame(output)
outDF.columns = ['itr', 'epsilon', 'running_time', 'error_l1', 'error_l2', 'method']
outDF['rounds'] = rounds

outDF.to_csv('results/ascent/scalable_mwem_' + data +  str(num_marg) + 'way_' + str(round(epsilon, 2)) + 'rounds' + str(args.rounds) + '_itr' + j + '.csv')

outMarg = {}
outMarg['cols'] = mar_M.cols() 
outMarg['y'] = [x.tolist() for x in mar_y]
outMarg['sigmas'] = mar_sigmas

with open('results/ascent/marg_measurements/scalable_mwem_' + data +  str(num_marg) + 'way_' + str(round(epsilon, 2)) + 'rounds' + str(args.rounds) + '_itr' + j + '.json', 'w') as file:
    json.dump(outMarg, file)