import sympy
import cvxpy as cp
import time
from sympy import symbols, Matrix
import numpy as np
from itertools import product
from sklearn.cluster import KMeans
from collections import defaultdict

import logging
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

class Problem:
    def __init__(self, name, verbose=False):
        self.name = name
        self.potential = None
        self.f_alpha = None
        self.ind2 = None
        self.verbose = verbose

    def form_new_constraints(self, M, y, mul2y, ind4):
        raise NotImplementedError()

    def solve(self):
        # check values
        assert self.f_alpha is not None, f'f_alpha is None'
        ind2 = self.ind2
        if ind2 is None:
            raise ValueError('ind2 is None; did you call generate_basis()?')
        logger.debug(f'Using {ind2.shape[0]} basis functions')

        # construct ind4 from the specified ind2
        ind4 = []
        for id1 in range(ind2.shape[0]):
            for id2 in range(ind2.shape[0]):
                ind4.append(ind2[id1, :] + ind2[id2, :])
        ind4 = np.vstack(ind4)
        ind4 = np.unique(ind4, axis=0)

        # check whether all monoms in f_alpha are contained in ind4
        for monom, coef in self.f_alpha.items():
            monom = np.array(monom)
            assert tuple(monom) in ind4, f'Monom {monom} not in ind4'

        # How ind4 can be divided as combinations of ind2 
        ind42ind2 = defaultdict(list)
        for id1 in range(ind2.shape[0]):
            for id2 in range(ind2.shape[0]):
                ind4_key = tuple((ind2[id1, :] + ind2[id2, :]).tolist())
                ind42ind2[ind4_key].append((tuple((ind2[id1, :]).tolist()), tuple((ind2[id2, :]).tolist())))

        # Map ind2 to index of ind2 
        ind2_index = dict() 
        for i in range(ind2.shape[0]):
            ind2_index[tuple((ind2[i, :]).tolist())] = i

        mul2y, y2mul = {}, {} # Map between Multiindices of prod_basis^* * prod_basis
        Midx2y = {} # Map between the moment matrix M to y

        cnt = 0 
        for i in range(ind4.shape[0]):
            mul2y[tuple(ind4[i, :])] = cnt 
            y2mul[cnt] = tuple(ind4[i, :])
            cnt += 1
        assert len(ind4) == cnt

        for id1 in range(ind2.shape[0]):
            for id2 in range(ind2.shape[0]):
                code1, code2 = ind2[id1, :], ind2[id2, :]
                prod_code = code1 + code2
                Midx2y[(id1, id2)] = mul2y[tuple(prod_code)]

        len_y = len(list(y2mul.keys()))
        assert len_y == ind4.shape[0]

        # define variables
        M = cp.Variable((ind2.shape[0], ind2.shape[0]), symmetric=True)
        y = cp.Variable((len_y))

        # M is psd
        constraints = [M >> 0]

        # M elements are formed from y elements
        row_idx, col_idx, y_idx = [ele[0] for ele in Midx2y], [ele[1] for ele in Midx2y], [Midx2y[ele] for ele in Midx2y]
        constraints += [M[row_idx, col_idx] == y[y_idx]]

        # Add new constraints
        constraints += self.form_new_constraints(M, y, mul2y, ind4)

        # Problem
        # Objective is applied to specific elements of M
        f_alpha_list, y_idx_list = np.array([self.f_alpha[ele] for ele in self.f_alpha]), np.array([mul2y[ele] for ele in self.f_alpha])
        prob = cp.Problem(cp.Minimize(cp.sum(cp.multiply(f_alpha_list, y[y_idx_list]))), constraints)

        t0 = time.time()
        prob.solve(solver=cp.MOSEK, verbose=self.verbose)
        # prob.solve(solver=cp.SCS, verbose=self.verbose)
        sol_time = time.time() - t0

        logger.debug(f'Achieved optimal value: {prob.value:.4f}')
        return prob, M, y, mul2y, ind4, sol_time

