from dataset.dataset import Dataset
from dataset.domain import Domain
import dataset.cdp2adp as cdp2adp
from dataset.workloads import downward_closure
from mechanism import ddrp, ddrp_SOR
from algebra import VStack, MarginalWorkload, ResidualWorkload2, _construct_contrast_basis, Workload
from linear_operator.operators import KroneckerProductLinearOperator
import json
import pandas as pd
import itertools
import numpy as np
import torch
import seaborn as sns
import time
import matplotlib.pyplot as plt
np.set_printoptions(suppress=True)
import cvxpy as cp
import sys
# sys.path.append('./../../private-pgm/src/')
# from mbi import FactoredInference
# import reconstruction
import torch
from dual_ascent import dualAscent

    
# read in data
data = 'adult'
with open("../../hd-datasets/clean/" + data + "-domain.json") as f:
        domain = json.load(f)
data_raw = pd.read_csv("../../hd-datasets/clean/" + data + ".csv")
col_map = {col: str(i) for i, col in enumerate(data_raw.columns)}
domain = {col_map[col] : domain[col] for col in data_raw.columns}
data_raw.columns = [col_map[col] for col in data_raw.columns]
dat = Dataset(df = data_raw, domain = Domain.fromdict(domain))

# set parameters
epsilon = 1.0
delta = 0.000000001
rho = cdp2adp.cdp_rho(epsilon, delta)
num_marg = 3
step_size = 0.00001

# init output
output = []

print((data, num_marg))

if data == 'adult':
    columns = [4, 8, 12, 14]
elif data == 'titanic':
    columns = [4, 6, 8]
    
for i in columns:

    # set target workload and get measurements
    # T_Q = [('0', '1',), ('1', '2',), ('0', '2',)]
    T_Q = list(itertools.combinations([str(num) for num in range(i)], num_marg))
    marginals = VStack([MarginalWorkload(tup, dat.domain) for tup in T_Q])
    T_M = downward_closure(T_Q)
    residuals = VStack([ResidualWorkload2(tup, dat.domain) for tup in T_M])
    y, sigmas = residuals.getAnswers(dat, rho = rho/len(T_M), return_sigma = True)

    # raw and heuristic error
    inferred = (marginals @ residuals.pinv()) @ y
    true_answers = marginals.getAnswers(dat, sigma = 0)
    errors_raw = np.mean([torch.linalg.vector_norm((inferred[idx] - true_answers[idx]), 1).item() / dat.df.shape[0] for idx in range(len(true_answers))])
    output.append((i, 0, errors_raw, 'Raw'))
    inferred_zero = [torch.max(ytau, torch.tensor(0)) for i, ytau in enumerate(inferred)]
    errors_zero = np.mean([torch.linalg.vector_norm((inferred_zero[idx] - true_answers[idx]), 1).item() / dat.df.shape[0] for idx in range(len(true_answers))]) 
    output.append((i, 0, errors_zero, 'Heuristic-Z'))
    inferred_normalized = [inferred_zero[i] * inferred[i].sum() / inferred_zero[i].sum() for i, _ in enumerate(inferred)]
    errors_normalized = np.mean([torch.linalg.vector_norm((inferred_normalized[idx] - true_answers[idx]), 1).item() / dat.df.shape[0] for idx in range(len(true_answers))]) 
    output.append((i, 0, errors_normalized, 'Heuristic-Z+N'))

    # dual ascent
    da = dualAscent(T_M, T_Q, y, sigmas, domain = dat.domain)
    da.solve(t = step_size, rounds = 4000)
    da_y_opt = [da.y_opt_dict[tup] for tup in da.R_Q]
    inferred_da = (da.marginals @ da.residuals_all.pinv()) @ da_y_opt
    errors_da = np.mean([torch.linalg.vector_norm((inferred_da[idx] - true_answers[idx]), 1).item() / dat.df.shape[0] for idx in range(len(true_answers))])
    output.append((i, da.running_time, errors_da, 'Dual Ascent'))
    
    
outDF = pd.DataFrame(output)
outDF.columns = ['num_cols', 'running_time', 'error', 'method']

outDF.to_csv('results/ascent/by_col_comparison_' + str(num_marg) + 'way_' + data + '.csv')