import numpy as np

import torch
import torch.nn as nn

from .solver import get_solver
from .grad import make_pair, backward_factory
from .jacobian import power_method
from .layer_utils import deq_decorator

from .norm import reset_norm
from .dropout import reset_dropout


class DEQBase(nn.Module):
    """
    Base class for Deep Equilibrium (DEQ) model.

    Args:
        args (argparse.Namespace): Parsed command line arguments.    
    """
    def __init__(self, args):
        super(DEQBase, self).__init__()
        
        self.args = args
        self.f_solver = get_solver(args.f_solver)
        self.b_solver = get_solver(args.b_solver)
        
        self.speedy = args.f_solver == 'speedy_naive'

        self.f_thres = args.f_thres
        self.b_thres = args.b_thres
        
        self.f_eps = args.f_eps
        self.b_eps = args.b_eps
        
        self.f_stop_mode = args.f_stop_mode
        self.b_stop_mode = args.b_stop_mode

        self.eval_f_thres = args.eval_f_thres if args.eval_f_thres > 0 else int(self.f_thres * args.eval_factor) 

        self.hook = None

    def _sradius(self, deq_func, z_star):
        """
        Estimates the spectral radius using the power method.

        Args:
            deq_func (callable): The DEQ function.
            z_star (torch.Tensor): The fixed point solution.

        Returns:
            float: The spectral radius.
        """
        with torch.enable_grad():
            new_z_star = deq_func(z_star.requires_grad_())
        _, sradius = power_method(new_z_star, z_star, n_iters=100)

        return sradius

    def _solve_fixed_point(
            self, deq_func, z_init,
            f_thres=None, 
            solver_kwargs=None,
            **kwargs
            ):
        """
        Solves for the fixed point. Must be overridden in subclasses.

        Args:
            deq_func (callable): The DEQ function.
            z_init (torch.Tensor): Initial tensor for fixed point solver.
            f_thres (float, optional): 
                Forward step threshold for overwriting the solver threshold in this call. Default None.
            solver_kwargs (dict, optional):
                Additional arguments for the solver used in this forward pass. These arguments will overwrite the default solver arguments. 
                Refer to the documentation of the specific solver for the list of accepted arguments. Default None.

        Raises:
            NotImplementedError: If the method is not overridden.
        """
        raise NotImplementedError
    
    def forward(
            self, func, z_init, 
            solver_kwargs=None,
            sradius_mode=False, 
            backward_writer=None,
            **kwargs
            ):
        """
        Defines the computation graph and gradients of DEQ. Must be overridden in subclasses.

        Args:
            func (callable): The DEQ function.
            z_init (torch.Tensor): Initial tensor for fixed point solver.
            solver_kwargs (dict, optional): 
                Additional arguments for the solver used in this forward pass. These arguments will overwrite the default solver arguments. 
                Refer to the documentation of the specific solver for the list of accepted arguments. Default None.
            sradius_mode (bool, optional): 
                If True, computes the spectral radius in validation and adds 'sradius' to the `info` dictionary. Default False. 
            backward_writer (callable, optional): 
                Callable function to monitor the backward pass. It should accept the solver statistics dictionary as input. Default None.

        Raises:
            NotImplementedError: If the method is not overridden.
        """
        raise NotImplementedError


