"""
Toolbox for model definition.
"""
from numpy import False_
import torch
import torch.nn as nn
from utils import relu_der
from learn_alg import *


# random
class RNN(nn.Module):
    """Feedback of error: online position difference."""

    def __init__(
        self,
        n_inputs,
        n_outputs,
        n_neurons,
        dtype,
        dt,
        tau,
        fb_delay=0,
        fb_density=1,
        pos_err=False,
        recurrent=True,
        error_type="error",
        error_detach=False,
    ):
        super(RNN, self).__init__()
        self.n_neurons = n_neurons
        self.alpha = dt / tau
        self.dt = dt

        self.rnn = nn.RNN(n_inputs, n_neurons, num_layers=1, bias=False)
        self.output = nn.Linear(n_neurons, n_outputs)
        self.feedback = nn.Linear(n_outputs, n_neurons)
        self.dtype = dtype

        self.nonlin = torch.nn.ReLU()
        self.nonlin_der = relu_der

        self.mask = nn.Linear(n_outputs, n_neurons, bias=False)
        self.mask.weight = nn.Parameter(
            (torch.rand(n_neurons, n_outputs) < fb_density).float()
        )
        self.mask.weight.requires_grad = False

        self.delay = fb_delay
        self.pos_err = pos_err

        self.recurrent = recurrent
        self.error_type = error_type
        self.error_detach = error_detach

    def init_hidden(self):
        return ((torch.rand(self.batch_size, self.n_neurons) - 0.5) * 0.2).type(
            self.dtype
        )

    # ONE SIMULATION STEP
    def f_step(self, xin, x1, r1, v1fb, v1, pin, xref, fb_in=True):
        in1 = xin.to(self.rnn.weight_ih_l0.T) @ self.rnn.weight_ih_l0.T
        in2 = r1 @ self.rnn.weight_hh_l0.T
        if not self.recurrent:#
            in2 *= 0
        if self.error_detach:
            error = v1fb.detach()
            error = torch.autograd.Variable(error, requires_grad=True)
            in3 = (
                error @ (self.mask.weight * self.feedback.weight).T
            )  # feedback
            in3_s = v1fb @ (self.mask.weight * self.feedback.weight).T
        else:
            error = v1fb
            in3 = error @ (self.mask.weight * self.feedback.weight).T
            in3_s = in3
        if fb_in:
            x1 = (1 - self.alpha) * x1 + self.alpha * (
                in1 + in2 + in3 + self.feedback.bias.T
            )
        else:
            x1 = (1 - self.alpha) * x1 + self.alpha * (in1 + in2 + self.feedback.bias.T)
        r1 = self.nonlin(x1)  # activation
        vt_ = self.output(r1)  # output 
        vt = vt_ + pin.cuda()  # velocity
        if self.pos_err:
            v1 = v1 + self.dt * (xref.cuda() - vt)  # integrated output error (position)
        else:
            v1 = xref.cuda() - vt  # direct output error (velocity)
        return x1, r1, vt_, v1, in1, in2, in3, error

    # GET VELOCITY OUTPUT (NOT ERROR)
    def get_output(self, testl1):
        return self.output(testl1)

    # RUN MODEL
    def forward(
        self,
        X,
        Xpert,
        Xref,
        local_learning_rule=None,
        fb_in=True,
        use_transpose=False,
        analysis=False,
        compute_hessian=False,
    ):
        gradients = []
        bptt = BPTT(lr=0.0)  # just for comparison of grads
        criterion = nn.MSELoss(reduction="none")  # hard coded for now
        self.batch_size = X.size(1)
        x1 = self.init_hidden()
        self.x0 = x1
        r1 = self.nonlin(x1)
        v1 = self.output(r1)
        # what to save
        hidden1 = [r1]
        poserr = [v1 * 0]
        raw_output = [v1]
        # initial variables needed for learning rules
        self.presum_alt = self.alpha * r1
        self.prepostsum = torch.zeros((x1.shape[1], x1.shape[1]), device=x1.device)
        # simulate time
        local_alignments = []
        local_grads = []
        r_inputs1 = []
        r_inputs2 = []
        r_inputs3 = []
        errors = []
        hidden_raw = []
        hidden_raw.append(x1)
        r_inputs1.append(torch.zeros_like(x1))
        r_inputs2.append(r1 @ self.rnn.weight_hh_l0.T)
        r_inputs3.append(torch.zeros_like(x1))
        control = []
        for j in range(X.size(0)):
            if local_learning_rule.__class__.__name__ in ['BPTT_2']:
                print(j)
            # save previous time step for RFLO
            r1_prev = torch.clone(r1)
            if j < self.delay:
                x1, r1, vt, v1, in1, in2, in3, error = self.f_step(
                    X[j], x1, r1, poserr[0] * 0, v1, Xpert[j], Xref[j], fb_in
                )
            else:
                error = poserr[j - self.delay]
                if self.error_type == "error":
                    error = error
                else:
                    Exception("error type not implemented")

                x1, r1, vt, v1, in1, in2, in3, error = self.f_step(
                    X[j], x1, r1, error, v1, Xpert[j], Xref[j], fb_in
                )
                
            # adaptation
            if local_learning_rule is not None:
                fb = v1[0] 
                # update learning variables
                local_learning_rule.update(model=self, r1 = r1, r1_prev = r1_prev, x1 = x1)

                # get local gradient
                if local_learning_rule.__class__.__name__ in ['BPTT', 'BPTT_2', 'SAM']:
                    mse_loss = criterion(fb, fb * 0).mean()
                    inputs = (X[j], x1, r1, error, v1, Xpert[j], Xref[j], fb_in)
                    dw = local_learning_rule.get_recurrent_grads(
                        self, mse_loss, inputs
                    )
                else:
                    dw = local_learning_rule.get_recurrent_grads(
                        self, fb, use_transpose=use_transpose
                    )

                    # calculate local alignment
                    mse_loss = criterion(fb, fb * 0).mean()
                    true_grad = bptt.get_recurrent_grads(
                        self, mse_loss )
                    alignment_j = torch.nn.functional.cosine_similarity(
                        dw.detach().clone().flatten(),
                        true_grad.detach().clone().flatten(),
                        dim=0,
                    )
                    local_alignments.append(alignment_j)
                    if compute_hessian:
                        Exception('Not implemented!')
                        
                    
                # save local gradient
                local_grads.append(dw.detach().clone())
                # update weights
                self.rnn.weight_hh_l0.data = (
                    self.rnn.weight_hh_l0.data.detach() + local_learning_rule.lr * dw.detach()
                )
            # save for later
            hidden1.append(r1)
            hidden_raw.append(x1)
            raw_output.append(vt)
            poserr.append(v1)
            control.append(in3)
            errors.append(error)
            r_inputs1.append(in1)
            r_inputs2.append(in2)
            r_inputs3.append(in3)

        hidden1 = torch.stack(hidden1) if not analysis else hidden1
        poserr = torch.stack(poserr)
        raw_output = torch.stack(raw_output)
        hidden_raw = (
            hidden_raw if analysis else torch.stack(hidden_raw).detach()
        )
        control = torch.stack(control)

        extras = {}
        extras["local_grads"] = local_grads
        extras["local_alignments"] = local_alignments
        extras["raw_output"] = raw_output
        extras["r_inputs1"] = r_inputs1
        extras["r_inputs2"] = r_inputs2
        extras["r_inputs3"] = r_inputs3
        extras["gradients"] = gradients
        extras["hidden_raw"] = hidden_raw
        extras["control"] = control
        extras["errors"] = errors

        return poserr, hidden1, extras
