from library.rnn_architectures.base_rnn import *

import torch

class Rank(DynamicalSystem):
    '''
    Low rank RNN
    '''

    def __init__(self, dim, rank, optimize_input_maps=True, in_dims=(), noise=0.0, noise_dim=None, time_constant=1.0,
                 activation=torch.tanh, device='cpu'):
        super().__init__(noise_dim=noise_dim, noise=noise, device=device)

        self.initial_state_linear_combination = nn.Parameter((2*torch.rand(rank, device=device)-1) / np.sqrt(rank))
        self.W_in_coef = nn.ParameterList([nn.Parameter(torch.ones(in_dim, device=device)*np.sqrt(dim/sum(in_dims))) for in_dim in in_dims])

        self.dim = dim
        self.rank = rank
        self.optimize_input_maps = optimize_input_maps
        self.in_dims = in_dims

        self.time_constant = time_constant
        self.activation = activation

        self.cum_dims = np.cumsum([0]+list(self.in_dims))

        self.init_weight = nn.Parameter(torch.tensor(1.0, device=device), requires_grad=False)
        self.init_bias = nn.Parameter(torch.tensor(0.0, device=device), requires_grad=False)

        self.define_parameters()

    def define_parameters(self):
        self.W_column = nn.Linear(self.rank + sum(self.in_dims), self.dim, bias=False, device=self.device)
        self.W_row = nn.Linear(self.rank, self.dim, bias=False, device=self.device)

        with torch.no_grad():
            self.W_column.weight.copy_(torch.randn_like(self.W_column.weight))
            self.W_row.weight.copy_(torch.randn_like(self.W_column.weight[:,:self.rank])/np.sqrt(self.rank))


    def construct_weight(self):

        W = torch.einsum('jr,kr->jk', [self.W_row.weight, self.W_column.weight[:,:self.rank]/torch.linalg.norm(self.W_column.weight[:,:self.rank], dim=0)])

        return W

    def construct_input_weight(self):
        if self.optimize_input_maps:
            W_in = [self.W_column.weight[:,self.rank+self.cum_dims[i]:self.rank+self.cum_dims[i+1]]/
                    torch.linalg.norm(self.W_column.weight[:,self.rank+self.cum_dims[i]:self.rank+self.cum_dims[i+1]], dim=0) *
                    self.W_in_coef[i].unsqueeze(0) for i in range(len(self.cum_dims)-1)]
        else:
            W_in = [(self.W_column.weight[:, self.rank + self.cum_dims[i]:self.rank + self.cum_dims[i + 1]]/
                    torch.linalg.norm(self.W_column.weight[:,self.rank+self.cum_dims[i]:self.rank+self.cum_dims[i+1]], dim=0) *
                    self.W_in_coef[i].unsqueeze(0)).detach() for i in range(len(self.cum_dims) - 1)]

        return W_in

    def get_components(self):

        return [self.W_row.weight.T, self.W_column.weight.T]

    def get_parameterization(self):

        W = self.construct_weight()

        W_in = self.construct_input_weight()

        self.W = nn.Parameter(W.detach())

        self.W_in = nn.ParameterList([i.detach() for i in W_in])

        return [self.W, W] + [[i,j] for i,j in zip(self.W_in, W_in)]

    def get_initial_state(self, batch_size):

        normalized_columns = self.W_column.weight[:, :self.rank]/torch.linalg.norm(self.W_column.weight[:,:self.rank], dim=0)

        initial_state = np.sqrt(self.dim)*(self.initial_state_linear_combination.unsqueeze(0) @ normalized_columns.T).squeeze(0)

        return initial_state.repeat(batch_size, 1).detach()

    def f(self, x, *args):

        inputs = self.activation(x) @ self.W.T
        inputs = inputs + sum(u @ self.W_in[i].T for i, u in enumerate(list(args)))

        return (inputs-x)*self.time_constant**-1#