class SimpleQuadratic(Problem):
    def __init__(self, config):
        super().__init__('SimpleQuadratic', verbose=config.get('verbose', False))

        self.m2_max_degree = config['m2_max_degree']

        w = sympy.symbols('w')
        x = sympy.symbols('x')
        u = sympy.Matrix([x, w])
        self.potential = (x - w) ** 2 + (w * x) ** 2
        self.potential = sympy.Poly(self.potential, u)
        self.func_n_coeffs = len(self.potential.coeffs())
        logger.debug(f'Potential has {self.func_n_coeffs} non-zero terms')

        self.f_alpha = defaultdict(int)
        for coef, monom in zip(self.potential.coeffs(), self.potential.monoms()):
            self.f_alpha[tuple(monom)] = coef

        # generate basis variables
        self.generate_basis()

    def generate_basis(self):
        # Form the initial m_s(x) basis based on the desired degree
        # generate an iterator of everything up to specified degree
        total_iter = list(product(range(self.m2_max_degree + 1), range(self.m2_max_degree + 1)))

        # filter out the ones that have total degree greater than what we want
        total_iter = list(filter(lambda tup: sum(tup) <= self.m2_max_degree, total_iter))
        ind2 = np.array(total_iter)
        self.ind2 = np.unique(ind2, axis=0).astype(int)
        logger.debug(f'{self.name}: Generated {self.ind2.shape[0]} basis functions')

    def form_new_constraints(self, M, y, mul2y, ind4):
        # form new constraints
        new_constraints = []

        def get_w_mom_value(alpha: tuple):
            """
            Assumes nu(omega) = Uniform(-1, 1)
            """
            assert len(alpha) == 1

            # check if any of alpha's values are odd
            # it can be 1, 3, 5, 7, ... and so on
            # if any of the values are odd, return 0
            mod_two = [alpha_val % 2 for alpha_val in alpha]
            if any(mod_two):
                return 0
            else:
                # all values are even
                assert all([~x for x in mod_two]), f'All values should be even; got {mod_two}'
                s = sum(alpha)
                return 1 / (s + 1)

        # add moment constraints
        for i4, pows in enumerate(ind4):
            if sum(pows[:1]) == 0:
                w_alpha = tuple([int(c) for c in pows[-1:]])
                mom_val = get_w_mom_value(w_alpha)
                logger.debug(f'{self.name}: Adding moment constraint for {pows} with value {mom_val}')
                t = tuple([float(x) for x in pows])
                v = mul2y[t]
                new_constraints += [y[v] == mom_val]
        return new_constraints

