import numpy as np
from numpy.linalg import norm

from numba import njit
from sklearn.linear_model import cd_fast
from sklearn.utils import check_random_state
from scipy import linalg

from clar.utils import (
    clp_sqrt, BST, clp_sigma_gls, get_alpha_max, get_sigma_min, clp_sigma_inv)
from clar.duality_gap import get_duality_gap


def get_path(
        X, measurement, list_p_alpha, alpha_max,
        sigma_min, B0=None,
        n_iter=10**4, tol=10**-4, gap_freq=10, active_set_freq=5,
        S_freq=10, pb_name="CLaR", use_accel=False,
        n_nncvx_iter=10, verbose=True, heur_stop=False):
    dict_masks = {}
    dict_dense_Bs = {}
    B_hat = None

    for n_alpha, p_alpha in enumerate(list_p_alpha):
        print("--------------------------------------------------------")
        print("%i-th alpha over %i" % (n_alpha, len(list_p_alpha)))
        # unique params to store results
        alpha = p_alpha * alpha_max
        # run solver of pb_name
        B_hat, _, _, _ = solver(
            X, measurement, alpha, alpha_max, sigma_min, B0=B_hat,
            n_iter=n_iter, gap_freq=gap_freq, active_set_freq=active_set_freq,
            S_freq=S_freq, pb_name=pb_name, tol=tol,
            use_accel=use_accel,
            heur_stop=heur_stop)
        # save the results
        mask = np.abs(B_hat).sum(axis=1) != 0
        str_p_alpha = p_alpha
        # str_p_alpha = '%0.10f' % p_alpha
        if pb_name == "MTLME":
            n_sources = X.shape[1]
            n_epochs, _, n_times = measurement.shape
            B_reshaped = B_hat.reshape((n_sources, n_epochs, n_times))
            B_reshaped = B_reshaped.mean(axis=1)
            dict_masks[str_p_alpha] = mask
            dict_dense_Bs[str_p_alpha] = B_reshaped[mask, :]
        else:
            dict_masks[str_p_alpha] = mask
            dict_dense_Bs[str_p_alpha] = B_hat[mask, :]
    assert len(dict_dense_Bs.keys()) == len(list_p_alpha)
    return dict_masks, dict_dense_Bs


def wrap_solver(
        X, obs, p_alpha, alpha_max=None, sigma_min=None, B0=None,
        n_iter=10**4, tol=10**-4, gap_freq=10, active_set_freq=20,
        S_freq=10, pb_name="CLaR", use_accel=False,
        n_nncvx_iter=10, verbose=True, heur_stop=False,
        alpha_Sigma_inv=0.0001):
    if sigma_min is None:
        sigma_min = get_sigma_min(obs)
    if alpha_max is None:
        alpha_max = get_alpha_max(
            X, obs, sigma_min, pb_name, alpha_Sigma_inv=alpha_Sigma_inv)
    alpha = p_alpha * alpha_max
    (B, _, E, gaps) = solver(
        X, obs, alpha, alpha_max, sigma_min, B0=None,
        n_iter=n_iter, tol=tol, gap_freq=gap_freq,
        active_set_freq=active_set_freq,
        S_freq=S_freq, pb_name=pb_name, use_accel=use_accel,
        n_nncvx_iter=n_nncvx_iter, verbose=verbose,
        heur_stop=heur_stop,
        alpha_Sigma_inv=alpha_Sigma_inv)
    supp = norm(B, axis=1) != 0
    B_dns = B[supp, :]
    return B_dns, supp


