from typing import Any, Dict, Tuple
import torch
from lightning import LightningModule
from torch.nn.parallel import DistributedDataParallel
from torchmetrics import MaxMetric, MeanMetric, MeanSquaredError, MeanAbsoluteError, MeanAbsolutePercentageError, \
    MinMetric, Accuracy, Precision, Recall, F1Score, AUROC
from .powergpt_components.revin import RevIN
from .PowerGPT import patch_masking, MaskMSELoss, create_patch
import torch.nn as nn
from ..loss.loss import GraphNTXentLoss
def get_model(model):
    """Return the model maybe wrapped inside `model`."""
    return model.module if isinstance(model, (DistributedDataParallel, nn.DataParallel)) else model

class PowerGPTModule(LightningModule):

    def __init__(
            self,
            name,
            net: torch.nn.Module,
            optimizer: torch.optim.Optimizer,
            scheduler: torch.optim.lr_scheduler,
    ) -> None:
        super().__init__()
        # this line allows to access init params with 'self.hparams' attribute
        # also ensures init params will be stored in ckpt
        self.save_hyperparameters(logger=False)

        self.net = net
        self.patch_len = self.net.patch_len
        self.stride = self.net.stride
        self.mask_ratio = self.net.mask_ratio
        self.context_points = self.net.context_points
        # loss function
        self.criterion_mse = MaskMSELoss()
        self.criterion_cl = GraphNTXentLoss()
        # pretrain, imputation, prediction
        self.train_mse = MeanSquaredError()
        self.train_mae = MeanAbsoluteError()
        self.train_mape = MeanAbsolutePercentageError()

        # for averaging loss across batches
        self.train_loss = MeanMetric()

        self.revin = RevIN(num_features=net.n_vars, affine=False)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

    def on_train_start(self) -> None:
        pass

    def on_before_batch_transfer(self, batch: Any, dataloader_idx: int) -> Any:
        #读x的格式x+delta or x
        x = batch.x
        batch.batch_size_ = batch.input_id.shape[0]
        # batch.table_index = batch.n_id//
        x = self.revin(x, 'norm')
        batch.x1, _ = create_patch(x[:,:self.context_points,:], patch_len=self.patch_len, stride=self.stride)
        batch.x2, _ = create_patch(x[:,-self.context_points:,:], patch_len=self.patch_len, stride=self.stride)

        batch.x2_cov, _ = create_patch(batch.x_cov.transpose(2,1)[:,-self.context_points:,:], patch_len=self.patch_len, stride=self.stride)
        batch.x_cov, _ = create_patch(batch.x_cov.transpose(2,1)[:,:self.context_points,:], patch_len=self.patch_len, stride=self.stride)
        x, y, mask = patch_masking(x[:,:self.context_points,:], stride=self.stride, patch_len=self.patch_len, mask_ratio=self.mask_ratio)
        mask = mask.unsqueeze(-1).repeat(1, 1, 1, self.patch_len)
        batch.x = x
        batch.y = y
        batch.mask = mask
        return batch # x[batch, context_point ,1] batch_x_nomask [batch, context+delta, 1]

    def model_step(
            self, batch: Tuple[torch.Tensor, torch.Tensor]
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        batch.y = batch.y[:batch.batch_size_]
        
        pred = self.forward(batch)
        batch_size, n_vars, num_patch, patch_len = pred.shape
        
        batch.x = batch.x1
        h1 = self.forward(batch).reshape(batch_size, -1)[:batch.batch_size_]
        batch.x = batch.x2
        batch.x_cov = batch.x2_cov
        h2 = self.forward(batch).reshape(batch_size, -1)[:batch.batch_size_]
        
        # batch_size, n_vars, num_patch, patch_len = pred.shape
        pred = self.revin(pred.reshape(batch_size, -1, 1), 'denorm')
        pred = pred.reshape(batch_size, n_vars, num_patch, patch_len)[:batch.batch_size_]
        batch.mask = batch.mask[:batch.batch_size_]
        
        loss_cl = self.criterion_cl(h1, h2)
        loss_mse = self.criterion_mse(pred, batch.y, batch.mask)
        loss = loss_cl + loss_mse
        return loss, pred, batch.y

    def training_step(
            self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
    ) -> torch.Tensor:

        loss, preds, targets = self.model_step(batch)
        mask = batch.mask
        # update and log metrics
        self.train_loss(loss)
        self.train_mse(preds[mask], targets[mask])
        self.train_mae(preds[mask], targets[mask])
        self.train_mape(preds[mask], targets[mask])

        self.log("train/loss", self.train_loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train/mse", self.train_mse, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train/mae", self.train_mae, on_step=False, on_epoch=True, prog_bar=True)
        self.log("train/mape", self.train_mape, on_step=False, on_epoch=True, prog_bar=True)

        return loss

    def on_train_epoch_end(self) -> None:
        pass

    def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
        pass


    def on_validation_epoch_end(self) -> None:
        pass

    def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
        pass

    def on_test_epoch_end(self) -> None:
        """Lightning hook that is called when a test epoch ends."""
        pass


    def freeze(self):
        """
        freeze the model head
        require the model to have head attribute
        """
        if hasattr(get_model(self.net), 'head'):
            # print('model head is available')
            for param in get_model(self.net).parameters(): param.requires_grad = False
            for param in get_model(self.net).head.parameters(): param.requires_grad = True
            # print('model is frozen except the head')

    def unfreeze(self):
        for param in get_model(self.net).parameters(): param.requires_grad = True

    def configure_optimizers(self) -> Dict[str, Any]:

        """Configures optimizers and learning-rate schedulers to be used for training.

        Normally you'd need one, but in the case of GANs or similar you might need multiple.

        Examples:
            https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers

        :return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
        """

        optimizer = self.hparams.optimizer(params=self.parameters())
        if self.hparams.scheduler is not None:
            scheduler = self.hparams.scheduler(optimizer=optimizer)
            return {
                "optimizer": optimizer,
                "lr_scheduler": {
                    "scheduler": scheduler,
                    "monitor": "train/loss",
                    "interval": "epoch",
                    "frequency": 1,
                },
            }
        return {"optimizer": optimizer}


if __name__ == "__main__":
    ...