class DEQIndexing(DEQBase):
    """
    DEQ computational graph that samples fixed point states at specific indices.
    Gradients are applied to the sampled solver states.

    Args:
        args (argparse.Namespace): Parsed command line arguments.
    """
    def __init__(self, args):
        super(DEQIndexing, self).__init__(args)
        
        '''Define gradient functions through the backward factory.'''

        # First compute the f_thres indexing where we add corrections.
        self.indexing = self._compute_f_thres(args.f_thres)

        # By default, we use the same phantom grad for all correction losses.
        # You can also set different grad steps a, b, and c for different terms by `args.grad a b c ...''.
        indexing_pg = make_pair(self.indexing, args.grad)
        produce_grad = [
                backward_factory(grad_type=pg, tau=args.tau, sup_gap=args.sup_gap, sup_loc=args.sup_loc) for pg in indexing_pg
                ]

        # Enabling args.ift will replace the last gradient function by IFT.
        if args.ift or args.hook_ift:
            produce_grad[-1] = backward_factory(
                grad_type='ift', hook_ift=args.hook_ift, b_solver=self.b_solver,
                b_solver_kwargs=dict(threshold=args.b_thres, eps=args.b_eps, stop_mode=args.b_stop_mode)
                )

        self.produce_grad = produce_grad
    
    def _compute_f_thres(self, f_thres):
        """
        Computes the steps for sampling internal solver states.
        Priority: args.n_losses > args.indexing.
        Uses args.n_losses to uniformly divide the solver forward threshold if args.n_losses is designated.
        Otherwise, uses args.indexing to generate the sample sequence.
        By default, it returns the f_thres if no args.n_losses or args.indexing apply.
        
        Args:
            f_thres (float): Forward step threshold for the DEQ solver.

        Returns:
            list[int]: List of solver steps to be sampled.
        """
        if self.args.n_losses > 1:
            n_losses = max(min(f_thres, self.args.n_losses), 1)
            delta = int(f_thres // n_losses)
            if f_thres % n_losses == 0:
                return [(k+1)*delta for k in range(n_losses)]
            else:
                return [f_thres-(n_losses-k-1)*delta for k in range(n_losses)]
        else:
            return [*self.args.indexing, f_thres]

    def _solve_fixed_point(
               self, deq_func, z_init, 
               f_thres=None, 
               indexing=None, 
               solver_kwargs=None, 
               **kwargs
               ):
        """
        Solves for the fixed point using the DEQ solver.

        Args:
            deq_func (callable): The DEQ function.
            z_init (torch.Tensor): Initial tensor for fixed point solver.
            f_thres (float, optional): Forward step threshold for overwriting the solver threshold in this call. Default None.
            indexing (list, optional): Indexing list for sampling the DEQ solver. Default None.
            solver_kwargs (dict, optional): 
                Additional arguments for the solver used in this forward pass. These arguments will overwrite the default solver arguments. 
                Refer to the documentation of the specific solver for the list of accepted arguments. Default None.

        Returns:
            torch.Tensor: The fixed point solution.
            list[torch.Tensor]: Sampled fixed point trajectory according to args.n_losses or args.indexing.
            dict: A dict containing solver statistics.
        """
        solver_kwargs = {k:v for k, v in solver_kwargs.items() if k != 'f_thres'}
        indexing = indexing if self.training else None

        with torch.no_grad():
            z_star, trajectory, info = self.f_solver(
                    deq_func, x0=z_init, threshold=f_thres,                     # To reuse previous fixed points
                    eps=self.f_eps, stop_mode=self.f_stop_mode, indexing=indexing,
                    **solver_kwargs
                    )

        return z_star, trajectory, info

    def forward(
            self, func, z_init, 
            solver_kwargs=None,
            sradius_mode=False, 
            backward_writer=None,
            **kwargs
            ):
        """
        Defines the computation graph and gradients of DEQ.

        This method carries out the forward pass computation for the DEQ model, by solving for the fixed point.
        During training, it also keeps track of the trajectory of the solution. 
        In inference mode, it returns the final fixed point.

        Args:
            func (callable): The DEQ function.
            z_init (torch.Tensor): Initial tensor for fixed point solver.
            solver_kwargs (dict, optional): 
                Additional arguments for the solver used in this forward pass. These arguments will overwrite the default solver arguments. 
                Refer to the documentation of the specific solver for the list of accepted arguments. Default None.
            sradius_mode (bool, optional): 
                If True, computes the spectral radius in validation and adds 'sradius' to the `info` dictionary. Default False. 
            backward_writer (callable, optional): 
                Callable function to monitor the backward pass. It should accept the solver statistics dictionary as input. Default None.

        Returns:
            list[torch.Tensor]:
                During training, returns the tracked gradients of the sampled fixed point trajectory according to args.n_losses or args.indexing.
                During inference, returns a list containing the fixed point solution only.
            dict: 
                A dict containing solver statistics.
        """
        deq_func, z_init = deq_decorator(func, z_init, speedy=self.speedy)
        
        if solver_kwargs is None:
            solver_kwargs = dict()

        if self.training:
            if type(solver_kwargs.get('f_thres', None)) in [int, float]:
                indexing = self._compute_f_thres(solver_kwargs['f_thres'])
            else:
                indexing = self.indexing

            _, trajectory, info = self._solve_fixed_point(deq_func, z_init, 
                    f_thres=solver_kwargs.get('f_thres', self.f_thres), indexing=indexing, solver_kwargs=solver_kwargs)
            
            z_out = []
            for z_pred, produce_grad in zip(trajectory, self.produce_grad):
                z_out += produce_grad(self, deq_func, z_pred, writer=backward_writer)   # See torchdeq.grad for the backward pass
            
            z_out = [deq_func.vec2list(each) for each in z_out]
        else:
            # During inference, we directly solve for the fixed point
            z_star, _, info = self._solve_fixed_point(deq_func, z_init, 
                    f_thres=solver_kwargs.get('f_thres', self.eval_f_thres), solver_kwargs=solver_kwargs)
            
            sradius = self._sradius(deq_func, z_star) if sradius_mode else torch.zeros(1, device=z_star.device)
            info['sradius'] = sradius

            z_out = [deq_func.vec2list(z_star)]

        return z_out, info


class DEQSliced(DEQBase):
    """
    DEQ computational graph that slices the full solver trajectory to apply gradients.

    Args:
        args (argparse.Namespace): Parsed command line arguments.
        
    """
    def __init__(self, args):
        super(DEQSliced, self).__init__(args)
        
        '''Define gradient functions through the backward factory.'''

        # First compute the f_thres indexing where we add corrections.
        self.indexing = self._compute_f_thres(args.f_thres)
               
        # By default, we use the same phantom grad for all correction losses.
        # You can also set different grad steps a, b, and c for different terms by `args.grad a b c ...''.
        indexing_pg = make_pair(self.indexing, args.grad)
        produce_grad = [
                backward_factory(grad_type=pg, tau=args.tau, sup_gap=args.sup_gap, sup_loc=args.sup_loc) for pg in indexing_pg
                ]

        # Enabling args.ift will replace the last gradient function by IFT.
        if args.ift or args.hook_ift:
            produce_grad[-1] = backward_factory(
                grad_type='ift', hook_ift=args.hook_ift, b_solver=self.b_solver,
                b_solver_kwargs=dict(threshold=args.b_thres, eps=args.b_eps, stop_mode=args.b_stop_mode)
                )

        self.produce_grad = produce_grad
    
    def _compute_f_thres(self, f_thres):
        """
        Computes the steps for sampling internal solver states.
        Priority: args.n_losses > args.indexing.
        Uses args.n_losses to uniformly divide the solver forward threshold if args.n_losses is designated.
        Otherwise, uses args.indexing to generate the sample sequence.
        By default, it returns the f_thres if no args.n_losses or args.indexing apply.
        
        Args:
            f_thres (float): Forward step threshold for the DEQ solver.

        Returns:
            list[int]: List of solver steps to be sampled.
        """
        if self.args.n_losses > 1:
            return [int(f_thres // self.args.n_losses) for _ in range(self.args.n_losses)]
        else:
            return np.diff([0, *self.args.indexing, f_thres]).tolist()

    def _solve_fixed_point(
            self, deq_func, z_init,
            f_thres=None, 
            solver_kwargs=None,
            **kwargs
            ):
        """
        Solves for the fixed point using the DEQ solver.

        Args:
            deq_func (callable): The DEQ function.
            z_init (torch.Tensor): Initial tensor for fixed point solver.
            f_thres (float, optional): Forward step threshold for overwriting the solver threshold in this call. Default None.
            indexing (list, optional): Indexing list for sampling the DEQ solver. Default None.
            solver_kwargs (dict, optional): 
                Additional arguments for the solver used in this forward pass. These arguments will overwrite the default solver arguments. 
                Refer to the documentation of the specific solver for the list of accepted arguments. Default None.

        Returns:
            torch.Tensor: The fixed point solution.
            list[torch.Tensor]: Sampled fixed point trajectory according to args.n_losses or args.indexing.
            dict: A dict containing solver statistics.
        """
        solver_kwargs = {k:v for k, v in solver_kwargs.items() if k != 'f_thres'}

        with torch.no_grad():
            z_star, _, info = self.f_solver(
                    deq_func, x0=z_init, threshold=f_thres,             # To reuse the previous fixed point
                    eps=self.f_eps, stop_mode=self.f_stop_mode,
                    **solver_kwargs
                    )         
        
        return z_star, info

    def forward(
            self, func, z_star, 
            solver_kwargs=None,
            sradius_mode=False, 
            backward_writer=None,
            **kwargs
            ):
        """
        Defines the computation graph and gradients of DEQ.

        Args:
            func (callable): The DEQ function.
            z_init (torch.Tensor): Initial tensor for fixed point solver.
            solver_kwargs (dict, optional): 
                Additional arguments for the solver used in this forward pass. These arguments will overwrite the default solver arguments. 
                Refer to the documentation of the specific solver for the list of accepted arguments. Default None.
            sradius_mode (bool, optional): 
                If True, computes the spectral radius in validation and adds 'sradius' to the `info` dictionary. Default False. 
            backward_writer (callable, optional): 
                Callable function to monitor the backward pass. It should accept the solver statistics dictionary as input. Default None.

        Returns:
            list[torch.Tensor]:
                During training, returns the tracked gradients of the sampled fixed point trajectory according to args.n_losses or args.indexing.
                During inference, returns a list containing the fixed point solution only.
            dict: 
                A dict containing solver statistics.
        """
        deq_func, z_star = deq_decorator(func, z_star, speedy=self.speedy)
        
        if solver_kwargs is None:
            solver_kwargs = dict()

        if self.training:
            if type(solver_kwargs.get('f_thres', None)) in [int, float]:
                indexing = self._compute_f_thres(solver_kwargs['f_thres'])
            else:
                indexing = self.indexing

            z_out = []
            for f_thres, produce_grad in zip(indexing, self.produce_grad):
                z_star, info = self._solve_fixed_point(deq_func, z_star, f_thres=f_thres, solver_kwargs=solver_kwargs)
                z_out += produce_grad(self, deq_func, z_star, writer=backward_writer)   # See torchdeq.grad for implementations
                z_star = z_out[-1]                                                      # Add the gradient chain to the solver.

            z_out = [deq_func.vec2list(each) for each in z_out]
        else:
            # During inference, we directly solve for the fixed point
            z_star, info = self._solve_fixed_point(deq_func, z_star, 
                    f_thres=solver_kwargs.get('f_thres', self.eval_f_thres), solver_kwargs=solver_kwargs)
            
            sradius = self._sradius(deq_func, z_star) if sradius_mode else torch.zeros(1)
            info['sradius'] = sradius

            z_out = [deq_func.vec2list(z_star)]

        return z_out, info


_deq_class = {
    'indexing': DEQIndexing,
    'sliced': DEQSliced,
    }


def register_deq(deq_type, deq_class):
    """
    Registers a user-defined DEQ class for the get_deq function.

    This method adds a new entry to the DEQ class dict with the key as
    the specified DEQ type and the value as the DEQ class.

    Args:
        deq_type (str): The type of DEQ model to register. This will be used as the key in the DEQ class dict.
        deq_class (type): The class defining the DEQ model. This will be used as the value in the DEQ class dict.

    Example:
        >>> register_deq('custom', CustomDEQ)
    """
    _deq_class[deq_type] = deq_class


def get_deq(args):        
    """
    Factory function to generate an instance of a DEQ model based on the command line arguments.

    This function returns an instance of a DEQ model class based on the DEQ computational core specified in the
    command line arguments `args.core`. 
    For example, `args.core indexing ` for DEQIndexing, ` args.core sliced ` for DEQSliced, etc.
    
    DEQIndexing and DEQSliced build different computational graphs in training but keep the same for test.

    For DEQIndexing, it defines a computational graph with tracked gradients by indexing the internal solver
    states and applying the gradient function to the sampled states.
    This is equivalent to attaching the gradient function aside the full solver computational graph. 
    The maximum number of DEQ function calls is defined by `args.f_thres`.

    For DEQSliced, it slices the full solver steps into several smaller graphs (w/o grad).
    The gradient function will be applied to the end state of each subgraph (either PG or IFT).
    Then a new fixed point solver will resume from the output of the gradient function.
    This is equivalent to inserting the gradient function into the full solver computational graph. 
    The maximum number of DEQ function calls is defined by, for example, `args.f_thres + args.n_losses * args.grad`.

    Args:
        args (argparse.Namespace): Parsed command line arguments specifying the configuration of the DEQ model.
        
    Returns:
        torch.nn.Module: DEQ model based on the specified configuration.

    Example:
        >>> args = argparse.Namespace(core='sliced')
        >>> deq = get_deq(args)
    """
    assert args.core in _deq_class, 'Not registered DEQ class!'

    return _deq_class[args.core](args)
    

def reset_deq(model):
    """
    Resets the normalization and dropout layers of the given DEQ model (usually before each training iteration).

    Args:
        model (torch.nn.Module): The DEQ model to reset.

    Example:
        >>> deq_layer = DEQLayer(args)          # A Pytorch Module that defines the f in z* = f(z*, x).
        >>> reset_deq(deq_layer)
    """
    reset_norm(model)
    reset_dropout(model)
