from abc import ABC

import matplotlib.pyplot as plt
import numpy as np
from matplotlib.patches import Rectangle

from active_ranking import config, utils
from active_ranking.base import ucb_lcb
from active_ranking.base.sampler import sampler_d
from active_ranking.base.space import Partition
from active_ranking.base.utils import GridAsFunction
from active_ranking.base.utils import estimate
from active_ranking.metrics.roc import ROC, infinity_norm, norm_one


class Learner(ABC):
    def __init__(self, j_max, d, eta):
        self.partition = Partition(j_max, d)
        self.estimates = {}
        self.ucb = {}
        self.lcb = {}
        self.eta = eta
        self.x_test = sampler_d(10000, self.partition.d)
        self.y_test = self.eta(self.x_test)
        self.predictions = {}
        self.grid = np.zeros((2 ** self.partition.j_max,) * self.partition.d)
        self.n_sample = []
        self.norm_infinity = []
        self.norm_one = []
        # notation as arm problem : the number of cells
        self.K = int(2 ** (j_max * d))
        self._break = False
        self.epsilon = config.epsilon
        self.active_execution_time = False
        self.__p_estimation = False
        self.t = 1

    def activate_p_estimation(self):
        self.__p_estimation = True

    def sample(self, n):
        ...

    def compute_active_set(self):
        pass

    def stopping(self):
        self.stop = False

    def merge_cells(self):
        pass

    def init_alg(self, n_0, n_max):
        self.n_max = n_max
        X = sampler_d(n_0, d=self.partition.d)
        y = self.eta(X)
        self.add_labels(X, y)
        self.compute_active_set()
        self.scoring_function()
        if self.true_eta is not None:
            self.regret()

    def one_step(self, n_max):
        self.partition.ini_step()
        X = self.sample(1)
        y = self.eta(X)
        self.add_labels(X, y)
        self.compute_active_set()
        self.merge_cells()
        self.stopping()
        if self.partition.step % 5 == 0:
            self.scoring_function()
            if self.true_eta is not None:
                self.regret()
        if self.stop or self.partition.x.shape[0] > n_max:
            self._break = True

    def run(self, n_0, n_max):
        self.init_alg(n_0, n_max)
        while True:
            self.one_step(n_max)
            if self._break:
                break

    def pursue(self, n):
        n_0 = self.partition.x.shape[0]
        while True:

            X = self.sample(1)
            y = self.eta(X)
            self.add_labels(X, y)
            self.compute_active_set()
            self.merge_cells()
            self.stopping()
            self.scoring_function()
            if self.true_eta is not None:
                self.regret()
            if self.stop:
                break
            if self.partition.x.shape[0] > n + n_0:
                break

    def update_t(self):
        self.t = self.partition.x.shape[0] / self.K

    @utils.execution_time
    def add_labels(self, X, y):
        self.partition.add_labels(X, y)
        _beta = ucb_lcb.beta(self.t, config.c, config.delta, self.K)
        for i, c in self.partition.p_cells.items():
            c.set_value(estimate, _beta)
            self.estimates[i] = c.value
            self.ucb[i] = c.ucb
            self.lcb[i] = c.lcb
        self.bounds = np.array((
            list(self.ucb.values()),
            list(self.lcb.values())
        )).T

    @utils.execution_time
    def scoring_function(self):
        _beta = ucb_lcb.beta(self.t, config.c, config.delta, self.K)
        for c in self.partition.p_cells.values():
            idx = self.partition.positions[c.id]
            c_cells = self.partition.current_cells()
            cell = [elt for elt in c_cells.keys() if c.id in elt][0]
            c_cells[cell].set_value(estimate, _beta)
            self.grid[tuple(idx)] = c_cells[cell].value
            self.grid[tuple(idx)] = c.value
        self.score = GridAsFunction(self.grid)
        self.predictions[self.partition.step] = np.array(
            self.score(self.x_test))

    def print_state(self):
        print("=" * 50)
        print(f"step : {self.partition.step}")
        print(self.partition.current_cells())
        print(self.partition.p_cells)

    def regret(self):
        ...


