import numba
import numpy as np
from scipy.spatial import distance_matrix
from sklearn.base import ClassifierMixin, BaseEstimator

from active_ranking import config


class UMessyRank:
    def __init__(self, empirical_mean, parameter, ucb, lcb, k) -> None:
        self.mu = np.array(empirical_mean, dtype='float16')
        self.ucb = ucb
        self.lcb = lcb
        self.k = k
        self.compute_delta()
        mat_crit = self.delta.reshape(-1, 1) * np.ones((len(ucb), len(ucb)))

        self.criterion = distance_matrix(
            self.mu.reshape(-1, 1),
            self.mu.reshape(-1, 1), p=1) < parameter * mat_crit
        diag = ~ np.diag(np.ones_like(self.mu)).astype(bool)
        self.criterion = self.criterion & diag
        self.compute_card_u()

    def compute_delta(self, local=config.local_delta):
        if not local:
            self.delta = max(self.ucb - self.lcb)
        else:
            self.delta = self.ucb - self.lcb

    def compute_card_u(self):
        self.card = np.array(
            [len(np.where(self.criterion[i])[0]) for i in range(len(self.mu))])

    def q(self, p, epsilon):
        left = self.delta * self.card
        right = self.k * epsilon * p * (1 - self.mu)
        return np.where(left <= right)[0]


@numba.njit
def create_tuple_of_disjoint_cells(li):  # W in article
    ret = []
    c_li = [elt.split("+") for elt in li]
    c_list = [f for elt in c_li for f in elt]
    pos_list = [i for i in range(len(c_li)) for _ in c_li[i]]

    # TODO improve
    for ii, i in enumerate(c_list):
        for jj, j in enumerate(c_list):
            if pos_list[ii] < pos_list[jj]:
                ret.append((i, j))
                ret.append((j, i))
    return ret


def get_tuple_ij(arr1: np.array, i: str, j: str):  # U in article
    is_in_i = np.array([i in elt for elt in arr1])
    is_in_j = np.array([j in elt for elt in arr1])
    m_list = arr1[is_in_i]
    n_list = arr1[is_in_j]

    n_list = np.array([elt.split("+") for elt in n_list])[0]
    m_list = np.array([elt.split("+") for elt in m_list])[0]
    ret = []

    # TODO improve
    for j in m_list:
        for i in n_list:
            ret.append((i, j))
            ret.append((j, i))
    return ret


@numba.njit
def _compute_active_set(ucb, lcb, w, version: str = "max"):
    max_: numba.int16 = 0
    idx = (0, 0)
    for i, j in w:
        d = ucb[i] - lcb[j]
        r = lcb[i] - ucb[j]
        cond = d > max_
        if cond and (d > 0) and (r < 0):
            max_ = d
            idx = (i, j)
    return idx


def test_active_set():
    _compute_active_set(
        ucb=np.array([0, 1, 0]),
        lcb=np.array([0, 0, 0]),
        w=np.array([(0, 1)]))


@numba.njit
def estimate(x: np.array, y: np.array):
    if len(y) == 0:
        return 0
    return np.sum(y) / len(y)


class BasicEstimator(ClassifierMixin, BaseEstimator):
    def __init__(self, j, d):
        from active_ranking.base.space import Partition
        self.j = j
        self.d = d
        self.__partition = Partition(self.j, self.d)
        self.__keys = list(self.__partition.p_cells.keys())

    def fit(self, X, y):
        X = np.array(X)
        y = np.array(y).astype(float)
        self.__partition.add_labels(X, y)
        for p, cell in self.__partition.p_cells.items():
            cell.set_value(estimate)

    def __helper_find_x(self, x):
        return self.__partition.find_cells(x)

    def predict(self, X):
        ret = self.predict_proba(X)[:, 1] > 0.5
        return ret

    def predict_proba(self, X):
        indexes = self.__helper_find_x(X)
        ret_p = np.array(
            [self.__partition.p_cells[self.__keys[i]].value for i in indexes])
        ret_m = np.array(
            [1 - self.__partition.p_cells[self.__keys[i]].value for i in
             indexes])
        ret = np.zeros((len(X), 2))
        ret[:, 0] = ret_m
        ret[:, 1] = ret_p
        return ret


@numba.njit
def _grid_as_function_helper(grid, x):
    ret = np.zeros(x.shape, dtype=np.int64)
    for d in range(len(grid.shape)):
        s: int = grid.shape[d]
        j_d = np.array([i / s for i in range(s)])
        search = np.searchsorted(j_d, x[:, d]) - 1
        ret[:, d] = search
    return ret


class GridAsFunction:

    def __init__(self, grid: np.array):
        self.grid = grid

    def __call__(self, x):
        ret = _grid_as_function_helper(self.grid, x)
        return [self.grid[tuple(coord.astype(int))] for coord in ret]


class FunctionAsLabeler:

    def __init__(self, func: callable):
        self.fun = func
        self.add_rand()

    def add_rand(self):
        self.rand = np.random.uniform(size=int(1e6))
        self.i = 0

    def __call__(self, x):
        if self.i + len(x) > len(self.rand):
            self.add_rand()
        rand = self.rand[self.i:self.i + len(x)]
        self.i += len(x)
        eta_ = self.fun(x)
        label = rand < eta_
        return label


def test_estimator():
    be = BasicEstimator(3, 2)
    X = np.array([[0.001, 0.001]])
    y = [True]
    be.fit(X, y)
    assert be.predict(X) == 1


def test_grid_as_function():
    _grid_as_function_helper(np.array([[0.1, 0.5], [0.1, 0.8]]),
                             np.array([[0.2, 0.6]] * 50))


if __name__ == '__main__':
    test_grid_as_function()
