import numpy as np
import pandas as pd
from scipy.spatial import distance_matrix

from active_ranking import config
from active_ranking.base.ucb_lcb import kl_bernoulli

x = np.linspace(0, 100, 3000)


def problem_complexity(eta_grid):
    return __Complexity(eta_grid).H_i


class __Complexity:

    def __init__(self, eta_grid):
        levels = np.sort(np.ravel(eta_grid))
        K = len(levels)
        self.dist = distance_matrix(levels.reshape(-1, 1),
                                    levels.reshape(-1, 1), p=1)
        self.right = K * config.epsilon * (1 - levels) * np.mean(eta_grid)
        self.d_i = {}
        self._H_i = {}
        for i in range(len(levels)):
            d = self._compute(i)
            self.d_i[i] = d
            min_ = np.max((0, levels[i] - d))
            max_ = np.min((1, levels[i] + d))
            self._H_i[i] = 1 / kl_bernoulli(min_, np.array([max_]))[0]
        self.H_i = np.array(list(self._H_i.values()))
        self.H = np.sum(self.H_i)

    def _compute(self, i):
        d = self.dist[i]
        d_ = d.reshape(-1, 1) * np.ones((len(d), len(x)))
        x_ = x.reshape(1, -1) * np.ones((len(d), len(x)))
        left = np.sum(d_ < x_, axis=0) - 1
        idx = np.argmin(np.abs(self.right[i] / x - left))
        return x[idx]


if __name__ == '__main__':
    eta_grid_ = pd.read_pickle("results/eta_1").values
    problem_complexity(eta_grid_)
