import time
from datetime import datetime

import numpy as np
import pandas as pd
import scipy
from scipy.optimize import minimize
from scipy.stats import multivariate_normal, laplace

from models.grb_offline import GRB_Model
from models.utils import project_onto_ball, project_onto_ball_first_quadrant, read_problem


class NoisyDualMD(object):
    def __init__(self, exp_params, algo_params):
        self.x = None
        self.algo_params = algo_params
        self.exp_params = exp_params
        for k, v in algo_params.items():
            setattr(self, k, v)
        for k, v in exp_params.items():
            setattr(self, k, v)
        self.ngamma = self.n * self.gamma

    def init_for_parallel(self):
        # --- Misc Init
        self.obj_coeffs = np.asarray([request[0] for request in self.requests])
        self.As = np.asarray([request[1] for request in self.requests])
        self.As_transposed = np.transpose(self.As, axes=[0, 2, 1])
        self.indiv_assign_bds_flag = np.any([request[2] is not None for request in self.requests])  # flag up means lbs & ups play a role
        if self.indiv_assign_bds_flag:
            self.indiv_assign_lbs = np.asarray([[domain[0] for domain in request[2]] for request in self.requests])
            self.indiv_assign_ubs = np.asarray([[domain[1] for domain in request[2]] for request in self.requests])
        else:
            self.indiv_assign_lbs = np.zeros((self.n, self.s))   # default bds for each assignment variables
            self.indiv_assign_ubs = np.ones((self.n, self.s))
        self.exp_params['indiv_assign_bds_flag'] = self.indiv_assign_bds_flag
        self.exp_params['indiv_assign_lbs'] = self.indiv_assign_lbs
        self.exp_params['indiv_assign_ubs'] = self.indiv_assign_ubs

        # --- non-private case
        self.x_non_private, self.prices_non_private, self.objval_non_private = GRB_Model(self.exp_params, self.algo_params).Solve_x_p()
        print('non_private shadow prices:', self.prices_non_private)
        if self.p_init is None:
            if self.dual_bd_for_p_init:
                self.p_init = self.K_multiplier * self.dual_bd / self.m * np.ones(self.m)
            else:
                self.p_init = self.K_multiplier * np.linalg.norm(self.prices_non_private, 1) / self.m * np.ones(self.m)
        print('p_init:', self.p_init)

        # --- Mirror Descent Init
        if self.algo_params['update_rule'].startswith('md_ne'):
            self.q = np.inf
            self.B_Phi = lambda x, y: scipy.special.kl_div(x, y).sum()
            self.std_normal_norm2 = 2 * np.log(2 * self.m)
            self.K_multiplier = self.algo_params['K_multiplier']
            self.K = self.K_multiplier * self.dual_bd
            self.alpha = 1 / self.K
            if self.algo_params['update_rule'] == 'md_ne_parameterized':
                self.alpha = np.min(self.b) / self.K
        elif self.algo_params['update_rule'] in ['md_l2', 'gd', 'projected_gd', 'projected_gd_positive', 'projected_gd_box']:
            self.q = 2
            self.B_Phi = lambda x, y: 1/2 * np.linalg.norm(x - y, 2)**2
            self.std_normal_norm2 = self.m
            self.K_multiplier = self.algo_params['K_multiplier']
            self.K = self.K_multiplier * self.dual_bd
            self.alpha = 1
        self.update_shadow_prices = self.shadow_price_update_rule(self.update_rule)
        self.netvalue_to_decision = self.netvalue_to_decision_map()
        self.b_qnorm = np.linalg.norm(self.b, self.q)
        self.b_2norm = np.linalg.norm(self.b, 2)
        self.ngammab_qnorm = np.linalg.norm(self.ngamma * self.b, self.q)
        self.G = max(self.ngammab_qnorm ** 2, np.linalg.norm(self.ngamma * self.b - self.b, self.q) ** 2)

        # --- Privacy Init
        if self.private == 1:
            self.c_eps_delta = self.b_2norm ** 2 * minimize(fun=lambda eps: (np.log(1 / self.delta) + self.vareps - eps) / (2 * (self.vareps - eps) * eps),
                                                                         x0=self.vareps / 2,
                                                                         bounds=((0, self.vareps),),
                                                                         tol=1e-12).fun
            # self.T_private_min = self.G / (self.c_eps_delta * self.std_normal_norm2)
            # if self.private_iters:
            #     self.T_private = self.private_iters[1] - self.private_iters[0] + 1
            # else:
            #     self.T_private = self.T
            self.sigma2 = self.T * self.c_eps_delta
            self.noise_var = multivariate_normal(mean=np.zeros(self.m), cov=np.eye(self.m) * self.sigma2)
            self.eta = np.sqrt(self.alpha * self.B_Phi(np.zeros(self.m), self.p_init)
                               / (self.T * (self.G + self.sigma2 * self.std_normal_norm2))
                               )

        if not self.private:
            self.vareps = np.inf
            self.delta = 1
            self.noise_var = multivariate_normal(mean=np.zeros(self.m), cov=np.eye(self.m) * 1e-64)
            self.eta = np.sqrt(self.alpha * self.B_Phi(np.zeros(self.m), self.p_init)
                               / (self.T * self.G)
                               )

        # -- private == 2 means hsu2016
        if self.private == 2:
            self.w = np.fmax(np.max(self.ngamma * self.b), np.sum(np.max(np.max(self.As, axis=-1), axis=-1)))
            self.T = int(self.w ** 2)
            self.vareps_prime = self.b_2norm * self.vareps / np.sqrt(8 * self.T * np.log(2 / self.delta))
            self.delta_prime = self.delta / (2 * self.T)
            self.sigma2 = 2 * self.b_2norm ** 2 * np.log(1.25 / self.delta_prime) / self.vareps_prime ** 2
            self.noise_var = multivariate_normal(mean=np.zeros(self.m), cov=np.eye(self.m) * self.sigma2)
            self.eta = 2 * self.dual_bd / (np.sqrt(self.T) * (self.w + 1/self.vareps_prime * np.log(self.T * self.m / 0.05)))
            self.p_init = np.zeros(self.m)

        # -- private == 3 means huang2018 algo 1
        '''
        do not use this for experiments. The algo only works for problems with coefficients in [0,1]
        '''
        if self.private == 3 or self.update_rule == 'huang2018':
            m = self.m
            beta = 0.99
            self.approxi_param = 20 * np.log(self.T) * np.sqrt(m * np.log(m+1) * np.log(6/beta)*np.log(2/self.delta)) / (self.n * self.gamma * self.vareps)
            self.p_max = np.max(self.obj_coeffs) * 4 * self.n / (self.n * self.gamma)
            self.p_init = self.p_max / (self.m + 1) * np.ones(self.m)
            self.prices_ext = self.p_max / (self.m + 1) * np.ones(self.m + 1)
            self.T = int(self.vareps**2 * self.n**2 / self.m)
            self.eta = np.log(self.m + 1) / (self.approxi_param * self.n * self.gamma * self.T)
            self.vareps_prime = self.vareps / (np.sqrt(8 * self.T * self.m * np.log(2 / self.delta)))
            self.Nabla_max = self.n + np.log(self.T) / self.vareps_prime
            self.noise_var = laplace(loc=np.zeros(self.m), scale=1 / self.vareps_prime)

    def shadow_price_update_rule(self, update_rule):
        def mirror_descent_ne(prices, gradients):
            obj_t = lambda p: self.eta * gradients @ p + self.B_Phi(p, prices)
            res = minimize(obj_t,
                           x0=prices,
                           method='SLSQP',
                           constraints={'type': 'ineq', 'fun': lambda p: self.K - p.sum()},  # 'ineq' means fun>=0
                           bounds=[(0, None)] * self.m,
                           tol=1e-16,
                           options={'ftol': self.ftol})
            prices, _fun_value = res.x, res.fun
            return prices

        def mirror_descent_ne_parameterized(prices, gradients):
            obj_t = lambda p: self.eta * gradients @ p + self.B_Phi(self.b * p, self.b * prices)
            res = minimize(obj_t,
                           x0=prices,
                           method='SLSQP',
                           constraints={'type': 'ineq', 'fun': lambda p: self.K - self.b @ p},  # 'ineq' means fun>=0
                           bounds=[(0, None)] * self.m,
                           tol=1e-16,
                           options={'ftol': self.ftol})
            prices, _fun_value = res.x, res.fun
            return prices

        def update_rule_huang2018(prices, gradients):
            gradients = np.clip(gradients, -self.Nabla_max, self.Nabla_max)
            gradients_ext = np.concatenate([gradients, [0]])
            prices_ext = np.concatenate([prices, self.prices_ext[-1:]])
            p_tmp = prices_ext * (1 - self.eta * gradients_ext)
            prices_ext = p_tmp / np.sum(p_tmp) * self.p_max
            self.gradients_ext = gradients_ext
            self.prices_ext = prices_ext
            return prices_ext[:-1]

        def mirror_descent_l2(prices, gradients):
            obj_t = lambda p: self.eta * gradients @ p + self.B_Phi(p, prices)
            res = minimize(obj_t,
                           x0=prices,
                           method='SLSQP',
                           bounds=[(0, None)] * self.m,
                           tol=1e-16,
                           options={'ftol': self.ftol})
            prices, _fun_value = res.x, res.fun
            return prices

        def gd(prices, gradients):
            return prices - self.eta * gradients

        def projected_gd(prices, gradients):
            return project_onto_ball(prices - self.eta * gradients, radius=self.K)

        def projected_gd_positive(prices, gradients):
            return project_onto_ball_first_quadrant(prices - self.eta * gradients, radius=self.K)

        def projected_gd_box(prices, gradients):
            return np.fmin(np.fmax(prices - self.eta * gradients, 0), 2 * self.dual_bd)

        if self.private == 2 or self.update_rule == 'hsu2016':
            self.update_rule = 'projected_gd_box'
            return projected_gd_box
        if self.private == 3 or self.update_rule == 'huang2018':
            self.update_rule = 'huang2018'
            return update_rule_huang2018
        if update_rule == 'md_ne':
            return mirror_descent_ne
        if update_rule == 'md_ne_parameterized':
            return mirror_descent_ne_parameterized
        if update_rule == 'md_l2':
            return mirror_descent_l2
        if update_rule == 'gd':
            return gd
        if update_rule == 'projected_gd':
            return projected_gd
        if update_rule == 'projected_gd_positive':
            return projected_gd_positive

    def netvalue_to_decision_map(self):
        n, s = self.n, self.s

        if (np.all(self.exp_params['indiv_totalassign_lbs'] == self.exp_params['indiv_totalassign_lbs'][0]) and
                np.all(self.exp_params['indiv_totalassign_ubs'] == self.exp_params['indiv_totalassign_ubs'][0])):
            lb, ub = self.exp_params['indiv_totalassign_lbs'][0], self.exp_params['indiv_totalassign_ubs'][0]
            lb, ub = int(lb), int(ub)

            if lb == 0 and ub == 1:
                def _rule(net_value):
                    xs_t = np.zeros((n, s))
                    max_each_row, idx = np.max(net_value, axis=1), np.argmax(net_value, axis=1)
                    xs_t[range(n), idx] = 1
                    xs_t[np.where(max_each_row < 0), :] = 0
                    return xs_t

            else:
                def _rule(net_values):
                    xs_t = np.zeros((n, s))
                    items_candidates = np.argpartition(net_values, kth=(-ub, -lb), axis=1)  # get top-(ub) items' indices for each agent
                    items_optionals, items_musts = items_candidates[:, s-ub:s-lb], items_candidates[:, s-lb:]
                    # fulfill their needs
                    np.put_along_axis(xs_t, items_musts, values=1, axis=1)
                    # fulfill optionals only if netvalue > 0
                    items_optional_net_values = np.take_along_axis(net_values, items_optionals, axis=1)
                    mx = np.ma.masked_array(items_optionals, mask=items_optional_net_values <= 0)
                    for x_t, mx_sub in zip(xs_t, mx):
                        x_t[mx_sub[~mx_sub.mask]] = 1
                    return xs_t

        else:
            # workforce_scheduling problem where each agent's lb&ub is different
            lbs, ubs = self.exp_params['indiv_totalassign_lbs'], self.exp_params['indiv_totalassign_ubs']
            def _rule(net_values):
                xs_t = []
                for net_value, lb, ub in zip(net_values, lbs, ubs):
                    x_t = np.zeros(s)
                    items_candidate = np.argpartition(net_value, kth=(-ub, -lb))   # get top-(ub) items' indices for each agent
                    items_optional, items_must = items_candidate[s-ub:s-lb], items_candidate[s-lb:]
                    # fulfill must-items
                    x_t[items_must] = 1
                    items_optional_net_value = net_value[items_optional]
                    # fulfill optional items if netvalue > 0
                    mx = np.ma.masked_array(items_optional, mask=items_optional_net_value <= 0)
                    x_t[mx[~mx.mask]] = 1
                    xs_t.append(x_t)
                xs_t = np.asarray(xs_t)
                return xs_t

        return _rule

    def evaluate_x(self, x, output_key_suffix=''):
        # - obj gap
        objval_private = np.sum(self.obj_coeffs * x)
        objavl_gap = (self.objval_non_private - objval_private) / self.objval_non_private

        # - misallocation
        misallocation = np.sum(np.abs(x - self.x_non_private)) / np.sum(self.x_non_private)

        output = {'objval_gap%': objavl_gap * 100,
                  'objval_private': objval_private,
                  'misallocation%': misallocation * 100
                  }

        # - violations
        Ax_b = np.sum([A @ x for A, x in zip(self.As, x)], axis=0) - self.ngamma * self.b
        res = self.evaluate_Ax_b(Ax_b)
        output.update(res)

        output = {k + output_key_suffix: v for k, v in output.items()}
        return output

    def evaluate_p(self, prices, output_key_suffix=''):
        # - prices gap
        l2_dual_prices = np.linalg.norm(prices - self.prices_non_private, 2) / np.linalg.norm(self.prices_non_private, 2)
        l1_dual_prices = np.linalg.norm(prices - self.prices_non_private, 1) / np.linalg.norm(self.prices_non_private, 1)
        mean_dual_prices = np.mean(prices - self.prices_non_private)

        output = {'p_gap_l2%': l2_dual_prices * 100,
                  'p_gap_l1%': l1_dual_prices * 100,
                  'p_gap_mean': mean_dual_prices}  # 'p_private': prices_prefavg,

        output = {k + output_key_suffix: v for k, v in output.items()}
        return output

    @staticmethod
    def evaluate_Ax_b(Ax_b, output_key_suffix=''):
        Ax_b_mean = np.mean(Ax_b)
        Ax_b_l1 = np.linalg.norm(Ax_b, 1)
        violations = np.fmax(Ax_b, 0)
        violations_sum = np.fmax(violations, 0).sum()
        violations_max = np.max(np.fmax(violations, 0))

        output = {'Ax-b_mean': Ax_b_mean,
                  'Ax-b_l1': Ax_b_l1,
                  'violations_total': violations_sum,
                  'violations_max': violations_max
                  }

        output = {k + output_key_suffix: v for k, v in output.items()}
        return output

    @staticmethod
    def print_log(output, suffix=''):
        log_string = ("t=%s, one batch takes %.2fs, perf_suffix: %s. "
                      "objval_gap: %.2f%%, "
                      "dual_vars accuracy: %.2f (mean), %.2f%% (l1), %.2f%% (l2), "
                      "Ax-b: %.2f (mean), %.2f (l1), violations: %.2f (sum), %.2f (max)"
                    % (output['t'], output['runtime'], suffix,
                       output[f'objval_gap%{suffix}'],
                       output[f'p_gap_mean{suffix}'], output[f'p_gap_l1%{suffix}'], output[f'p_gap_l2%{suffix}'],
                       output[f'Ax-b_mean{suffix}'], output[f'Ax-b_l1{suffix}'], output[f'violations_total{suffix}'], output[f'violations_max{suffix}']))
        print(log_string)

    def Solve_x_p(self):
        # print('exp_params:', self.exp_params)
        print('algo_params:', self.algo_params)

        prices = self.p_init
        x_prefavg = np.zeros((self.n, self.s))
        prices_prefavg = np.zeros(self.m)

        batch_size = self.batch_size
        xs_one_batch = []
        prices_one_batch = []
        prices_all_batches = []
        Ax_b_all_batches = []
        objval_all_batches = []

        # --- log
        t0 = time.time()
        csv_outputs = []

        for t in range(1, 1 + self.T):
            # --- solve dual subproblems
            net_profit = self.obj_coeffs - self.As_transposed @ prices
            xs_t = self.netvalue_to_decision(net_profit)

            # --- construct noisy gradients
            noise = self.noise_var.rvs()
            gradients = self.ngamma * self.b - np.sum([A @ x for A, x in zip(self.As, xs_t)], axis=0)
            noisy_gradients = gradients + noise

            # --- update shadow prices
            prices = self.update_shadow_prices(prices, noisy_gradients)

            xs_one_batch.append(xs_t)
            prices_one_batch.append(prices)

            # --- process the batch & log
            if t % batch_size == 0:
                x_onebatchavg = np.mean(xs_one_batch, axis=0)
                p_onebatchavg = np.mean(prices_one_batch, axis=0)
                x_prefavg = (x_prefavg * (t - batch_size) + np.sum(xs_one_batch, axis=0)) / t
                prices_prefavg = (prices_prefavg * (t - batch_size) + np.sum(prices_one_batch, axis=0)) / t

                prices_all_batches.append(prices_one_batch)
                xs_one_batch, prices_one_batch = [], []

                # - runtime
                runtime = time.time() - t0
                t0 = time.time()

                # --- log
                csv_output = {
                    # exp_setup
                    'vareps': self.vareps, 'delta': self.delta,
                    'm': self.m, 'n': self.n, 's': self.s,
                    'b': self.b, 'gamma': self.gamma, 'ngamma': self.ngamma,
                    'dual_bd': self.dual_bd, 'K': self.K,
                    'p_nonprivate_opt': self.prices_non_private,
                    'objval_nonprivate_opt': self.objval_non_private,
                    # algo_setup
                    'update_rule': self.update_rule, 'K_multiplier': self.K_multiplier, 'T': self.T,
                    'p_init': self.p_init,
                    'private': int(self.private),
                    't': t,
                    # algo perf
                    'runtime': runtime,
                }
                # - evaluate x & p
                x_perf_onebatchavg = self.evaluate_x(x_onebatchavg, output_key_suffix='_onebatchavg')
                p_perf_onebatchavg = self.evaluate_p(p_onebatchavg, output_key_suffix='_onebatchavg')
                csv_output.update(x_perf_onebatchavg)
                csv_output.update(p_perf_onebatchavg)

                x_perf_prefavg = self.evaluate_x(x_prefavg, output_key_suffix='_prefavg')
                p_perf_prefavg = self.evaluate_p(prices_prefavg, output_key_suffix='_prefavg')
                csv_output.update(x_perf_prefavg)
                csv_output.update(p_perf_prefavg)

                csv_outputs.append(csv_output)

                # - print log
                self.print_log(csv_output, suffix=self.log_suffix)     # suffix options: _onebatchavg, _prefavg
                print('prices_onebatchavg:', np.round(p_onebatchavg, 1))
                print('prices_prefavg:', np.round(prices_prefavg, 1))

                # - for suffix averaging
                Ax_b = np.sum([A @ x for A, x in zip(self.As, x_onebatchavg)], axis=0) - self.ngamma * self.b
                Ax_b_all_batches.append(Ax_b)
                objval_private = np.sum(self.obj_coeffs * x_onebatchavg)
                objval_all_batches.append(objval_private)

        self.x_prefavg = x_prefavg
        self.prices_prefavg = prices_prefavg
        csv_outputs[-1].update({'x_prefavg': x_prefavg})

        # --- suffix averaging perf
        # - last iter
        last_prices = prices_all_batches[-1][-1]
        net_profit = self.obj_coeffs - self.As_transposed @ last_prices
        x_last_iter = self.netvalue_to_decision(net_profit)
        x_perf_sufavg = self.evaluate_x(x_last_iter, output_key_suffix='_sufavg')
        p_perf_sufavg = self.evaluate_p(last_prices, output_key_suffix='_sufavg')
        csv_outputs[-1].update(x_perf_sufavg)
        csv_outputs[-1].update(p_perf_sufavg)

        # - suffix averaging
        print('start to evaluate performances of suffix averagings. It may take seconds to mins, depends on T and batch size')
        for k in range(len(Ax_b_all_batches) - 1):
            if k % 10 == 0:
                print(f' --- evaluating suffix averaging: batch {k} --- ')
            objval_sufavg = np.mean(objval_all_batches[-(k+1):], axis=0)
            objval_res = {'objval_private_sufavg': objval_sufavg,
                          'objval_gap%_sufavg': (objval_sufavg - self.objval_non_private) / self.objval_non_private * 100}

            Ax_b_sufavg = np.mean(Ax_b_all_batches[-(k+1):], axis=0)
            Ax_b_res = self.evaluate_Ax_b(Ax_b_sufavg, output_key_suffix='_sufavg')

            p_sufavg = np.mean(np.concatenate(prices_all_batches[-(k+1):]), axis=0)
            p_perf_sufavg = self.evaluate_p(p_sufavg, output_key_suffix='_sufavg')

            csv_outputs[-(k+2)].update(objval_res)
            csv_outputs[-(k+2)].update(Ax_b_res)
            csv_outputs[-(k+2)].update(p_perf_sufavg)

        # --- log
        end_time = datetime.now().strftime('%Y%m%d-%H%M%S-%f')
        df_csv_output = pd.DataFrame.from_records(csv_outputs)
        df_csv_output.to_csv(f'outputs/{self.update_rule}_T={self.T}_eps={self.vareps}_delta={self.delta}_gammaMean={np.mean(self.gamma)}_'
                             f'Kmulti={self.K_multiplier}_{end_time}.csv',
                             index=False)

        return self.x_prefavg, self.prices_prefavg

    def Run(self):
        self.init_for_parallel()
        x, p = self.Solve_x_p()