def solver(
        X, all_epochs, alpha, alpha_max, sigma_min, B0=None,
        n_iter=10**4, tol=10**-4, gap_freq=10, active_set_freq=5,
        S_freq=10, pb_name="CLaR", use_accel=False,
        n_nncvx_iter=10, verbose=True, heur_stop=False,
        alpha_Sigma_inv=0.1):
    """
    Parameters
    --------------
    X: np.array, shape (n_sensors, n_sources)
        gain matrix
    all_epochs: np.array, shape (n_epochs, n_sensors, n_times)
        observations
    alpha: float
        positiv number, coefficient multiplying the penalization
    alpha_max: float
        positiv number, if alpha is bigger than alpha max, B=0
    sigma_min: float
        positiv number, value to which to eigenvalue smaller than sigma_min
        are put to when computing the inverse of ZZT
    B0: np.array, shape (n_sources, n_time)
        initial value of B
    n_iter: int
        nuber of iterations of the algorithm
    tol : float
        The tolerance for the optimization: if the updates are
        smaller than ``tol``, the optimization code checks the
        dual gap for optimality and continues until it is smaller
        than ``tol``
    gap_freq: int
        Compute the duality gap every gap_freq iterations.
    active_set_freq: int
        When updating B, while B_{j, :} != 0,  B_{j, :} keeps to
        be updated, at most active_set_freq times.
    S_freq: int
        S is updated every S times.
    pb_name: str
        choose the problem you want to solve between
        "MTL", "SGCL", "CLaR", "MLE", "MLER", "MRCER".
    use_accel: bool
        States if you want to use accelration while computing the dual.
    n_nncvx_iter: int
        An approach to solve such non-convex problems is to solve a succesion
        of convex problem. n_nncvx_iter is number of iteration in the outter loop.
    heur_stop: bool
        States if you want to use an heuristic stoppping criterion ot stop the algo.
        Here the heuristic stopping criterion is
        primal[i] - primal[i+1] < primal[0] * tol / 10.
    """

    if use_accel and pb_name != "SGCL":
        raise NotImplementedError()

    X = np.asfortranarray(X, dtype='float64')

    n_sources = X.shape[1]
    n_times = all_epochs.shape[-1]
    if verbose:
        print("--------- %s -----------------" % pb_name)

    if B0 is None:
        if pb_name != "MTLME":
            B = np.zeros((n_sources, n_times), dtype=float)
        else:
            n_epochs, _, n_times = all_epochs.shape
            B = np.zeros((n_sources, n_times * n_epochs), dtype=float)
    else:
        B = B0.copy().astype(np.float64)

    if pb_name in ("SGCL", "MTL"):
        if all_epochs.ndim != 2:
            raise ValueError("Wrong number of dimensions, expected 2, "
                             "got %d " % all_epochs.ndim)
        observations = all_epochs[None, :, :]
    elif pb_name == "MLE":
        observations = all_epochs[None, :, :]
        results = solver_(
            observations, X, alpha, alpha_max, sigma_min, B, n_iter, gap_freq,
            use_accel, active_set_freq, S_freq, tol=tol,
            pb_name="MLER", verbose=verbose,
            heur_stop=heur_stop, alpha_Sigma_inv=alpha_Sigma_inv)
        return results
    elif pb_name == "MRCE":
        observations = all_epochs[None, :, :]
        results = solver_(
            observations, X, alpha, alpha_max, sigma_min, B, n_iter, gap_freq,
            use_accel, active_set_freq, S_freq, tol=tol,
            pb_name="MRCER", verbose=verbose,
            heur_stop=heur_stop, alpha_Sigma_inv=alpha_Sigma_inv)
        return results
    elif pb_name in("CLaR", "MRCER", "MLER"):
        observations = all_epochs
    elif pb_name == "MTLME":
        if all_epochs.ndim != 3:
            raise ValueError("Wrong number of dimensions, expected 2, "
                             "got %d " % all_epochs.ndim)
        observations = all_epochs.transpose((1, 0, 2))
        observations = observations.reshape(observations.shape[0], -1)
        observations = observations.reshape((1, *observations.shape))
        n_epochs, n_channels, n_times = all_epochs.shape
        return solver(
            X, observations[0, :, :], alpha, alpha_max, sigma_min, B0=B,
            n_iter=n_iter, gap_freq=gap_freq,
            use_accel=use_accel, active_set_freq=active_set_freq,
            S_freq=S_freq, tol=tol, pb_name="MTL", verbose=verbose,
            heur_stop=heur_stop, alpha_Sigma_inv=alpha_Sigma_inv)
    elif pb_name == "NNCVX":
        B_support, support_absolute, all_n_active, all_weights = \
            nncvx.nncvx_solver(
                X, all_epochs, alpha, alpha_max, sigma_min, B0=B0,
                n_nncvx_iter=n_nncvx_iter, tol=tol, max_iter_inner=n_iter,
                gap_freq=gap_freq, active_set_freq=active_set_freq,
                S_freq=S_freq, heur_stop=heur_stop)
        B_algo = np.zeros((n_sources, n_times))
        B_algo[support_absolute, :] = B_support
        return B_algo, support_absolute, all_n_active, all_weights
    else:
        raise ValueError("Unknown solver %s" % pb_name)

    results = solver_(
        observations, X, alpha, alpha_max, sigma_min, B, n_iter, gap_freq,
        use_accel, active_set_freq, S_freq, tol=tol,
        pb_name=pb_name, verbose=verbose,
        heur_stop=heur_stop, alpha_Sigma_inv=alpha_Sigma_inv)
    return results


