# Acknowledge: https://github.com/t-sakai-kure/pywsl/blob/master/pywsl/pul/pu_mr.py

import numpy as np

from utils import comcalc as com
# from utils.PULearning.utils import check


class PU_SL(object):
    def __init__(self, **kwargs):
        self.prior = kwargs.get('prior', 0.5)
        self.n_fold = kwargs.get('n_fold', 5)
        self.sigma_list = kwargs.get('sigma_list', None)
        self.lambda_list = kwargs.get('lambda_list', None)
        self.n_basis = kwargs.get('n_basis', 1000)
        self.model = kwargs.get('model', 'gauss')

        if self.lambda_list is None:
            self.lambda_list = [0.001, 0.01, 0.1, 1, 10]

    def fit(self, x_p, x_u):

        # check.same_dim(x_p, x_u)
        if x_p.shape[1] != x_u.shape[1]:
            raise ValueError("""Dimension must be the same.
        Expected: x1.shape[1] == x2.shape[1]
        Actual: x1.shape[1] != x2.shape[1]""")
            
        n_p, n_u = x_p.shape[0], x_u.shape[0]

        x_c, d_p, d_u = self.calc_sigma(x_p, x_u)
        sigma, lam = self.cross_validation(n_p, n_u, d_p, d_u)

        K_p, K_u = self.ker(d_p, sigma, self.model), self.ker(d_u, sigma, self.model)
        
        print((K_p[:,:-1] - x_p).sum(), (K_u[:,:-1] - x_u).sum())
        
        H, h = self.solve_prepare(K_p, K_u)
        
        
        # exit(0)
        w = self.solve(H, h, lam)

        # exit(0)
        
        f_dec = lambda x_t: self.make_func(w, x_c, sigma, self.model, x_t)

        return f_dec

    def solve(self, H, h, lam):
        b = H.shape[0]
        Reg = lam * np.eye(b)
        # print(H, Reg, h)
        w = np.linalg.solve(H + Reg, h)
        

        return w

    def solve_prepare(self, K_p, K_u):
        H = K_u.T.dot(K_u) / K_u.shape[0]
        h = 2 * self.prior * np.mean(K_p, axis=0) - np.mean(K_u, axis=0)
        h[np.isnan(h)] = 0

        return H, h

    def cross_validation(self, n_p, n_u, d_p, d_u):
        # cv
        cv_index_p, cv_index_u = com.cv_index(n_p, self.n_fold), com.cv_index(n_u, self.n_fold)

        score_cv_fold = np.empty((len(self.sigma_list), len(self.lambda_list), self.n_fold))

        if len(self.sigma_list) == 1 and len(self.lambda_list) == 1:
            score_cv = np.empty((1, 1))
            score_cv[0, 0] = -np.inf
            best_sigma_index, best_lambda_index = 1, 1
        else:
            for ite_sig, sigma in enumerate(self.sigma_list):
                K_p, K_u = self.ker(d_p, sigma, self.model), self.ker(d_u, sigma, self.model)

                for ite_fold in range(self.n_fold):
                    K_ptr = K_p[cv_index_p != ite_fold, :]
                    K_pte = K_p[cv_index_p == ite_fold, :]
                    K_utr = K_u[cv_index_u != ite_fold, :]
                    K_ute = K_u[cv_index_u == ite_fold, :]

                    H_tr, h_tr = self.solve_prepare(K_ptr, K_utr)
                    # H_tr = K_utr.T.dot(K_utr) / K_utr.shape[0]
                    # h_tr = 2 * self.prior * np.mean(K_ptr, axis=0) - np.mean(K_utr, axis=0)

                    for ite_lam, lam in enumerate(self.lambda_list):
                        w = self.solve(H_tr, h_tr, lam)
                        gp, gu = K_pte.dot(w), K_ute.dot(w)
                        score_cv_fold[ite_sig, ite_lam, ite_fold] \
                            = self.calc_risk(gp, gu)

            score_cv = np.mean(score_cv_fold, axis=2)

        # score_best = np.inf
        index = np.argmin(score_cv.ravel())
        index = np.unravel_index(index, score_cv.shape)
        best_sigma_index, best_lambda_index = index[0], index[1]

        sigma = self.sigma_list[best_sigma_index]
        lam = self.lambda_list[best_lambda_index]

        return sigma, lam

    def calc_sigma(self, x_p, x_u):
        if self.model == 'gauss':
            x_c = np.concatenate((x_p, x_u), axis=0)
            d_p, d_u = com.squared_dist(x_p, x_c), com.squared_dist(x_u, x_c)
            if self.sigma_list is None:
                med = np.median(np.concatenate((d_p.ravel(), d_u.ravel()), axis=0))
                self.sigma_list = np.sqrt(med) * np.logspace(-2, 1, 10)
        elif self.model == 'lm':
            b, x_c = x_p.shape[0] + 1, None
            d_p, d_u = com.homo_coord(x_p), com.homo_coord(x_u)
            self.sigma_list = [1]
        else:
            raise ValueError('Invalid model: {} is not supported.'.format(self.model))

        return x_c, d_p, d_u

    def calc_risk(self, gp, gu):
        f_n = np.mean(gp <= 0)
        f_pu = np.mean(gu >= 0)
        pu_risk = self.prior * f_n + np.maximum(f_pu + self.prior * f_n - self.prior, 0)
        return pu_risk

    def ker(self, d, sigma, model):
        if model == 'gauss':
            return com.gauss_basis(d, sigma)
        elif model == 'lm':
            return d

    def make_func(self, w, x_c, sigma, model, x_t):
        if model == 'gauss':
            K = com.gauss_basis(com.squared_dist(x_t, x_c), sigma)
        elif model == 'lm':
            K = com.homo_coord(x_t)
        return K.dot(w)