class SNL(Problem):
    @staticmethod
    def generate_sensor_anchor_positions(n_sol_dims, n_spatial_dims, n_anchors, radius, seed):
        # generate initial sensor and anchor positions
        # check if all ss and sa distances are valid)
        np.random.seed(seed)
        try_count = 0
        while True:
            x_gt = np.random.uniform(-1, 1, size=(n_sol_dims, n_spatial_dims))
            a_true_np = np.random.uniform(-1, 1, size=(n_anchors, n_spatial_dims))

            pos_var = np.concatenate([x_gt, a_true_np], axis=0)
            D_mat = np.empty((n_sol_dims + n_anchors, n_sol_dims + n_anchors)).astype(object)
            for i in range(n_sol_dims + n_anchors):
                for j in range(n_sol_dims + n_anchors):
                    if i == j:
                        D_mat[i, j] = 0
                    else:
                        D_mat[i, j] = np.sqrt(np.sum((pos_var[i] - pos_var[j]) ** 2))
            if np.mean(D_mat < radius) > 0.50:  # more than half of distances are valid
                break
            try_count += 1
            if try_count > 50:
                raise RuntimeError(f'Could not generate valid sensor positions for radius {radius}')
        assert x_gt.shape == (n_sol_dims, n_spatial_dims)
        assert a_true_np.shape == (n_anchors, n_spatial_dims)
        return x_gt, a_true_np

    @staticmethod
    def do_clustering(x_gt, n_clusters, seed):
        # fit kmeans and report cluster assignments
        kmeans = KMeans(n_clusters=n_clusters, random_state=seed).fit(x_gt)
        x_gt_cluster = kmeans.predict(x_gt)
        logger.debug(f'Cluster assignments: {x_gt_cluster}')

        # hardcoded cluster-cluster connectivity graph
        # use a simple i, i + 1 connectivity graph, i.e.
        # (0, 1), ... (n_clusters - 1, 0)
        cc_adj = np.zeros((n_clusters, n_clusters))
        for i in range(n_clusters):
            for j in range(n_clusters):
                if i == j:
                    cc_adj[i, j] = 1
                    continue
                elif i == j + 1 or i == j - 1 or (i == 0 and j == n_clusters - 1) or (i == n_clusters - 1 and j == 0):
                    cc_adj[i, j] = 1
        return x_gt_cluster, cc_adj

    def __init__(self, config):
        super().__init__('SNL', verbose=config.get('verbose', False))

        self.n_sol_dims = config['n_sol_dims']
        self.n_spatial_dims = config['n_spatial_dims']
        self.n_noise_dims = config['n_noise_dims']

        self.seed = config['seed']
        self.radius = config['radius']
        self.eps = config['eps']
        self.n_perturb = config['n_perturb_per_ndim']
        self.noise_model = config['noise_model']
        self.use_hard_eq = config['use_hard_eq']
        self.n_hard_eq_constr = config['n_hard_eq_constr']
        self.n_anchors = self.n_spatial_dims + 1

        # form sympy variables
        self.x = Matrix([[
            symbols('x{}{}'.format(i+1, j+1)) for j in range(self.n_spatial_dims)
        ] for i in range(self.n_sol_dims)])
        self.w = sympy.symbols('w:{}'.format(self.n_noise_dims))
        self.u_vec = list(np.array(self.x).flatten()) + list(np.array(self.w).flatten())

        # generate initial sensor and anchor positions
        self.x_gt, self.a_true_np = SNL.generate_sensor_anchor_positions(
            self.n_sol_dims, self.n_spatial_dims, self.n_anchors, self.radius, self.seed
        )

        # generate cluster assignments and cluster-cluster connectivity graph
        # assign each sensor to one cluster
        self.use_cluster_basis = config['use_cluster_basis']
        self.n_clusters = config['n_clusters']
        if not self.use_cluster_basis and self.n_clusters > 1:
            logger.warning('Ignoring n_clusters > 1; use_cluster_basis is False')
            self.n_clusters = 1
        self.x_gt_cluster, self.cc_adj = SNL.do_clustering(self.x_gt, self.n_clusters, self.seed)

        # compute sensor-sensor distance matrix (variables)
        pos_var = np.concatenate([np.array(self.x), self.a_true_np], axis=0)
        D_mat = np.empty((self.n_sol_dims + self.n_anchors, self.n_sol_dims + self.n_anchors)).astype(object)
        for i in range(self.n_sol_dims + self.n_anchors):
            for j in range(self.n_sol_dims + self.n_anchors):
                if i == j:
                    D_mat[i, j] = 0
                else:
                    D_mat[i, j] = sum((pos_var[i] - pos_var[j]) ** 2)
        assert D_mat.dtype == object

        # compute sensor-sensor distance matrix (actual values)
        true_D_ss = np.zeros((self.n_sol_dims, self.n_sol_dims))
        for s_i in range(self.n_sol_dims):
            for s_j in range(self.n_sol_dims):
                if s_i == s_j:
                    continue
                else:
                    true_D_ss[s_i, s_j] = np.linalg.norm(self.x_gt[s_i] - self.x_gt[s_j], ord=2) ** 2
        assert true_D_ss.dtype == np.float64

        # compute sensor-anchor distance matrix
        true_D_sa = np.zeros((self.n_sol_dims, self.n_anchors))
        for s_i in range(self.n_sol_dims):
            for a_j in range(self.n_anchors):
                true_D_sa[s_i, a_j] = np.linalg.norm(self.x_gt[s_i] - self.a_true_np[a_j], ord=2) ** 2
        assert true_D_sa.dtype == np.float64

        # distance matrix as a function of sensor and anchor positions
        D_ss = D_mat[:self.n_sol_dims, :self.n_sol_dims]
        D_sa = D_mat[:self.n_sol_dims, self.n_sol_dims:]
        assert D_ss.shape == true_D_ss.shape
        assert D_sa.shape == true_D_sa.shape

        # mask out some sensor-sensor and sensor-anchor terms
        # if they are outside specified sensing radius
        mask_D_ss = (true_D_ss < self.radius).astype(np.float32)
        mask_D_sa = (true_D_sa < self.radius).astype(np.float32)
        logger.debug(f'Keeping {100 * np.mean(mask_D_ss):.2f}% of sensor-sensor terms')
        logger.debug(f'Keeping {100 * np.mean(mask_D_sa):.2f}% of sensor-anchor terms')

        # mask out any sensor-sensor interactions across two different clusters c_i, c_j where c_i \neq c_j
        # this should reduce the number of edges in the graph
        # and also the number of terms in the polynomial potential.
        total_cluster_mask_D_ss = np.zeros_like(mask_D_ss)
        for c_i in range(self.n_clusters):
            s_inds = np.where(self.x_gt_cluster == c_i)[0]

            # set all intracluster indices to 1
            # i.e. all sensors within a cluster are connected by an edge.
            cluster_mask = np.zeros_like(true_D_ss).astype(bool)
            inds_to_set = product(s_inds, s_inds)
            for s_i, s_j in inds_to_set:
                cluster_mask[s_i, s_j] = True

            # set only the intercluster indices where (i, j) is present in cluster-cluster connectivity graph
            # i.e. cluster c_i, c_j are sparsely connected.
            for c_j in range(self.n_clusters):
                if c_i == c_j:
                    continue
                if self.cc_adj[c_i, c_j] == 1:
                    s_inds_i = np.where(self.x_gt_cluster == c_i)[0]
                    s_inds_j = np.where(self.x_gt_cluster == c_j)[0]
                    inds_to_set = product(s_inds_i, s_inds_j)
                    for s_i, s_j in inds_to_set:
                        cluster_mask[s_i, s_j] = True

            total_cluster_mask_D_ss = np.logical_or(total_cluster_mask_D_ss, cluster_mask)

            # each w_i only perturbs edges within cluster c_i
            assert self.noise_model in ['linear', 'outlier']
            np.random.seed(self.seed)
            if self.noise_model == 'linear':
                ii, jj = np.where(cluster_mask)
                n_ind = np.random.choice(len(ii), size=self.n_perturb, replace=False)

                # NOTE: each w_i perturbs n_perturb edges (w_i, x_k) isolated to a given cluster.
                D_ss[ii[n_ind], jj[n_ind]] += self.eps * self.w[c_i]
            else:
                raise NotImplementedError()

        potential = 0.0
        if self.use_cluster_basis:
            logger.debug('Using cluster basis; only keeping sensor-sensor interactions within a cluster or across adjacent clusters')
            potential += np.sum(np.sum((total_cluster_mask_D_ss * mask_D_ss) * np.square(true_D_ss - D_ss)))
        else:
            potential += np.sum(np.sum((np.ones_like(total_cluster_mask_D_ss) * mask_D_ss) * np.square(true_D_ss - D_ss)))

        if self.use_hard_eq:
            logger.debug('Using hard-equality constraints; only keeping sensor-sensor interactions')
        else:
            potential += np.sum(mask_D_sa * np.square(true_D_sa - D_sa))

        potential = sympy.Poly(potential, self.u_vec)
        self.potential = sympy.simplify(potential)
        self.func_n_coeffs = len(self.potential.coeffs())
        logger.debug(f'Potential has {self.func_n_coeffs} non-zero terms')

        self.f_alpha = defaultdict(int)
        for coef, monom in zip(self.potential.coeffs(), self.potential.monoms()):
            self.f_alpha[tuple(monom)] = coef

        # generate basis variables
        self.generate_basis()

    def generate_basis(self):
        # the number of dimensions to use within this cluster
        # use all noise dims for each cluster
        d = self.n_sol_dims * self.n_spatial_dims + self.n_noise_dims
        ind1 = np.vstack([np.zeros(d), np.eye(d)])

        ind2 = []
        for id1 in range(ind1.shape[0]):
            for id2 in range(ind1.shape[0]):
                ind2.append(ind1[id1, :] + ind1[id2, :])

        ind2 = np.vstack(ind2)
        ind2 = np.unique(ind2, axis=0)
        full_ind2 = ind2

        # label the sensor and noise dimensions
        labels = {}
        for s_i in range(self.n_sol_dims):
            c_k = self.x_gt_cluster[s_i]
            for d_i in range(self.n_spatial_dims):
                k = s_i * self.n_spatial_dims + d_i
                labels[k] = {
                    'type': 'sensor',
                    'cluster': c_k,
                }
        for w_i in range(self.n_noise_dims):
            k = self.n_sol_dims * self.n_spatial_dims + w_i
            labels[k] = {
                'type': 'noise',
                'cluster': w_i,
            }

        # assert that all dimensions are labeled
        assert set(labels.keys()) == set(range(d))

        # build up ind2 from scratch
        # in a way that is consistent with the cluster basis
        # eg
        # only [x_i, x_j] terms where (i, j) belong to the same cluster c_k
        # only [x_i, x_j] terms where (i, j) is in the c_i, c_j edge map
        # only [x_i, w_j] terms where j belongs to cluster c_j
        ind2 = [
            np.zeros(d),
        ]
        for id1 in range(ind1.shape[0]):
            for id2 in range(ind1.shape[0]):
                i2_cand = ind1[id1, :] + ind1[id2, :]
                set_inds = set(np.where(i2_cand)[0])
                cluster_labels = set([labels[i]['cluster'] for i in set_inds])

                # handles (xi, xj) in same cluster and (xi, wi) in same cluster
                is_same_cluster = len(cluster_labels) == 1

                # handle (xi, xj) in adjacent clusters
                if len(cluster_labels) == 2:
                    c_i, c_j = tuple(cluster_labels)
                    is_cc_map = self.cc_adj[c_i, c_j] == 1
                else:
                    is_cc_map = False

                if is_same_cluster or is_cc_map:
                    ind2.append(ind1[id1, :] + ind1[id2, :])

        ind2 = np.vstack(ind2)
        ind2 = np.unique(ind2, axis=0)

        if self.n_clusters == 1:
            assert len(full_ind2) == len(ind2), f'Full basis {full_ind2.shape} does not match cluster basis {ind2.shape}'
        self.ind2 = ind2.astype(int)

    def form_new_constraints(self, M, y, mul2y, ind4):
        # form new constraints
        new_constraints = []

        def get_w_mom_value(alpha: tuple):
            """
            Assumes nu(omega) = Uniform(-1, 1)
            """
            assert len(alpha) == self.n_noise_dims

            # check if any of alpha's values are odd
            # it can be 1, 3, 5, 7, ... and so on
            # if any of the values are odd, return 0
            mod_two = [alpha_val % 2 for alpha_val in alpha]
            if any(mod_two):
                return 0
            else:
                # all values are even
                assert all([~x for x in mod_two]), f'All values should be even; got {mod_two}'
                s = sum(alpha)
                return 1 / (s + 1)

        # add moment constraints
        for i4, pows in enumerate(ind4):
            if sum(pows[:self.n_sol_dims * self.n_spatial_dims]) == 0:
                w_alpha = tuple([int(c) for c in pows[-self.n_noise_dims:]])
                mom_val = get_w_mom_value(w_alpha)
                t = tuple([float(x) for x in pows])
                v = mul2y[t]
                new_constraints += [y[v] == mom_val]

        # add hard-equality constraints
        n_hard_constr = self.n_hard_eq_constr * self.use_hard_eq
        assert self.n_hard_eq_constr <= self.n_sol_dims, f'Cannot add more than {self.n_sol_dims} hard constraints'
        np.random.seed(self.seed)
        hard_sensor_inds = np.random.choice(self.n_sol_dims, size=self.n_hard_eq_constr, replace=False)
        logger.debug(f'Adding {n_hard_constr} hard-equality constraints')
        logger.debug(f'Sensors with hard-equality constraints: {hard_sensor_inds}')

        n_hard_constr_added = 0
        for s_i in hard_sensor_inds:
            # add two constraint (dx and dy) for sensor to each anchor
            anc = 0
            for sd in range(self.n_spatial_dims):
                h_j = 0.0

                k = s_i * self.n_spatial_dims + sd
                non_k = np.array([i for i in range(ind4.shape[1]) if i != k])

                # x_i_s - a_k_s
                x_is_one = ind4[:, k] == 1
                non_x_is_zero = ind4[:, non_k].sum(axis=1) == 0
                mask = x_is_one & non_x_is_zero
                i4 = np.where(mask)[0][0]
                pows = ind4[i4]
                v = mul2y[tuple([float(x) for x in pows])]
                h_j += y[v] - self.a_true_np[anc, sd]

                # d_ik
                d_ik = self.x_gt[s_i, sd] - self.a_true_np[anc, sd]
                h_j -= d_ik

                new_constraints.append(h_j == 0.0)
                n_hard_constr_added += 1

            # we want to set Var[X_i] = 0 as well
            # E[X_i^2] - E[X_i]^2 = 0
            for sd in range(self.n_spatial_dims):
                h_j = 0.0

                k = s_i * self.n_spatial_dims + sd
                non_k = np.array([i for i in range(ind4.shape[1]) if i != k])

                x_is_two = ind4[:, k] == 2
                non_x_is_zero = ind4[:, non_k].sum(axis=1) == 0
                mask = x_is_two & non_x_is_zero
                i4 = np.where(mask)[0][0]
                pows = ind4[i4]
                v = mul2y[tuple([float(x) for x in pows])]
                h_j += y[v]

                x_is_one = ind4[:, k] == 1
                non_x_is_zero = ind4[:, non_k].sum(axis=1) == 0
                mask = x_is_one & non_x_is_zero
                i4 = np.where(mask)[0][0]
                pows = ind4[i4]
                v = mul2y[tuple([float(x) for x in pows])]
                h_j -= self.x_gt[s_i, sd] ** 2

                new_constraints.append(h_j == 0)
                n_hard_constr_added += 1

            logger.debug(f'Added hard constraint for sensor {s_i} to anchor {anc}')
            logger.debug(f'True sensor position: {self.x_gt[s_i]}')
        return new_constraints

