from dataset.dataset import Dataset
from dataset.domain import Domain
import dataset.cdp2adp as cdp2adp
from dataset.workloads import downward_closure
from algebra import VStack, MarginalWorkload, ResidualWorkload2, _construct_contrast_basis, Workload
from linear_operator.operators import KroneckerProductLinearOperator
import json
import pandas as pd
import itertools
import numpy as np
import torch
import seaborn as sns
import time
import matplotlib.pyplot as plt
np.set_printoptions(suppress=True)
import sys
import torch
from tqdm import tqdm
import scipy
import argparse

sys.path.append('./../../private-pgm/src/')
from mbi import FactoredInference

parser = argparse.ArgumentParser()
parser.add_argument('--data', default='titanic', help='dataset to run method on')
parser.add_argument('--epsilon', default='1.0', help='privacy budget')
parser.add_argument('--marginals', default='2', help='number of marginals')
parser.add_argument('--iteration', default='0', help='experiment iteration')
parser.add_argument('--rounds', default='10', help='rounds')
args = parser.parse_args()

# set args
num_marg = args.marginals
iteration = args.iteration
data = args.data
epsilon = args.epsilon
rounds = args.rounds

# read in data
with open("../../hd-datasets/clean/" + data + "-domain.json") as f:
        domain = json.load(f)
data_raw = pd.read_csv("../../hd-datasets/clean/" + data + ".csv")
col_map = {col: str(i) for i, col in enumerate(data_raw.columns)}
domain = {col_map[col] : domain[col] for col in data_raw.columns}
data_raw.columns = [col_map[col] for col in data_raw.columns]
dat = Dataset(df = data_raw, domain = Domain.fromdict(domain))

# read in measurements
with open('results/ascent/marg_measurements/scalable_mwem_' + data +  
    num_marg + 'way_' + epsilon + 
    'rounds' + rounds + '_itr' + iteration + '.json', 'r') as file:
    measurements = json.load(file)

cols, y, sigmas = measurements['cols'], measurements['y'], measurements['sigmas']

# Private-PGM
marg_measurements = [(np.eye(len(y[idx])), np.array(y[idx]), sigmas[idx], tuple(tup)) for idx, tup in enumerate(cols)]
engine = FactoredInference(dat.domain, log=True)
pgm_start = time.time()
model = engine.estimate(marg_measurements, engine='MD')
pgm_time = time.time() - pgm_start

# calc error
T_Q = list(itertools.combinations([str(num) for num in dat.df.columns], int(num_marg)))
marginals = VStack([MarginalWorkload(tup, dat.domain) for tup in T_Q])
pgm = [model.project(tup).datavector() for tup in T_Q]
output = []
true_answers = marginals.getAnswers(dat, sigma = 0)
errors_pgm1 = np.mean([np.linalg.norm((pgm[idx] - true_answers[idx].numpy()), 1).item() / dat.df.shape[0] for idx in range(len(true_answers))])
errors_pgm2 = np.mean([np.linalg.norm((pgm[idx] - true_answers[idx].numpy()), 2).item() / dat.df.shape[0] for idx in range(len(true_answers))])
output.append((iteration, epsilon, pgm_time, errors_pgm1, errors_pgm2, 'Private-PGM'))
outDF = pd.DataFrame(output)
outDF.columns = ['itr', 'epsilon', 'running_time', 'error_l1', 'error_l2', 'method']
outDF['rounds'] = rounds

# write to file
outDF.to_csv('results/ascent/scalable_mwem_' + data +  num_marg 
             + 'way_' + epsilon + 'rounds' + rounds + '_itr' 
             + iteration + '_pgm' + '.csv')