def solver_(
        all_epochs, X, alpha, alpha_max,  sigma_min, B, n_iter, gap_freq,
        use_accel, active_set_freq=5, S_freq=10, tol=10**-4,
        pb_name="CLaR", verbose=True, heur_stop=False, alpha_Sigma_inv=0.01):
    gaps = []
    gaps_acc = []
    E = []  # E for energy, ie value of primal objective
    p_obj = np.infty
    d_obj = - np.infty
    d_obj_acc = - np.infty
    n_epochs, n_sensors, n_times = all_epochs.shape
    if pb_name in ("CLaR", "MRCER", "MLER"):
        # compute Y2, costly quantity to compute once
        Y2 = np.zeros((n_sensors, n_sensors))
        Y = np.zeros((n_sensors, n_times))
        for l in range(n_epochs):
            Y2 += all_epochs[l, :, :] @ all_epochs[l, :, :].T
            Y += all_epochs[l, :, :]
        Y2 /= n_epochs
        Y /= n_epochs
    elif pb_name == "MTL" or "SGCL":
        Y = all_epochs[0]
        Y2 = None
    elif pb_name == "MTLME":
        Y = all_epochs

    if use_accel:
        K = 6
        last_K_res = np.zeros((K, n_sensors * n_times))
        onesKm1 = np.ones(K - 1)
        U = np.zeros((K - 1, n_sensors * n_times))
    R = np.asfortranarray(Y - X @ B, dtype='float64')

    # compute the value of the first primal
    B0 = np.zeros_like(B)
    _, _, _, kwargs_gap = update_S(
        X, Y, Y2, all_epochs, Y - X @ B0, B0, alpha, alpha_Sigma_inv,
        sigma_min, pb_name)
    p_first, d_first = get_duality_gap(pb_name=pb_name, **kwargs_gap)
    E.append(p_first)
    print("------------------------")
    print("First primal: %0.2e" % p_first)

    # main for loop
    ##########################################################
    for t in range(n_iter):
        #####################################################
        # update S
        if t % S_freq == 0:
            if pb_name == 'SGCL':
                Z = Y - X @ B
                ZZT = Z @ Z.T / n_times
                S_trace, S_inv = clp_sqrt(ZZT, sigma_min)
                S_inv_R = S_inv @ R
                S_inv_X = S_inv @ X
                kwargs_gap = {
                    'R': R, 'X': X, 'Y': Y, 'B': B, 'S_trace': S_trace,
                    'S_inv_R': S_inv_R, 'sigma_min': sigma_min, 'alpha': alpha}
            else:
                S_inv, S_inv_R, S_inv_X, kwargs_gap = update_S(
                        X, Y, Y2, all_epochs, R, B, alpha, alpha_Sigma_inv,
                        sigma_min, pb_name)
        ###################################################
        # update B
        update_B(X, Y, B, R, S_inv_R, S_inv_X, alpha, pb_name,
                 active_set_freq)
        # compute duality gap
        if t % gap_freq == 0 or pb_name == "SGCL":
            p_obj, d_obj = get_duality_gap(pb_name=pb_name, **kwargs_gap)
            E.append(p_obj)
            gap = p_obj - d_obj
            gaps.append(gap)
            if verbose:
                print("p_obj: %.6e" % (p_obj))
                print("d_obj: %.6e" % (d_obj))
                print("length support: %i" % (norm(B, axis=1) != 0).sum())
                print("iteration: %d, gap: %.4e" % (t, gap))
        if t // gap_freq >= 1 and heur_stop:
            heuristic_stopping_criterion = (E[-2] - E[-1]) < \
                tol * np.abs(E[0]) / 10
        else:
            heuristic_stopping_criterion = False
        if gap < tol * E[0] or \
                (use_accel and (p_obj - d_obj_acc < tol * E[0])) \
                or heuristic_stopping_criterion:
            break
    # be carefull for MLE and MRCE S_inv is Sigma_inv
    results = (B, S_inv, np.asarray(E), np.asarray(gaps))
    if use_accel:
        results = (B, S_inv, np.asarray(E),
                  (np.asarray(gaps), np.asarray(gaps_acc)))
    return results


