"""
Toolbox for model definition.
"""
from numpy import False_
import torch
import torch.nn as nn


class LearningRule:
    """Make learning rule selection more modular."""

    def __init__(self, lr):
        self.lr = lr

    def update(self, **kwargs):
        pass

    def get_recurrent_grads(self):
        pass


class SAM(LearningRule):
    def __init__(self, lr):
        super().__init__(lr)

    def get_recurrent_grads(self, model, loss, inputs, criterion = nn.MSELoss(reduction="none"), rho = 1):
        # SHARPNESS AWARE UPDATE (SAM) (from https://github.com/davda54/sam/blob/main/sam.py)
        
        dw, = torch.autograd.grad(
            loss,
            model.rnn.weight_hh_l0,
            retain_graph=True,
            create_graph=True,
        )

        grad_norm = dw.norm(2)            
        scale = rho / (grad_norm + 1e-12)
        old_p = model.rnn.weight_hh_l0.data.clone()  
        e_w = 1 * dw * scale.to(dw.device) # non-adaptive

        # climb to the local maximum "w + e(w)"
        model.rnn.weight_hh_l0.data = model.rnn.weight_hh_l0.data + e_w 

        # second forward pass
        x, x1, r1, error, v1, xpert, xref, fb_in = inputs 
        _, _, _, v1_2, _ = model.f_step(
            x, x1, r1, error, v1, xpert, xref, fb_in
        )

        mse_loss_2 = criterion(v1_2[0], v1_2[0] * 0).mean()

        # get back to "w" from "w + e(w)"
        model.rnn.weight_hh_l0.data = old_p 

        sam_dw,  = torch.autograd.grad(
            mse_loss_2 ,
            model.rnn.weight_hh_l0,
            retain_graph=True,
            create_graph=True,
        )
        
        return - sam_dw


class BPTT_2(LearningRule):
    def __init__(self, lr):
        super().__init__(lr)

    def get_recurrent_grads(self, model, loss, inputs=None, alpha = 1e-3):
        # here we get the gradients

        dw, = torch.autograd.grad(
            loss,
            model.rnn.weight_hh_l0,
            retain_graph=True,
            create_graph=True,
        )

        hessian = []
        for g in dw.flatten():
            gg, = torch.autograd.grad(g, model.rnn.weight_hh_l0, retain_graph=True)
            hessian.append(gg.flatten().detach())
        hessian = torch.stack(hessian) # |theta| x |theta|

        # compute approximate second order gradient
        second_dw = (
            torch.pinverse(
                torch.add(hessian, torch.eye(hessian.shape[0],device=hessian.device), alpha=alpha)
                ) @ dw.detach().flatten()
            ).view_as(dw)

        return - second_dw


class BPTT(LearningRule):
    def __init__(self, lr):
        super().__init__(lr)

    def get_recurrent_grads(self, model, loss, inputs=None):
        # here we get the gradients
        dw, = torch.autograd.grad(
            loss,
            model.rnn.weight_hh_l0,
            retain_graph=True,
            create_graph=True,
        )
        return - dw


class FED(LearningRule):
    def __init__(self, lr):
        super().__init__(lr)

    def update(self, model, r1, r1_prev, x1):
        # here we update the eligibility traces
        # TODO: is this correct?
        model.presum_alt = model.alpha * r1 + (1 - model.alpha) * model.presum_alt

    def get_recurrent_grads(self, model, fb, use_transpose=False):
        # here we get the gradients
        if use_transpose:
            w = model.output.weight
        else:
            w = (model.feedback.weight).T
        dw = torch.outer(fb @ w, model.presum_alt[0])  
        return dw


class FED_t(FED):
    def __init__(self, lr):
        super().__init__(lr)

    def get_recurrent_grads(self, model, fb, use_transpose=True):
        # here we get the gradients
        dw = torch.outer(fb @ model.output.weight, model.presum_alt[0])
        return dw


class RFLO(LearningRule):
    def __init__(self, lr):
        super().__init__(lr)

    def update(self, model, r1, r1_prev, x1):
        # here we update the eligibility traces
        model.prepostsum = (
            model.alpha * torch.outer(model.nonlin_der(x1[0]), r1_prev[0])
            + (1 - model.alpha) * model.prepostsum
        )

    def get_recurrent_grads(self, model, fb, use_transpose=False):
        # here we get the gradients
        if use_transpose:
            w = model.output.weight
        else:
            w = (model.feedback.weight).T
        dw = ((fb @ w) * model.prepostsum.T).T
        return dw


class RFLO_t(RFLO):
    def __init__(self, lr):
        super().__init__(lr)

    def get_recurrent_grads(self, model, fb, use_transpose=True):
        # here we get the gradients
        dw = (
            torch.outer(
                fb @ model.output.weight,
                torch.ones(model.prepostsum[0].shape[0], device=fb.device),
            )
            * model.prepostsum[0]
        )
        return dw
