from __future__ import division, print_function, absolute_import

import matplotlib

matplotlib.rcParams["pdf.fonttype"] = 42
matplotlib.rcParams["ps.fonttype"] = 42
matplotlib.rcParams["font.family"] = "serif"
matplotlib.rcParams["font.size"] = 10

# Grid resolution
RES = 200

import sys

sys.stdout = sys.stderr

import numpy as np


def map_simplex_to_cartesian(simplex_point):
    """
    Map a point from standard 2-simplex with vertices: (0,0,1), (0,1,0), (1,0,0),
    to cartesian triangle: triangle(ABC), A=(0,0), B=(1,0), C=(0,1)
    """
    return simplex_point[1:]


def map_cartesian_to_simplex(cartesian_point, res=RES):
    """
    Map a point from cartesian triangle (0,0), (0,1), (1,0) to stanard 2-simplex.
    Cartesian point represented as (x,y)
    """
    cartesian_point = np.array(cartesian_point)
    simplex_point = np.concatenate([
        np.array([res - sum(cartesian_point)]) / res,
        cartesian_point / res
    ])
    return simplex_point


# A class that defines a region in the solution simplex with a center, and coordinate mappings between the coordinates simplex <-> cartesian
class SimplexRegion:
    """
    Simplex Region/Clustering Region class. Is defined through its cluster center in both simplex & cartesian coordinates. Can sample from its region given certain rho/cluster coverage parameter.
    """

    def __init__(
        self,
        region_id,
        center,
        rho
    ):
        self.region_id = region_id
        self.center_simplex = center
        self.center_cartesian = map_simplex_to_cartesian(self.center_simplex)
        # Radius of the ball to sample around from
        self.rho = rho

        # Presample alphas to accalerate training
        sampled_alphas = sample_L1_ball(self.center_simplex, self.rho, 2000)
        # Normalization, IMPORTANT!
        self.sampled_alphas = sampled_alphas / sampled_alphas.sum(axis=1).reshape(-1, 1)
        self.alphas_cartesian = [map_simplex_to_cartesian(alpha) for alpha in self.sampled_alphas]

    def get_client_subregion(self):
        return self.sampled_alphas, self.alphas_cartesian

    def get_region_center_simplex(self):
        return self.center_simplex

    def get_region_center_cartesian(self):
        return self.center_cartesian

class SolutionSimplex:
    """
    Solution simplex that keeps track of all simplex regions.
    """

    def __init__(self, cfg, xp_name):
        self.cfg = cfg
        self.experiment_name = xp_name

    def set_solution_simplex_regions(
        self, projected_points, rho
    ):
        self.simplex_regions = _compute_solution_simplex_(
            projected_points=projected_points,
            rho=rho
        )
        self.rho = rho
        self.client_to_simplex_region_mapping = {}
        for i, simplex_region in enumerate(self.simplex_regions):
            self.client_to_simplex_region_mapping[i] = simplex_region.region_id
    
    def get_client_subregion(self, client_id):
        # Get correct simplex region
        sampled_region_id = self.client_to_simplex_region_mapping[client_id]
        client_simplex_region = self.simplex_regions[sampled_region_id]
        sampled_alpha_simplex, sampled_alpha_cartesian = client_simplex_region.get_client_subregion()
        return sampled_alpha_simplex, sampled_alpha_cartesian, sampled_region_id

    def sample_uniform(self, client_id):
        # Get correct simplex region
        sampled_region_id = client_id
        alpha = np.random.exponential(scale=1.0, size=(100000, self.cfg.rule.num_points))
        sampled_alpha_simplex = alpha / alpha.sum(1).reshape(-1, 1)
        sampled_alpha_cartesian = [map_simplex_to_cartesian(alpha) for alpha in sampled_alpha_simplex]
        return sampled_alpha_simplex, sampled_alpha_cartesian, sampled_region_id

    def get_client_center(self, client_id):
        sampled_region_id = self.client_to_simplex_region_mapping[client_id]
        alpha_simplex = self.simplex_regions[sampled_region_id].get_region_center_simplex()
        alpha_cartesian = self.simplex_regions[sampled_region_id].get_region_center_cartesian()
        return alpha_simplex, alpha_cartesian, sampled_region_id

    def get_simplex_region_centers_cartesian(self):
        return [simplex_region.get_region_center_cartesian() for simplex_region in self.simplex_regions]

    def get_simplex_region_centers_simplex(self):
        return [simplex_region.get_region_center_simplex() for simplex_region in self.simplex_regions]

def _compute_solution_simplex_(projected_points, rho):
    simplex_regions = []
    for i, tmp_center in enumerate(projected_points):
        simplex_region = SimplexRegion(
            region_id=i,
            center=tmp_center,
            rho=rho
        )
        simplex_regions.append(simplex_region)
    return simplex_regions

def sample_L1_ball(center, radius, num_samples):
    dim = len(center)
    samples = np.zeros((num_samples, dim))
    for i in range(num_samples):
        # Generate a point on the surface of the L1 unit ball
        u = np.random.uniform(-1, 1, dim)
        u = np.sign(u) * (np.abs(u) / np.sum(np.abs(u)))
        # Scale the point to fit within the radius
        r = np.random.uniform(0, radius)
        samples[i] = center + r * u
    return samples