if __name__ == '__main__':
    config = {
        # fixed parameters
        'n_sol_dims': 5,
        'n_spatial_dims': 2,

        'n_noise_dims': 1,
        'n_clusters': 1,

        # noise model
        # linear: perturbs n ss and sa edges
        # outlier: sets n ss and sa edges to Uniform [0, 2]
        'seed': 42,
        'noise_model': 'linear',
        'eps': 0.100,
        'n_perturb_per_ndim': 1,

        'radius': 4.000,

        'use_hard_eq': True,
        'n_hard_eq_constr': 2,

        # CVXPY params
        'verbose': True,

        # S-SOS params

        # BFGS params
        # 'bfgs_samples': 5,
        'use_cluster_basis': True,
    }
    # snl_problem = SNL(config)
    # prob, M, y, mul2y, ind4, sol_time = snl_problem.solve()

    config = {
        'm2_max_degree': 6,
        'verbose': True,
    }
    quad_problem = SimpleQuadratic(config)
    prob, M, y, mul2y, ind4, sol_time = quad_problem.solve()

    # test n_clusters == 1, use_cluster_basis == True

    # test n_clusters == 1, use_cluster_basis == False

    # test n_clusters > 1, use_cluster_basis == True

    # test n_clusters > 1, use_cluster_basis == False