if __name__ == '__main__':
    # -----
    '''
    Problem options:
    workforce_scheduling_gurobi_demo: a scheduling problem from Gurobi demo
    assignp800:                       assignment problem with 800 agents, 800 services, 10 types of resource
    assignp1500:                      assignment problem with 1500 agents, 1500 services, 10 types of resource
    assignp3000:                      assignment problem with 3000 agents, 3000 services, 10 types of resource
    '''
    import os
    wd = os.getcwd()
    problem = 'workforce_scheduling_gurobi_demo'
    exp_params = read_problem(problem, 0.1, dir='../')
    exp_params['problem'] = problem
    # -----
    n, m, s = exp_params['n'], exp_params['m'], exp_params['s']

    obj_coeffs = np.asarray([request[0] for request in exp_params['requests']])

    dual_bd = np.sum(np.sort(np.max(obj_coeffs, axis=0))[-8:])
    algo_params = {'update_rule': 'md_ne',    # md_ne, md_ne_parameterized, md_l2, projected_gd, gd, projected_gd_positive
                   'dual_bd': dual_bd,
                   'K_multiplier': 1.1,
                   'dual_bd_for_p_init': True,
                   'p_init': None,
                   'T': 10000,
                   'private': 2,        # 0=non-private, 1=our algo, 2=Hsu2017JDP_CvxP, 3=huang2018
                   'batch_size': 100,      # batch size affects logs only.
                   'vareps': 5,
                   'delta': 0.01,

                    # misc params
                   'ftol': 1e-12,
                   'log_suffix': '_prefavg'    # suffix options: _onebatchavg, _prefavg
                   }

    cvx_problem = NoisyDualMD(exp_params, algo_params)
    cvx_problem.init_for_parallel()
    x, p = cvx_problem.Solve_x_p()



