from typing import List, Dict,Callable

import torch
import torch.nn as nn
import utils
from dct import LinearDCT
from disentangle.AbstractTrainer import AbstractLitModule
from models.encoders import MLP, TransformerEncoder
from scipy.special import logit
from solver.ode_layer import ODESYSLayer
import numpy as np

class MechanisticModule(torch.nn.Module):
    def __init__(
        self,
        batch_size: int = 10,
        order: int = 2,  # order of polynomial of MNN
        state_dim: int = 1,
        n_step: int = 60,
        mlp_enc: bool = True,
        dct_layer: bool = False,
        code_sharing: Dict[int,List[int]] = None,
        feature_sharing_fn: Callable = None,
        **kwargs,
    ):
        super().__init__()

        self.n_step = n_step
        # the order of derivatives we use to approximate,
        # not necessarily the ground truth order of the ODE
        self.order = order

        # state dimension
        self.state_dim = state_dim

        # hidden dim for submodules
        self.hidden_dim = kwargs.get("hidden_dim", 1024)
        self.n_views = kwargs.get("n_views", 2)

        self.batch_size = batch_size
        self.n_coeff = self.n_step * (self.order + 1)
        self.n_iv_steps = kwargs.get("n_iv_steps", 1)  # how many first steps information

        # for initial value problem  n_iv_steps = 1
        self.step_dim = (self.n_step - 1) * self.state_dim
        self.ode_layer = ODESYSLayer(
            bs=batch_size * self.n_views,
            n_ind_dim=1,
            order=self.order,
            n_equations=self.state_dim,  # equals number of states
            n_dim=self.state_dim,
            n_iv=1,
            n_step=self.n_step,
            n_iv_steps=self.n_iv_steps,
            solver_dbl=True,
        )
        # define the dimensions
        self.rhs_dim = self.state_dim * self.n_step  # time_steps * state_dim

        self.coeff_dim = (
            self.ode_layer.n_ind_dim
            * self.ode_layer.n_equations
            * self.ode_layer.n_step
            * self.ode_layer.n_dim
            * (self.order + 1)
        )  # self.state_dim*(self.order+1) #self.rhs_dim*(self.order+1)
        self.param_dim = kwargs.get("param_dim", 20)
        self.embedding_dim = kwargs.get("embedding_dim", 64)

        # add dct transform if necessary:
        self.dct_layer: bool = dct_layer
        self.freq_frac_to_keep: float = kwargs.get("freq_frac_to_keep", 0.5)
        if dct_layer:
            self.dct: nn.Module = LinearDCT(self.n_step, "dct", norm="ortho").double()
            self.idct: nn.Module = LinearDCT(self.n_step, "idct", norm="ortho").double()
            input_dim = int(self.freq_frac_to_keep * self.n_step) * self.state_dim
        else:
            input_dim = self.rhs_dim

        # by default use mlps for encoding
        ##############################################

        if mlp_enc:
            self.params_enc = MLP(input_dim, self.param_dim, self.hidden_dim)
        else:
            self.params_enc = TransformerEncoder(
                self.n_step, self.state_dim, self.param_dim, self.embedding_dim, self.hidden_dim
            )
        ##############################################
        # decode from params to rhs
        self.rhs_t = MLP(input_dim=self.param_dim, output_dim=self.rhs_dim, hidden_dim=self.hidden_dim, num_layers=3)

        self.coeffs_mlp = MLP(
            input_dim=self.param_dim,
            output_dim=self.coeff_dim,
            hidden_dim=self.hidden_dim,
            num_layers=3,
        )

        self.pre_steps_mlp = nn.Sequential(
            nn.Linear(self.param_dim, self.hidden_dim),
            nn.LeakyReLU(),
            nn.Linear(self.hidden_dim, self.hidden_dim),
            nn.LeakyReLU(),
        )

        self.steps_layer = nn.Linear(self.hidden_dim, self.step_dim)

        # set step bias to make initial step 0.1
        step_bias = logit(0.1)
        self.steps_layer.weight.data.fill_(0.0)
        self.steps_layer.bias.data.fill_(step_bias)
        
        
        self.code_sharing = code_sharing
        self.feature_sharing_fn = feature_sharing_fn

    def state_transform(self, states: torch.Tensor):
        # states: [n_views, bs, n_step, state_dim]
        freqs: torch.Tensor = self.dct(states.swapaxes(-1, -2)).swapaxes(-1, -2)
        return freqs[..., : int(self.freq_frac_to_keep * self.n_step), :]

    def state_inverse_transform(self, freqs: torch.Tensor):
        # freqs: [bs, n_freqs_to_keep, state_dim]
        # fill the high-frequency that we droped before with zero
        freqs: torch.Tensor = torch.cat(
            [freqs, torch.zeros(*freqs.shape[:2], self.n_step - freqs.shape[-2], freqs.shape[-1]).type_as(freqs)],
            dim=-2,
        )
        return self.idct(freqs.swapaxes(-1, -2)).swapaxes(-1, -2)

    def decode_from_params(self, params: torch.Tensor):
        # Righthandside of the ODE
        rhs: torch.Tensor = self.rhs_t(params)  # (bs, n_step*state_dim)
        # Time varying ODE coefficients
        coeffs: torch.Tensor = self.coeffs_mlp(params)  # (bs, n_step*state_dim*(order+1))
        # Learned steps
        _steps = self.pre_steps_mlp(params)  # (bs, hidden_dim)
        steps: torch.Tensor = self.steps_layer(_steps)  # (bs, n_steps-1)
        steps: torch.Tensor = torch.sigmoid(steps).clip(min=0.001, max=0.999)  # (bs, n_steps-1)
        return rhs, coeffs, steps

    def solve(self, params: torch.Tensor, iv_rhs: torch.Tensor):
        rhs, coeffs, steps = self.decode_from_params(params)
        u0, u1, u2, eps, steps = self.ode_layer(coeffs=coeffs, rhs=rhs, iv_rhs=iv_rhs, steps=steps)
        u0 = u0.squeeze(1)  # (n_views*bs, ts, state_dim)
        return u0.reshape(self.n_views, -1, self.n_step, self.state_dim)

    def feature_sharing(self, params: torch.Tensor, **kwargs):
        # this should be inherent to the data generating process, so it should be an attribute
        # to the corresponding dataset
        return params if len(params) == 1 else self.feature_sharing_fn(params, **kwargs)

    def forward(self, states: torch.Tensor, **kwargs):
        # states: (bs, n_step, state_dim)
        # extarct iv steps before dct layer, make sure it is in the time domain
        iv_rhs = states[..., : self.n_iv_steps, :]  # (bs, n_iv_steps, state_dim)
        if self.dct_layer:
            states: torch.Tensor = self.state_transform(states)
        # parameter encoding
        params: torch.Tensor = self.params_enc(
            states.reshape(-1, states.shape[-2] * states.shape[-1])
        )  # (bs, param_dim)
        params: torch.Tensor = params.reshape(-1, self.batch_size, self.param_dim)  # (n_views, bs, param_dim)
        iv_rhs = iv_rhs.reshape(
            -1, self.batch_size, self.n_iv_steps, self.state_dim
        )  # (n_views, bs, n_iv_steps, state_dim)
        shared = self.feature_sharing(params, **kwargs)
        # no matter apply dct layer or not, u0 always in time domain
        # shape: [n_views, bs, ts, state_dim]
        u0s = self.solve(shared.view(-1, self.param_dim), iv_rhs.view(-1, self.n_iv_steps, self.state_dim))
        if self.dct_layer:
            u0s = self.state_transform(u0s.double())  # to convert u0s to the freq domain; make sure it is double
        return states, u0s, params, shared  # u0: [bs, ts, state_dim], params: [bs, param_dim]


