import numpy as np
from methods.base_method import BaseMethod


class EXTRA(BaseMethod):
    def __init__(self, gamma0, N=1000, type_gamma=1, alpha=1 / 2, c=1,
                 beta=1, eta=1, min_gamma=False,
                 history_with_gamma=False, normalized_by_gamma=True, const_gamma=False,
                 *args, **kwargs):
        """
        :param gamma0: float or np.array[n], starting step size
        :param N: int, iterations number
        :param type_gamma: int, type of step size. Possible values: 1, 2, 3 (see article)
        :param alpha: float, parameter for backtracking
        :param c: float, parameter in d^nu updating
        :param beta: float, parameter for tilde{f} constructing
        :param eta: float, linear coefficient in backtracking condition
        :param min_gamma: bool, if True, we choose minimal step size on all nodes
        :param history_with_gamma: bool, if True, method saves gamma values into history
        """
        super().__init__(*args, **kwargs)
        self.gamma0 = gamma0
        self.N = N
        self.type_gamma = type_gamma
        self.alpha = alpha
        self.beta = beta
        self.eta = eta
        self.c = c
        self.history_with_gamma = history_with_gamma
        self.min_gamma = min_gamma
        self.gamma_list = []
        self.normalized_by_gamma = normalized_by_gamma
        self.const_gamma = const_gamma
        if const_gamma:
            self.gamma = gamma0

    @staticmethod
    def _get_step(X, gamma, D):
        Dgamma = gamma * D.T
        return X - Dgamma.T

    def __call__(self, X0, gradF, consensus, F=None, grad_sum=None):
        """
        :param X0: np.array[n, ...], starting points in each node
        :param gradF: callable, function for gradient calculating of function F(X)=sum_i F_i(x_i)
        :param consensus: callable, function that implements consensus procedure
        :param F: callable, object function
        :return: np.array[n, ...], obtained point with the same dimension as X0
        """
        gamma0, N, type_gamma = self.gamma0, self.N, self.type_gamma
        self.X = X0.copy()
        X_prev = X0.copy()
        self.gamma = self.gamma0
        for i in range(N):
            gradX = gradF(self.X)
            gradX_prev = gradF(X_prev)
            x_part = 2 * self.X - X_prev
            grad_part = gradX - gradX_prev
            X_prev = self.X.copy()
            if self.return_history:
                if self.history_with_gamma:
                    elem = (self.X.copy(), self.gamma)
                else:
                    elem = self.X.copy()
                self.history.append(elem)
            if i==0:
                self.X = consensus(self.X) - self.gamma * gradX
            else:
                self.X = (x_part + consensus(x_part))/2 - self.gamma * grad_part
        if self.return_history:
            if self.history_with_gamma:
                elem = (self.X.copy(), None)
            else:
                elem = self.X.copy()
            self.history.append(elem)
        return self.X
