# -*- coding: utf-8 -*-
"""
Created on Sat Apr 13 15:17:38 2024

@author: admin-01
"""

import numpy as np
import pandas as pd
from runestimator import run, run_twostage, run_unbalanced,\
    generate, generate_randm
from utils import distance
from tqdm import tqdm

n = 1000
epsilon = 1
delta = 1e-5
distribution = 9
"""
Distributions:
    1: 1d uniform in [-1,1]
    2: 1d Gaussian in [-1,1]
    3: 2d uniform 
"""
method = 1
"""
method:
    1: Huber loss minimization (current)
    2: two stage approach (baseline)
"""
unbalanced = True
"""
If true, divide the samples randomly
If false, divide the samples uniformly
"""
n_trials = 100

#gammaarray=[8]
gammaarray = [1,2,3,4,5,6,7,8]
narray = [2000,5000,10000]
mavg = 100
"""
If balanced, values in marray is just the m value.
If unbalanced, values in marray is the average number of samples per user,
i.e. N = nm, some users have more samples than m, while others have less samples
than m.
"""
if distribution == 9:
    d_ipums = pd.read_csv("E:/data/ipums.csv") #load the data
res = dict()
res["gamma"] = gammaarray
for n in narray:
    msearray = []
    stdarray = []
    for gamma in gammaarray:
        N = n * mavg
        print("gamma: {}, N: {}".format(gamma, N))
        errs = []
        if distribution == 1:
            C = 5
            tau = 3/np.sqrt(mavg)
            Rc = 1
            mu = np.zeros(1)
        elif distribution == 2:
            C = 10
            tau = 7/np.sqrt(mavg)
            Rc = 1
            mu = np.zeros(1)
        elif distribution == 3:
            C = 10
            if n == 2000:
                tau = 7/np.sqrt(mavg)
            elif n == 5000:
                tau = 9/np.sqrt(mavg)
            if n == 10000:
                tau = 12/np.sqrt(mavg)
            Rc = 3
            mu = np.ones(1)
        elif distribution == 9:
            C = 100000
            tau = 1e5
            Rc = 50000
            mu = (51291.25)*np.ones(1)
        for i in tqdm(range(n_trials)):
            D, m_vec = generate_randm(n, N, gamma, distribution)
            if method == 1:
                weights = np.minimum(np.sqrt(m_vec), 15)
                Tarray = C/weights
                ans, mu0, randerr = run_unbalanced(D, Tarray, weights, Rc, epsilon, delta)
            elif method == 2:
                ans, mu0, randerr = run_twostage(D, tau, Rc, epsilon, delta)
            err = distance(mu0, mu) ** 2 + randerr
            errs.append(err)
        mse = np.mean(np.array(errs))
        std = np.std(np.array(errs))
        print("m={}, mse: {}".format(mavg, mse))
        msearray.append(mse)
        stdarray.append(std)
    res[n] = np.array(msearray)
    res["std_{}".format(n)] = np.array(stdarray)

res = pd.DataFrame(res)
if method == 1:
    res.to_csv('result_unbal_{}_{}_new.csv'.format(distribution, epsilon), index = None)
elif method == 2:
    res.to_csv("result_unbal_{}_{}_baseline.csv".format(distribution, epsilon), index = None)