import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F

from torchmetrics import Precision, Recall

from models.edge_encoder import STEdgeEncoder

from utils import TemporalData
from utils import init_weights

from torch_geometric.utils import subgraph


class InterpretableAssociation(pl.LightningModule):
    def __init__(self,
                 historical_steps: int,
                 embed_dim: int,
                 num_heads: int,
                 dropout: float,
                 num_att_layers: int,
                 #lr: float,
                 #weight_decay: float,
                 #T_max: int,
                 **kwargs) -> None:
        super(InterpretableAssociation, self).__init__()
        self.save_hyperparameters()

        self.historical_steps = historical_steps
        self.embed_dim = embed_dim
        
        self.edge_encoder = STEdgeEncoder(embed_dim=embed_dim,
                                            num_heads=num_heads,
                                            num_att_layers=num_att_layers,
                                            dropout=dropout)
        
        self.layer_norm = nn.LayerNorm(embed_dim)
        
        self.classifier = nn.Sequential(nn.Linear(embed_dim, embed_dim),
                                        nn.ReLU(),
                                        nn.Linear(embed_dim, 2))
        
        self.cls_loss = torch.nn.CrossEntropyLoss()

        self.binary_precision = Precision(task='binary', average='macro', num_classes=2)
        self.recall = Recall(task='binary', average='macro', num_classes=2)

        self.apply(init_weights)


    def forward(self,
                data: TemporalData,
                st_embed):
        
        edge_index = data.edge_index
        st_edge_embed = self.edge_encoder(data, st_embed, edge_index)
        cls_out = self.classifier(self.layer_norm(st_edge_embed))
        edge_conf = F.softmax(input=cls_out, dim=1)[:, 1]
        edge_mask = edge_conf > 0.5
        edge_pred = data.edge_index[:, data.v2x_mask][:, edge_mask]
        data.edge_asso = edge_pred
        data.edge_conf = edge_conf
        data.edge_mask = edge_mask
        
        return cls_out, st_edge_embed


    def training_step(self, data, batch_idx):
        
        out, edge_embed = self(data)
        
        link_label = data.gt_mask[data.candidate_mask]
        out = out[data.candidate_mask]
        
        loss = self.cls_loss(out, link_label.to(torch.int64))
        self.log('train_loss', loss, prog_bar=True, on_step=True, on_epoch=True, batch_size=1)
        return loss

    def validation_step(self, data, batch_idx):

        out, edge_embed = self(data)
        
        link_label = data.gt_mask[data.candidate_mask]
        out = out[data.candidate_mask]

        loss = self.cls_loss(out, link_label.to(torch.int64))
        self.log('val_loss', loss, prog_bar=True, on_step=False, on_epoch=True, batch_size=1)

        pred_link = out.argmax(dim=1)

        self.binary_precision.update(pred_link, link_label)
        self.recall.update(pred_link, link_label)
        self.auroc.update(pred_link, link_label)
        self.log('Precision', self.binary_precision, prog_bar=True, on_step=False, on_epoch=True, batch_size=data.num_graphs)
        self.log('Recall', self.recall, prog_bar=True, on_step=False, on_epoch=True, batch_size=data.num_graphs)

    def configure_optimizers(self):
        
        optimizer = torch.optim.AdamW(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=self.T_max, eta_min=0.0)
        return [optimizer], [scheduler]

    @staticmethod
    def add_model_specific_args(parent_parser):
        parser = parent_parser.add_argument_group('V2XLinkPredictor')
        parser.add_argument('--historical_steps', type=int, default=50)
        parser.add_argument('--embed_dim', type=int, default=64)
        parser.add_argument('--num_heads', type=int, default=8)
        parser.add_argument('--dropout', type=float, default=0.1)
        parser.add_argument('--num_temporal_layers', type=int, default=2)
        parser.add_argument('--lr', type=float, default=3e-3)
        parser.add_argument('--weight_decay', type=float, default=1e-4)
        parser.add_argument('--T_max', type=int, default=64)
        return parent_parser