class LearnerAnalyser(Learner):
    __stopping_time_d1__ = 0
    __stopping_time_d_inf__ = 0

    def __init__(self, j_max, d, eta):
        super().__init__(j_max, d, eta)

    def add_tracker(self, eta_real):
        self.true_eta = eta_real
        self.roc_true = ROC(
            y_true=self.y_test,
            y_prediction=self.true_eta(self.x_test)
        )

    def update(self):
        self.ret_p_cells = self.estimation(
            self.partition.p_cells,
            eta_real=self.true_eta)
        self.ret_c_cells = self.estimation(
            self.partition.current_cells(),
            eta_real=self.true_eta)

    @utils.execution_time
    def regret(self):
        if self.true_eta is None:
            print("Regret cannot be compute without knowing eta")
            return None
        roc = ROC(
            y_true=self.y_test,
            y_prediction=list(self.predictions.values())[-1])

        self.n_sample.append(len(self.partition.x))
        self.norm_infinity.append(infinity_norm(self.roc_true, roc))
        self.norm_one.append(norm_one(self.roc_true, roc))

        if self.norm_infinity[
            -1] < config.epsilon and self.__stopping_time_d_inf__ != 0:
            self.__stopping_time_d_inf__ = self.n_sample[-1]

        if self.norm_one[
            -1] < config.epsilon and self.__stopping_time_d1__ != 0:
            self.__stopping_time_d1__ = self.n_sample[-1]

    def estimation(self, list_cells: dict, eta_real=None):
        estimates = {}
        ucb = {}
        lcb = {}
        ranked_count = {}
        for i, c in list_cells.items():

            for e in i.split("+"):
                estimates[e] = c.value
                ucb[e] = c.ucb
                lcb[e] = c.lcb
                ranked_count[e] = c.n
        indexes = np.array(list(lcb.keys()))
        ranked_estimates = {k: v for k, v in
                            sorted(estimates.items(), key=lambda item: item[1])}
        ranked_ucb = {k: ucb[k] for k in ranked_estimates.keys()}
        ranked_lcb = {k: lcb[k] for k in ranked_estimates.keys()}
        rank = np.array(list(ranked_estimates.keys()))
        ranked_count = {k: ranked_count[k] for k in ranked_estimates.keys()}

        eta = []
        if eta_real is not None:

            for k in ranked_lcb.keys():
                eta.append(eta_real(
                    np.array([self.partition.p_cells[k].centers[0]])))
        ret = {
            "estimates": estimates,
            "lcb": lcb,
            "ranked_ucb": ranked_ucb,
            "ranked_lcb": ranked_lcb,
            "ranked_estimates": ranked_estimates,
            "rank": rank,
            "ranked_count": ranked_count,
            "indexes": indexes,
            "eta": eta
        }

        return ret

    def plot_n_sample(self):
        self.update()
        plt.figure(figsize=(6, 6), dpi=200)
        plt.scatter(self.partition.x[:, 0],
                    self.partition.x[:, 1],
                    c=range(self.partition.x.shape[0]),
                    cmap="rainbow_r")

    def plot_ucb_lcb(self, partition_type="c_cells", plot_nb_points=True,
                     shades=True, plot_active_cell=False,
                     plot_sampled_cell=False):
        self.update()
        if partition_type == "c_cells":
            res = self.ret_c_cells
        else:
            res = self.ret_p_cells
        fig = plt.figure(dpi=200)
        args = dict(where="mid")
        r = range(len(res["ranked_lcb"].values()))

        ax = plt.gca()

        def steps(array):
            return np.array([i for i in array for _ in range(2)])

        lcb = np.array(list(res["ranked_lcb"].values()))
        mu = np.array(list(res["ranked_estimates"].values()))
        ucb = np.array(list(res["ranked_ucb"].values()))

        r_plot = np.concatenate(([0], steps(r)[:-1] + 1)) - 0.5
        mu_plot = steps(mu)
        lcb_plot = steps(lcb)
        ucb_plot = steps(ucb)
        if shades:
            plt.fill_between(r_plot, lcb_plot, ucb_plot, zorder=20,
                             label="$(LCB, UCB)$", alpha=0.2)

            plt.plot(r_plot, mu_plot, zorder=0, label="$\mu$")
        else:
            plt.plot(r_plot, lcb_plot, zorder=20, label="$LCB$")
            plt.plot(r_plot, mu_plot, zorder=10, label="$\\mu$")
            plt.plot(r_plot, ucb_plot, zorder=0, label="$UCB$")

        ax.set_xticks(r, res["ranked_lcb"].keys())

        if res["eta"]:
            eta_plot = np.array([i for i in res["eta"] for _ in range(2)])
            plt.step(r_plot, eta_plot, color="gray", zorder=20, label="$\eta$",
                     ls="--")

        if plot_nb_points:
            ax2 = ax.twinx()
            rc = res["ranked_count"].values()
            c = config.colors[1]
            ax2.step(r_plot, steps(rc), lw=0.9, zorder=0,
                     color=c, **args)
            ax2.grid(False)
            ax2.set_ylabel("Number of sample", color=c)

        if plot_active_cell:
            if hasattr(self, "ij"):
                active_list = self.ij
            elif hasattr(self, "active_set"):
                active_list = self.active_set

            else:
                active_list = []
            for i in active_list:
                loc_i = np.where(self.ret_p_cells["rank"] == i)[0][0]
                label = None if i != active_list[0] else "active set"
                ax.add_patch(
                    Rectangle((r_plot[2 * loc_i], -0.05),
                              r_plot[2 * loc_i + 1] - r_plot[2 * loc_i], 0.05,
                              facecolor=config.colors[2], alpha=.4,
                              zorder=-10, lw=0, label=label
                              ))
        plt.sca(ax)
        ax.legend(loc='upper center', bbox_to_anchor=(0.5, 1.25),
                  ncol=2)
        if len(self.partition.p_cells) > 16:
            ax.set_xticklabels([])
            ax.set_xlabel("Cell index")
        plt.grid(linewidth=0.2)
        plt.tight_layout()

    def check_p_cell(self, i):
        self.update()

        plt.figure(figsize=(6, 6), dpi=200)
        plt.scatter(self.partition.p_cells[i].x[:, 0],
                    self.partition.p_cells[i].x[:, 1])
        plt.scatter(self.partition.p_cells[i].centers[0][0],
                    self.partition.p_cells[i].centers[0][1])

        self.true_eta(self.partition.p_cells[i].x)
        self.partition.p_cells[i].plot_rectangle()
        plt.xlim((0, 1))
        plt.ylim((0, 1))

        plt.figure(figsize=(6, 6), dpi=200)
        plt.scatter(self.partition.p_cells[i].x[:, 0],
                    self.partition.p_cells[i].x[:, 1],
                    c=self.partition.p_cells[i].y, cmap="coolwarm")

        self.true_eta(self.partition.p_cells[i].x)
        self.partition.p_cells[i].plot_rectangle(linewidth=1,
                                                 facecolor="none")
        plt.xlim((0, 1))
        plt.ylim((0, 1))

    def run_and_plot(self, n_0, n_max, step_mod=10, starting_step=0):
        self.init_alg(n_0, n_max)
        while True:
            self.one_step(n_max)
            if self.partition.step > starting_step and (
                    self.partition.step % step_mod == 0):
                self.plot_ucb_lcb(plot_active_cell=True)
                plt.title(f"step : {self.partition.step}")
            if self._break:
                break