def update_S(X, Y, Y2, all_epochs, R, B, alpha, alpha_Sigma_inv,
             sigma_min, pb_name):
    n_sensors, n_times = Y.shape
    if pb_name == "CLaR":
        XB = X @ B
        YXB = Y @ XB.T
        ZZT = (Y2 - YXB - YXB.T + XB @ XB.T) / n_times
        S_trace, S_inv = clp_sqrt(ZZT, sigma_min)
        S_inv_R = np.asfortranarray(S_inv @ R)
        S_inv_X = S_inv @ X
        kwargs_gap = {
            'X': X, 'all_epochs': all_epochs, 'B': B,
            'S_trace': S_trace, 'S_inv': S_inv, 'sigma_min': sigma_min,
            'alpha': alpha}
    elif pb_name == "MRCER":
        XB = X @ B
        YXB = Y @ XB.T
        emp_cov = (Y2 - YXB - YXB.T + XB @ XB.T) / n_times
        Sigma, S_inv = update_Sigma_gls(
            emp_cov, alpha_Sigma_inv)
        S_inv_R = S_inv @ R  # be careful this is not real S_inv_R
        S_inv_X = S_inv @ X
        kwargs_gap = {
            'X': X, 'Y': Y, 'Y2': Y2, 'Sigma': Sigma,
            'Sigma_inv': S_inv, 'alpha': alpha,
            'alpha_Sigma_inv': alpha_Sigma_inv, 'B': B,
            'sigma_min': sigma_min}
    elif pb_name == "MLER":
        XB = X @ B
        YXB = Y @ XB.T
        emp_cov = (Y2 - YXB - YXB.T + XB @ XB.T) / n_times
        S = emp_cov  # be carefull this is not S, but Sigma
        _, S_inv = clp_sigma_inv(emp_cov, sigma_min ** 2)
        S_inv_R = S_inv @ R  # be careful this is not real S_inv_R
        S_inv_X = S_inv @ X  # it is Sigma_inv_R and Sigma_inv_X
        kwargs_gap = {
            'X': X, 'Y': Y, 'Y2': Y2, 'Sigma_inv': S_inv,
            'alpha': alpha, 'B': B}
    elif pb_name == "SGCL":
        Z = Y - X @ B
        ZZT = Z @ Z.T / n_times
        S_trace, S_inv = clp_sqrt(ZZT, sigma_min)
        S_inv_R = S_inv @ R
        S_inv_X = S_inv @ X
        kwargs_gap = {
            'R': R, 'X': X, 'Y': Y, 'B': B, 'S_trace': S_trace,
            'S_inv_R': S_inv_R, 'sigma_min': sigma_min, 'alpha': alpha}
    elif pb_name == "MTL" or pb_name == "MTLME":
        S_inv = np.eye(1)
        S_inv_R = R
        S_inv_X = X
        kwargs_gap = {'X': X, 'Y': Y, 'B': B, 'alpha': alpha}
    return S_inv, S_inv_R, S_inv_X, kwargs_gap


