from collections import OrderedDict, defaultdict
from typing import Callable

import torch
import torch.nn as nn

from adl4cv.classification.model.classification_module import ClassificationModuleHyperParameterSet, \
    ClassificationModule, TrainStage
from adl4cv.classification.model.convolutional.resnet import ResNet18Definition
from adl4cv.classification.model.models import ClassificationModuleDefinition, ModelType
from adl4cv.classification.model.jenny import losses
from adl4cv.classification.model.jenny.gnn_base import GNNReID
from adl4cv.classification.model.jenny.graph_generator import GraphGenerator
from adl4cv.classification.optimizer.optimizers import OptimizerDefinition
from adl4cv.classification.optimizer.schedulers import SchedulerDefinition


class JennyNetHyperParameterSet(ClassificationModuleHyperParameterSet):
    """HyperParameterSet of the GeneralNet"""

    def __init__(self,
                 classif_net_def: ClassificationModuleDefinition = ResNet18Definition(),
                 distinct_graph: bool = True,
                 optimizer_definition: OptimizerDefinition = None,
                 scheduler_definition: SchedulerDefinition = None, **kwargs):
        """
        Creates new HyperParameterSet
        :param classif_net_def: The definition of the feature extraction part
        :param mpn_net_def: The definition of the message passing network part
        :func:`~ClassificationModuleHyperParameterSet.__init__`
        """
        super().__init__(optimizer_definition, scheduler_definition, **kwargs)
        self.classif_net_def = classif_net_def

        self.gnn_params = {"pretrained_path": "no",
                      "red": 1,
                      "cat": 0,
                      "every": 0,
                      "gnn": {
                          "num_layers": 1,
                          "aggregator": "add",
                          "num_heads": 4,
                          "attention": "dot",
                          "mlp": 1,
                          "dropout_mlp": 0.1,
                          "norm1": 1,
                          "norm2": 1,
                          "res1": 1,
                          "res2": 1,
                          "dropout_1": 0.1,
                          "dropout_2": 0.1,
                          "mult_attr": 0},
                      "classifier": {
                          "neck": 1,
                          "num_classes": 10,
                          "dropout_p": 0.4,
                          "use_batchnorm": 0}}

        self.graph_params = {"sim_type": "correlation",
                        "thresh": "no",
                        "set_negative": "hard",
                        "distinct_graph": distinct_graph}


class JennyNetDefinition(ClassificationModuleDefinition):
    """Definition of the GeneralNet"""

    def __init__(self, hyperparams: JennyNetHyperParameterSet = JennyNetHyperParameterSet()):
        super().__init__(ModelType.JennyNet, hyperparams)

    @property
    def _instantiate_func(self) -> Callable:
        return JennyNet


class JennyNet(ClassificationModule):
    """
    GeneralNet, which consists of two main parts:
      1. Feature extraction: It is a backbone which returns latent representation of the input samples
      2. Feature refinement: It refines the features using Graph Neural Networks on a full graph of the latent representations
    """
    def __init__(self, params: JennyNetHyperParameterSet = JennyNetHyperParameterSet()):
        self.edge_index = None
        self.graph_generator = None
        self.losses = defaultdict(list)
        super().__init__(params)
        self.get_loss_fn(10)

    def define_model(self) -> torch.nn.Module:
        classif_net = self.params.iid_net_def.instantiate()
        mpn_net = GNNReID(self.device,
                          self.params.gnn_params,
                          512).to(self.device)

        self.graph_generator = GraphGenerator(self.device, **self.params.graph_params)

        return nn.Sequential(OrderedDict([
            ('classif_net', classif_net),
            ('mpn_net', mpn_net)
        ]))

    def initialize_model(self):
        pass

    def forward_iid(self, x: torch.tensor):
        x = self.model.classif_net.forward(x)
        return x

    def forward(self, x: torch.tensor):
        """
        Runs the forward pass on the data
        :param x: The input to be forwarded
        :return: The output of the model
        """
        fc7 = self.model.classif_net.features(x)

        edge_attr, edge_index, fc7 = self.graph_generator.get_graph(fc7)
        pred, feats = self.model.mpn_net(fc7, edge_index, edge_attr, "plain")

        return pred[-1]

    def general_step(self, batch, batch_idx, mode: TrainStage):
        """
        General step used in all phases
        :param batch: The current batch of data
        :param batch_idx: The current batch index
        :param mode: The current phase
        :return: The loss
        """
        x, y = batch

        fc7 = self.model.classif_net.features(x)
        probs = self.model.classif_net.head(fc7)

        # Compute CE Loss
        loss = 0
        loss0 = self.ce(probs / 1, y)
        loss += 1 * loss0
        self.losses['Cross Entropy'].append(loss.item())

        # Add other losses of not pretraining
        edge_attr, edge_index, fc7 = self.graph_generator.get_graph(fc7)
        # self.module_logger.debug(edge_index)
        # self.module_logger.debug(edge_attr)
        pred, feats = self.model.mpn_net(fc7, edge_index, edge_attr, "plain")

        # self.comp_list.append(self.comp(fc7, feats[-1]))

        lo = self.gnn_loss(pred[-1] / 1, y)
        loss += lo
        self.losses['GNN'].append(lo.item())

        self.losses['Total Loss'].append(loss.item())

        accuracy_metric = self.get_accuracy_metric(mode)
        accuracy_metric(pred[-1], y)

        return loss

    def get_loss_fn(self, num_classes):
        self.losses = defaultdict(list)
        self.losses_mean = defaultdict(list)

        # GNN loss
        self.gnn_loss = losses.CrossEntropyLabelSmooth(num_classes=num_classes, use_gpu=True)

        # CrossEntropy Loss
        self.ce = losses.CrossEntropyLabelSmooth(num_classes=num_classes, use_gpu=True)