class MechanisticLitModule(AbstractLitModule):
    def __init__(
        self,
        learning_rate: torch.float64 = 1e-5,
        alignment_reg=10,
        eval_metrics: List[str] = [],
        **model_kwargs,
    ):
        super().__init__(
            learning_rate=learning_rate,
            eval_metrics=eval_metrics,
            **model_kwargs,
        )
        model_kwargs['feature_sharing_fn'] = self.feature_sharing_fn
        
        for k, v in model_kwargs.items():
            setattr(self, k, v)
        # save hyperparameters
        self.save_hyperparameters()

        self.model = MechanisticModule(**model_kwargs).double()

        # # add xavier initializtaion
        # if self.train():
        #     self.model.train()
        #     utils.xavier_init(self.model)

        self.loss = nn.MSELoss().double()
        self.alignment_reg = alignment_reg
        self.type = torch.float64
        
        # if self.training:
        #     utils.xavier_init(self.model)

    def forward(self, states: torch.Tensor, **kwargs):
        return self.model(states, **kwargs)

    def training_step(self, batch, batch_idx):
        # [n_views, bs, ts, state_dim]
        batch["states"] = batch["states"].to(self.type)
        # depending on if we have dct layer or not, the output states could be in freq space
        states, u0s, params, shared = self.forward(
            **batch
        )  # here: [n_views * bs, ts, state_dim], [n_views * bs, param_dim]
        # states = states.reshape(-1, self.model.n_step, self.model.state_dim)
        # u0: [n_views, bs, ts, state_dim], params: [n_views, bs, param_dim]
        recon_loss = self.loss(u0s.double().reshape(*states.shape), 
                               states.double())
        if self.alignment_reg > 0.:
            param_loss = self.loss(params[..., : self.model.n_views - 1], shared[..., : self.model.n_views - 1])
            loss = recon_loss + self.alignment_reg * param_loss  # TODO: maybe use another optimizer later
            self.log("param_loss", param_loss, prog_bar=True, on_step=True, on_epoch=True)
        else:
            loss = recon_loss
        self.log("train_loss", recon_loss, prog_bar=True, on_step=True, on_epoch=True)
        
        return loss

    def validation_step(self, batch, batch_idx):
        states = batch["states"].to(self.type)
        if self.model.dct_layer:
            states = self.model.state_transform(states)
        params = self.model.params_enc(states.reshape(-1, states.shape[-2] * states.shape[-1])).cpu().numpy()
        self.misc["pred_params"].append(params)
        if "gt_params" in batch:
            if isinstance(batch["gt_params"], dict):
                self.misc["gt_params"].append(torch.stack(list(batch["gt_params"].values()), -1).cpu().numpy())
            else:
                self.misc["gt_params"].append(batch["gt_params"].cpu().numpy())

    # TODO: possible to forecast over the whole trajectory
    def predict_step(self, batch, batch_id):
        # select the first half as encoder input
        states = batch["states"].to(self.type)
        input_states = states[..., : self.model.n_step, :]
        future_states = states[..., self.model.n_step :, :]
        if self.model.dct_layer:
            input_states = self.model.state_transform(input_states)

        params = self.model.params_enc(input_states.reshape(-1, input_states.shape[-2] * input_states.shape[-1]))

        iv_rhs = future_states[..., : self.model.n_iv_steps, :].reshape(
            -1, self.model.batch_size, self.model.n_iv_steps, self.model.state_dim
        )  # (n_views, bs, n_iv_steps, state_dim)
        # no matter apply dct layer or not, u0 always in time domain
        # shape: [n_views, bs, ts, state_dim]
        u0s = self.model.solve(
            params.view(-1, self.model.param_dim), iv_rhs.view(-1, self.model.n_iv_steps, self.model.state_dim)
        )
        forecast_loss = self.loss(u0s.double(), future_states.double())
        return forecast_loss