@njit
def update_B(
        X, Y, B, R,  S_inv_R, S_inv_X,
        alpha, pb_name,
        active_set_passes=5):
    n_sensors, n_times = Y.shape
    n_sources = X.shape[1]

    is_not_MTL = (pb_name != "MTL") and (pb_name != "MTLME")
    # store some quantities for not recompute them multiple times
    active_set = np.ones(n_sources)

    L = np.zeros(n_sources)
    for j in range(n_sources):
        L[j] = X[:, j] @ S_inv_X[:, j]

    # for it in range(active_set_passes + 1):
    for t in range(active_set_passes):
        if t == 0:
            sources_to_update = np.arange(n_sources)
        else:
            sources_to_update = np.where(active_set != 0)[0]

        for j in sources_to_update:
            # update line j of B
            if active_set[j]:
                R += X[:, j:j+1] @ B[j:j+1, :]
                if is_not_MTL:
                    S_inv_R += S_inv_X[:, j:j+1] @ B[j:j+1, :]

            B[j:j+1, :], line_is_zero = BST(
                X[:, j:j+1].T @ S_inv_R / L[j],
                alpha * n_sensors * n_times / L[j])
            # if (abs(B[j:j+1, :]) > 1000).any():

            active_set[j] = not line_is_zero
            if not line_is_zero:
                R -= X[:, j:j+1] @ B[j:j+1, :]
                if is_not_MTL:
                    S_inv_R -= S_inv_X[:, j:j+1] @ B[j:j+1, :]


def update_Sigma_gls(
        emp_cov, alpha_Sigma_inv, cov_init=None, mode='cd', tol=1e-4,
        enet_tol=1e-4, sigmamin=1e-4, max_iter=1e4, verbose=False,
        return_costs=False, eps=np.finfo(np.float64).eps,
        return_n_iter=False):
    _, n_features = emp_cov.shape
    if cov_init is None:
        covariance_ = emp_cov.copy()
    else:
        covariance_ = cov_init.copy()
    covariance_ *= 0.95
    diagonal = emp_cov.flat[::n_features + 1]
    covariance_.flat[::n_features + 1] = diagonal
    precision_ = linalg.pinvh(covariance_)

    indices = np.arange(n_features)
    errors = dict(over='raise', invalid='ignore')
    sub_covariance = np.copy(covariance_[1:, 1:], order='C')

    for idx in range(n_features):
        if idx > 0:
            di = idx - 1
            sub_covariance[di] = covariance_[di][indices != idx]
            sub_covariance[:, di] = covariance_[:, di][indices != idx]
        else:
            sub_covariance[:] = covariance_[1:, 1:]
        row = emp_cov[idx, indices != idx]
        with np.errstate(**errors):
            # Use coordinate descent
            coefs = -(precision_[indices != idx, idx]
                        / (precision_[idx, idx] + 1000 * eps))
            coefs, _, _, _ = cd_fast.enet_coordinate_descent_gram(
                coefs, alpha_Sigma_inv, 0, sub_covariance,
                row, row, max_iter, enet_tol,
                check_random_state(None), False)
        # Update the precision matrix
        precision_[idx, idx] = (
            1. / (covariance_[idx, idx]
                    - np.dot(covariance_[indices != idx, idx], coefs)))
        precision_[indices != idx, idx] = (- precision_[idx, idx]
                                            * coefs)
        precision_[idx, indices != idx] = (- precision_[idx, idx]
                                            * coefs)
        coefs = np.dot(sub_covariance, coefs)
        covariance_[idx, indices != idx] = coefs
        covariance_[indices != idx, idx] = coefs
    if not np.isfinite(precision_.sum()):
        raise FloatingPointError('The system is too ill-conditioned '
                                    'for this solver')
    return covariance_, precision_
