from dataset.dataset import Dataset
from dataset.domain import Domain
import dataset.cdp2adp as cdp2adp
from dataset.workloads import downward_closure
from algebra import VStack, MarginalWorkload, ResidualWorkload2, _construct_contrast_basis, Workload
from utils import getOptimalSigmasCF
from linear_operator.operators import KroneckerProductLinearOperator
import json
import numpy as np
import pandas as pd
import itertools
import torch
import seaborn as sns
import time
import matplotlib.pyplot as plt
np.set_printoptions(suppress=True)

import torch
from dual_ascent import dualAscent
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--data', default='titanic', help='dataset to run OMaRR on')
parser.add_argument('--epsilon', default=1.0, help='privacy budget', type = float)
parser.add_argument('--marginals', default=2, help='number of marginals', type = int)
parser.add_argument('--iteration', default='0', help='experiment iteration')
parser.add_argument('--sigmas', default='RP', help='sigmas')
args = parser.parse_args()
    
# read in data
data = args.data
with open("../../hd-datasets/clean/" + data + "-domain.json") as f:
        domain = json.load(f)
data_raw = pd.read_csv("../../hd-datasets/clean/" + data + ".csv")
col_map = {col: str(i) for i, col in enumerate(data_raw.columns)}
domain = {col_map[col] : domain[col] for col in data_raw.columns}
data_raw.columns = [col_map[col] for col in data_raw.columns]
dat = Dataset(df = data_raw, domain = Domain.fromdict(domain))

# set parameters
epsilon = args.epsilon
delta = 0.000000001
rho = cdp2adp.cdp_rho(epsilon, delta)
num_marg = args.marginals
# step_size = 0.0001
j = args.iteration

# init output
output = []

print((data, num_marg))

# set target workload and get measurements
# T_Q = [('0', '1',), ('1', '2',), ('0', '2',)]
i = dat.df.shape[1]
T_Q = list(itertools.combinations([str(num) for num in range(i)], num_marg))
marginals = VStack([MarginalWorkload(tup, dat.domain) for tup in T_Q])
T_M = downward_closure(T_Q)
residuals = VStack([ResidualWorkload2(tup, dat.domain) for tup in T_M])
opt_sigmas = getOptimalSigmasCF(marginals = T_Q, rho = rho, domain = dat.domain)
opt_sigmas_inorder = [opt_sigmas[tup] for tup in T_M]
y, sigmas = residuals.getAnswers(dat, sigma = opt_sigmas_inorder, return_sigma = True)

# raw and heuristic error
inferred = (marginals @ residuals.pinv()) @ y
true_answers = marginals.getAnswers(dat, sigma = 0)
errors_raw1 = np.mean([torch.linalg.vector_norm((inferred[idx] - true_answers[idx]), 1).item() / dat.df.shape[0] for idx in range(len(true_answers))])
errors_raw2 = np.mean([torch.linalg.vector_norm((inferred[idx] - true_answers[idx]), 2).item() / dat.df.shape[0] for idx in range(len(true_answers))])
output.append((j, round(epsilon, 4), 0, errors_raw1, errors_raw2, 'Raw', None))
inferred_zero = [torch.max(ytau, torch.tensor(0)) for i, ytau in enumerate(inferred)]
errors_zero1 = np.mean([torch.linalg.vector_norm((inferred_zero[idx] - true_answers[idx]), 1).item() / dat.df.shape[0] for idx in range(len(true_answers))]) 
errors_zero2 = np.mean([torch.linalg.vector_norm((inferred_zero[idx] - true_answers[idx]), 2).item() / dat.df.shape[0] for idx in range(len(true_answers))]) 
output.append((j, round(epsilon, 4), 0, errors_zero1, errors_zero2, 'Heuristic-Z', None))
inferred_normalized = [inferred_zero[i] * inferred[i].sum() / inferred_zero[i].sum() for i, _ in enumerate(inferred)]
errors_normalized1 = np.mean([torch.linalg.vector_norm((inferred_normalized[idx] - true_answers[idx]), 1).item() / dat.df.shape[0] for idx in range(len(true_answers))]) 
errors_normalized2 = np.mean([torch.linalg.vector_norm((inferred_normalized[idx] - true_answers[idx]), 2).item() / dat.df.shape[0] for idx in range(len(true_answers))]) 
output.append((j, round(epsilon, 4), 0, errors_normalized1, errors_normalized2, 'Heuristic-Z+N', None))


# for sig in ['RP', 'proportional', 'uniform']:
for sig in ['proportional']:
    
    # define sigmas
    if sig == 'RP':
        da_sigmas = sigmas
    elif sig == 'proportional':
        da_sigmas = [2 ** len(tup) for tup in T_M]
    elif sig == 'uniform':
        da_sigmas = [30 for tup in T_M]
    elif sig == 'mixed':
        da_sigmas = [(opt_sigmas[tup] * (4 ** len(tup))) / 10 for tup in T_M]
    
    # dual ascent
    da = dualAscent(T_M, T_Q, y, da_sigmas, domain = dat.domain)
    da.solveLooping(rounds = 4000, true_answers = true_answers, num_records = dat.df.shape[0])
    da_y_opt = [da.y_opt_dict[tup] for tup in da.R_Q]
    inferred_da = (da.marginals @ da.residuals_all.pinv()) @ da_y_opt
    errors_da1 = np.mean([torch.linalg.vector_norm((inferred_da[idx] - true_answers[idx]), 1).item() / dat.df.shape[0] for idx in range(len(true_answers))])
    errors_da2 = np.mean([torch.linalg.vector_norm((inferred_da[idx] - true_answers[idx]), 2).item() / dat.df.shape[0] for idx in range(len(true_answers))])
    output.append((j, round(epsilon, 4), da.running_time, errors_da1, errors_da2, 'Dual Ascent', sig))

outDF = pd.DataFrame(output)
outDF.columns = ['itr', 'epsilon', 'running_time', 'error_l1', 'error_l2', 'method', 'sigmas']

outDF.to_csv('results/ascent/supported_' + data +  str(num_marg) + 'way_' + str(round(epsilon, 2)) + '_itr' + j + '.csv')