import numpy as np
from methods.base_method import BaseMethod


class AdaNIDS(BaseMethod):
    def __init__(self, gamma0, N=1000, type_gamma=1, alpha=1 / 2, c=1,
                 beta=1, eta=1, min_gamma=True,
                 history_with_gamma=False, normalized_by_gamma=True, const_gamma=False,
                 use_consensus_error=False,
                 *args, **kwargs):
        """
        :param gamma0: float or np.array[n], starting step size
        :param N: int, iterations number
        :param type_gamma: int, type of step size. Possible values: 1, 2, 3 (see article)
        :param alpha: float, parameter for backtracking
        :param c: float, parameter in d^nu updating
        :param beta: float, parameter for tilde{f} constructing
        :param eta: float, linear coefficient in backtracking condition
        :param min_gamma: bool, if True, we choose minimal step size on all nodes
        :param history_with_gamma: bool, if True, method saves gamma values into history
        """
        super().__init__(*args, **kwargs)
        self.gamma0 = gamma0
        self.N = N
        self.type_gamma = type_gamma
        self.alpha = alpha
        self.beta = beta
        self.eta = eta
        self.c = c
        self.history_with_gamma = history_with_gamma
        self.min_gamma = min_gamma
        self.gamma_list = []
        self.normalized_by_gamma = normalized_by_gamma
        self.const_gamma = const_gamma
        self.use_consensus_error = use_consensus_error
        if const_gamma:
            self.gamma = gamma0

    @staticmethod
    def _change_gamma(_cond_value, gamma, alpha):
        """
        Method implements one step of backtracking
        :param _cond_value: np.array[n], values for backtracking in each node
        :param gamma: np.array[n], current step sizes
        :param alpha: float, parameter for updating
        :return: updated gamma
        """
        change_mask = (_cond_value > 0).astype(int)
        gamma *= (change_mask * alpha + (1 - change_mask))
        return gamma

    @staticmethod
    def _get_step(X, gamma, D):
        Dgamma = gamma * D.T
        return X - Dgamma.T

    def update_gamma(self, F, X_tilde, X_delta, current_iteration, consensus_error):
        """
        Function for updating Gamma through backtracking in each node
        :param gradF: callable, function for gradient calculating of function F(X)=sum_i F_i(x_i)
        :param F: callable, object function
        """
        if self.const_gamma:
            return
        F_x = F(X_tilde)
        gradF_X_tilde = self._new_grad

        def get_cond_value(gamma):
            X_new = self._get_step(X_tilde, gamma, X_delta)
            diff_F = F(X_new) - F_x
            first_ord = gamma * (gradF_X_tilde * X_delta).sum(axis=-1)
            square_part = -gamma / 2 * (X_delta * X_delta).sum(axis=-1)
            _cond_value = diff_F + first_ord + square_part
            return _cond_value

        coef = (current_iteration + self.beta + 1) / (current_iteration + 1)
        gamma = coef * self.gamma * np.ones(self.X.shape[0])
        _cond_value = get_cond_value(gamma)
        while not (_cond_value <= self.EPS).all():
            gamma = self._change_gamma(_cond_value, gamma, self.alpha)
            _cond_value = get_cond_value(gamma)
        self.gamma_list.append(np.copy(gamma))
        self.gamma = np.copy(gamma)
        self.gamma = self.gamma.min()
        self.gamma = np.maximum(self.gamma, self.EPS_GAMMA)

    def consensus_error_step(self, consensus_error):
        return self.X - self.c * consensus_error

    def add_direction_construction(self, consensus, gradF_X_tilde):
        Y_tilde = gradF_X_tilde + self.Y
        return self.Y - self.c * (Y_tilde - consensus(Y_tilde))

    def x_update(self, X_tilde, x_delta):
        self.X, self.old_X = X_tilde - self.gamma * x_delta, self.X

    def dual_update(self, add_direction, consensus_error):
        self.Y = add_direction + self.c * consensus_error / self.gamma

    def __call__(self, X0, gradF, consensus, F=None, grad_sum=None):
        """
        :param X0: np.array[n, ...], starting points in each node
        :param gradF: callable, function for gradient calculating of function F(X)=sum_i F_i(x_i)
        :param consensus: callable, function that implements consensus procedure
        :param F: callable, object function
        :return: np.array[n, ...], obtained point with the same dimension as X0
        """
        gamma0, N, type_gamma = self.gamma0, self.N, self.type_gamma
        self.Y = gradF(X0)
        self.Y = np.zeros(X0.shape)
        self.X = X0.copy()
        self._old_grad = self.Y.copy()
        self.gamma = self.gamma0
        for i in range(N):
            consensus_error = (self.X - consensus(self.X))
            X_tilde = self.consensus_error_step(consensus_error)
            self._new_grad = gradF(X_tilde)
            add_direction = self.add_direction_construction(consensus, self._new_grad)
            x_delta = add_direction + self._new_grad
            self.update_gamma(F, X_tilde, x_delta, i, consensus_error)
            if self.return_history:
                if self.history_with_gamma:
                    elem = (self.X.copy(), self.gamma.copy(), X_tilde.copy(), self.Y.copy())
                else:
                    elem = self.X.copy()
                self.history.append(elem)
            self.x_update(X_tilde, x_delta)
            self.dual_update(add_direction, consensus_error)
            if self._old_X is not None and np.linalg.norm(self.X - self._old_X) <= 1e-16:
                break
        if self.return_history:
            if self.history_with_gamma:
                elem = (self.X.copy(), None, X_tilde.copy(), self.Y.copy())
            else:
                elem = self.X.copy()
            self.history.append(elem)
        return self.X
