import logging
import os
from argparse import ArgumentParser, Namespace
import copy
from typing import Type, Optional

from adl4cv.classification.data.caltech import Caltech101Definition, Caltech101HyperParameterSet, Caltech256Definition, \
    Caltech256HyperParameterSet
from adl4cv.classification.sampling.class_based_sampler import ClassBasedSamplerDefinition, \
    ClassBasedSamplerHyperParameterSet
from adl4cv.classification.sampling.combine_sampler import CombineSamplerDefinition, CombineSamplerHyperParameterSet
from adl4cv.classification.sampling.sampler import SamplerType, SubsetRandomSamplerDefinition
from adl4cv.classification.loss.center_loss import CenterLossDefinition, CenterLossHyperParameterSet
from adl4cv.classification.loss.loss_calculator import LossEvaluatorHyperParameterSet
from adl4cv.classification.loss.losses import CrossEntropyDefinition
from adl4cv.classification.model.convolutional.pretrained_resnet import PretrainedResNetDefinition, \
    PretrainedResNetHyperParameterSet
from adl4cv.classification.model.graph._graph.edge_attrib import EdgeAttributeType, \
    CorrelationEdgeAttributeDefinitionSet, CosineEdgeAttributeDefinitionSet, EdgeAttributeHyperParameterSet, \
    DistanceEdgeAttributeDefinitionSet, DistanceEdgeAttributeHyperParameterSet, NoEdgeAttributeDefinitionSet
from adl4cv.classification.model.graph._graph.graph_builder import GraphType, DenseGraphBuilderDefinitionSet, \
    DenseGraphGraphBuilderHyperParameterSet, DistinctGraphBuilderDefinitionSet, GraphBuilderHyperParameterSet
from adl4cv.classification.model.graph._transformer.softmax import AttentionScalingType
from adl4cv.classification.model.jenny.jenny_net import JennyNetDefinition, JennyNetHyperParameterSet
from adl4cv.classification.model.graph.general_net import GeneralNetDefinition, \
    GeneralNetHyperParameterSet
from adl4cv.classification.data.cifar import CIFAR10Definition, CIFAR10HyperParameterSet, \
    CIFAR100Definition, CIFAR100HyperParameterSet
from adl4cv.classification.data.datas import DataModuleType
from adl4cv.classification.data.transform import ComposeTransformDefinition, \
    ComposeTransformHyperParameterSet, RandomHorizontalFlipTransformDefinition, \
    RandomCropTransformDefinition, ToTensorTransformDefinition, NormalizeTransformDefinition, \
    NormalizeHyperParameterSet, RandomCropTransformHyperParameterSet, ResizeTransformDefinition, \
    ResizeHyperParameterSet, RepeatTransformDefinition, RepeatHyperParameterSet
from adl4cv.classification.log.logger import init_logger
from adl4cv.classification.model.convolutional.resnet import ResNet18Definition, \
    ResNet18HyperParameterSet, ResNet34Definition, ResNet34HyperParameterSet
from adl4cv.classification.model.dummy.fc_classifier import DummyModuleDefinition, \
    DummyModuleHyperParameterSet
from adl4cv.classification.model.graph.message_passing_net import \
    MessagePassingNetDefinition, MessagePassingNetHyperParameterSet
from adl4cv.classification.model.models import ModelType
from adl4cv.classification.optimizer.optimizers import SGDOptimizerDefinition, \
    SGDOptimizerHyperParameterSet, AdamOptimizerDefinition, OptimizerType, AdamOptimizerHyperParameterSet, \
    RAdamOptimizerDefinition, RAdamOptimizerHyperParameterSet
from adl4cv.classification.optimizer.schedulers import MultiStepLRSchedulerDefinition, \
    MultiStepLRSchedulerHyperParameterSet
from adl4cv.parameters.params import OptimizationDefinitionSet
from adl4cv.classification.trainer.trainers import PL_TrainerDefinition, \
    PL_TrainerHyperParameterSet, DeviceType, TrainerType
from adl4cv.parameters.optimizer import HyperParamOptimizer, PrunerType, OptimizationSamplerType, OptimizationType
from adl4cv.active_learning.active_datamodule import ActiveDataModuleHyperParameterSet
from adl4cv.utils.utils import remove_prefix, str2bool


class HyperParamOptimizationExperiment:
    def __init__(self, arguments: Namespace):
        self.arguments = arguments
        self.module_logger = init_logger(self.__class__.__name__, logging_level=logging.INFO)

    @classmethod
    def add_argparse_args(cls, parent_parser=ArgumentParser()):
        """
        Defines the arguments of the class
        :param parent_parser: The parser which should be extended
        :return: The extended parser
        """
        parser = ArgumentParser(description="Script for aquiring calibration images from video.",
                                add_help=False
                                )
        # --------------------------- OPTIMIZATION ---------------------------
        parser.add_argument(
            "-t", "--timeout", type=int, default=None,
            help="The timeout of the optimization defined in seconds"
        )
        parser.add_argument(
            "-nt", "--n_trials", type=int, default=None,
            help="The number of trials to be executed"
        )
        parser.add_argument(
            "-desc", "--description", type=str, default=None,
            help="Description of the training"
        )
        parser.add_argument(
            "-optt", "--optimization_type", type=OptimizationType, choices=list(OptimizationType), default=OptimizationType.CrossValidation,
            help="Type of the model optimizer"
        )
        parser.add_argument(
            "-st", "--sampler_type", type=OptimizationSamplerType, choices=list(OptimizationSamplerType), default=OptimizationSamplerType.TPE,
            help="Sampler to use"
        )

        # -------------------------------- MODEL --------------------------------
        parser.add_argument(
            "-m", "--model", type=ModelType, choices=list(ModelType), required=True,
            help="Model to use in the active learning"
        )
        parser.add_argument(
            "-bbm", "--backbone_model", type=ModelType, choices=list(ModelType), default=ModelType.ResNet18,
            help="Model to use as GeneralNet backbone"
        )
        parser.add_argument(
            "-gm", "--gradient_multiplier", type=float, default=0.0,
            help="The multiplier of the GradientGate before the IID head"
        )
        parser.add_argument(
            "-hb", "--head_bias", type=str2bool, default=True,
            help="Using the head bias or not"
        )
        parser.add_argument(
            "-il", "--intermediate_layers", type=str, nargs='+', default=[],
            help="The loss definitions to use"
        )
        parser.add_argument(
            "-th", "--train_head", action="store_true", default=False,
            help="Indicates whether to train the iid head or not"
        )
        parser.add_argument(
            "-conc", "--concat", type=bool, default=False,
            help="Concatenate the heads in multi-headed attention"
        )
        parser.add_argument(
            "-skp", "--skip_downsampling", type=bool, default=False,
            help="Skip the downsampling in transformer layer, causing feature size expansion"
        )
        parser.add_argument(
            "-bet", "--beta", type=str2bool, default=False,
            help="Will scale the message passing and root weighting"
        )
        parser.add_argument(
            "-rw", "--root_weight", type=str2bool, default=True,
            help="Will scale the message passing and root weighting"
        )
        parser.add_argument(
            "-pt", "--pretrained", action="store_true", default=False,
            help="Indicates whether to use pretrained ResNet18 or not"
        )
        parser.add_argument(
            "-ot", "--optimizer_type", type=OptimizerType, choices=list(OptimizerType), default=OptimizerType.SGD,
            help="Type of the model optimizer"
        )
        parser.add_argument(
            "-dg", "--distinct_graph", action="store_true", default=False,
            help="Indicates whether to use distinct graph or not"
        )
        parser.add_argument(
            "-eat", "--edge_attrib_type", type=EdgeAttributeType, choices=list(EdgeAttributeType),
            default=EdgeAttributeType.NO_EDGE_ATTRIB,
            help="Type of the edge attributes"
        )
        parser.add_argument(
            "-ean", "--edge_attrib_normalize", action="store_true", default=False,
            help="Indicates whether to normalize the edge attributes or not"
        )
        parser.add_argument(
            "-gbt", "--graph_builder_type", type=GraphType, choices=list(GraphType), default=GraphType.DENSE,
            help="Type of the graph builder"
        )
        parser.add_argument(
            "-gssl", "--skip_self_loops", action="store_true", default=False,
            help="Indicates whether to skip self loops in dense graphs or not"
        )
        parser.add_argument(
            "-ed", "--edge_dimension", type=int, default=None,
            help="Defines the edge dimension in MPN"
        )
        parser.add_argument(
            "-ascl", "--attention_scaling", type=AttentionScalingType, choices=list(AttentionScalingType),
            default=AttentionScalingType.NO_SCALING,
            help="Defines the regularization method"
        )
        parser.add_argument(
            "-qkbs", "--query_key_bias", type=str2bool, default=True,
            help="Defines the regularization method"
        )
        parser.add_argument(
            "-to", "--temperature_optimization", action="store_true", default=False,
            help="Optimize temperature or not"
        )
        parser.add_argument("-dmp", "--disable_mp", type=str2bool, default=False,
                            help="Whether to shut down the MP in the Transformer model"
                            )

        # -------------------------------- DATA --------------------------------
        parser.add_argument(
            "-d", "--dataset", type=DataModuleType, choices=list(DataModuleType), default=DataModuleType.CIFAR10,
            help="Dataset to use in the active learning"
        )
        parser.add_argument(
            "-ndw", "--num_data_workers", type=int, default=4,
            help="Number of unlabelled samples to label in each stage"
        )
        parser.add_argument(
            "-tspc", "--train_samples_per_class", type=int, default=None,
            help="The number of training samples in each class"
        )

        # -------------------------------- TRAIN --------------------------------
        parser.add_argument(
            "-nis", "--num_initial_samples", type=int, default=1000,
            help="Number of samples in the initial labeled pool"
        )
        parser.add_argument(
            "-dev", "--device", type=DeviceType, choices=list(DeviceType), default=DeviceType.GPU,
            help="Defines the maximum number of epochs"
        )
        parser.add_argument(
            "-fdr", "--fast_dev_run", action="store_true", default=False,
            help="Indicates whether to run fast dev run or not"
        )
        parser.add_argument(
            "-tst", "--train_sampler_type", type=SamplerType, choices=list(SamplerType),
            default=SamplerType.SubsetRandomSampler,
            help="Defines the type of the sampler used during training"
        )
        parser.add_argument(
            "-tsnc", "--train_sampler_num_classes", type=int, default=None, help="Number of classes in batch"
        )
        parser.add_argument(
            "-cvf", "--cross_validation_folds", type=int, default=5,
            help="Number of cross validation folds"
        )
        parser.add_argument(
            "-cvs", "--cross_validation_steps", type=int, default=5,
            help="Number of cross validation steps"
        )
        parser.add_argument(
            "-ld", "--loss_definition", nargs='+', default=None,
            help="The loss definitions to use"
        )
        parser.add_argument(
            "-lw", "--loss_weights", type=float, nargs='+', default=None,
            help="The loss weights to use"
        )
        parser.add_argument(
            "-me", "--max_epochs", type=int, default=None, help="Defines the maximum number of epochs"
        )

        # -------------------------------- PARAMS --------------------------------
        parser.add_argument(
            "-bs", "--batch_size", type=int, default=None, help="Batch size"
        )

        parser.add_argument(
            "-fs", "--feature_size", type=int, default=None, help="Feature size of MPN"
        )
        parser.add_argument(
            "-nmp", "--num_message_passings", type=int, default=None,
            help="Number of message passings by MPN"
        )
        parser.add_argument(
            "-nh", "--num_heads", type=int, default=None,
            help="Number of heads of attention during message passings"
        )
        parser.add_argument(
            "-ast", "--attention_scaling_threshold", type=float, default=None,
            help="Defines the scale of attention scaling"
        )

        parser.add_argument(
            "-lrate", "--learning_rate", type=float, default=None,
            help="The learning rate"
        )
        parser.add_argument(
            "-lratemsr", "--lr_milestone_ratio", type=float, nargs='+', default=None,
            help="The milestone ratio for learning rate scheduling"
        )
        parser.add_argument(
            "-mom", "--momentum", type=float, default=None,
            help="The momentum"
        )
        parser.add_argument(
            "-wd", "--weight_decay", type=float, default=None,
            help="The weight decay"
        )

        return parser

    @classmethod
    def from_argparse_args(cls: Type['_T'], arguments: Namespace) -> '_T':
        """
        Creates new object from arguments
        """
        return cls(arguments)

    def _get_optimization_id(self, optimization_id: Optional[str]):
        if optimization_id is None:
            if "SLURM_ARRAY_JOB_ID" in os.environ:
                return int(os.environ['SLURM_ARRAY_JOB_ID'])
            else:
                self.module_logger.debug(f"SLURM_ARRAY_JOB_ID not found in {os.environ}")
                return self._get_job_id()

    def _get_job_id(self):
        if "SLURM_JOB_ID" in os.environ:
            return int(os.environ['SLURM_JOB_ID'])
        else:
            self.module_logger.debug(f"SLURM_JOB_ID not found in {os.environ}")
            self.module_logger.warning("Job ID not found, using default value: -1")
            return -1

    def execute(self):
        description = self._get_description()
        optimization_definition_space = self._get_optimization_definition_space()
        optimizer = HyperParamOptimizer(optimization_definition_space=optimization_definition_space,
                                        description=description,
                                        datamodule_hyperparams=self._get_datamodule_hyperparams(),
                                        job_id=self._get_job_id(),
                                        callbacks=None,
                                        sampler_type=self.arguments.sampler_type,
                                        pruner_type=PrunerType.PercentilePruner,
                                        n_trials=self.arguments.n_trials,
                                        timeout=self.arguments.timeout,
                                        n_jobs=1,
                                        num_initial_samples=self.arguments.num_initial_samples,
                                        cross_validation_folds=self.arguments.cross_validation_folds,
                                        cross_validation_steps=self.arguments.cross_validation_steps,
                                        temperature_optimization=self.arguments.temperature_optimization,
                                        optimization_type=self.arguments.optimization_type)

        print(f"Running hyperparameter optimization: {optimization_definition_space.dumps()}")
        result = optimizer.optimize()
        print(f"Optimization done!")

    def _get_optimization_definition_space(self):
        default_optimization_set = self._get_optimization_set()
        return default_optimization_set.definition_space()

    def _get_optimization_set(self):
        data_definition = self._get_data_definition()
        model_definition = self._get_model_definition(data_definition.hyperparams.num_classes)
        trainer_definition = self._get_trainer_definition()
        return OptimizationDefinitionSet(
            data_definition=data_definition,
            model_definition=model_definition,
            trainer_definition=trainer_definition,
            seed=1234)

    def _get_datamodule_hyperparams(self):
        train_sampler_definitions = \
            {
                # SamplerType.SubsetSequentialSampler: SubsetSequentialSamplerDefinition(),
                SamplerType.SubsetRandomSampler: SubsetRandomSamplerDefinition(),
                SamplerType.CombineSampler: CombineSamplerDefinition(
                    CombineSamplerHyperParameterSet(
                        num_classes_in_batch=self.arguments.train_sampler_num_classes,
                        num_samples_per_class=None
                    )
                ),
                SamplerType.ClassBasedSampler: ClassBasedSamplerDefinition(
                    ClassBasedSamplerHyperParameterSet(
                        num_classes_in_batch=self.arguments.train_sampler_num_classes,
                        num_samples_per_class=None
                    )
                )
            }

        return ActiveDataModuleHyperParameterSet(
            train_sampler_definition=train_sampler_definitions[self.arguments.train_sampler_type])

    def _get_data_definition(self):
        data_definitions = \
            {
                DataModuleType.CIFAR10: CIFAR10Definition(
                    CIFAR10HyperParameterSet(
                        num_workers=self.arguments.num_data_workers,
                        batch_size=self.arguments.batch_size,
                        val_ratio=None,
                        duplication_factor=1,
                        train_transforms_def=ComposeTransformDefinition(
                            ComposeTransformHyperParameterSet([
                                RandomHorizontalFlipTransformDefinition(),
                                RandomCropTransformDefinition(
                                    RandomCropTransformHyperParameterSet(size=32, padding=4)),
                                ToTensorTransformDefinition(),
                                NormalizeTransformDefinition(
                                    NormalizeHyperParameterSet(mean=[0.4914, 0.4822, 0.4465],
                                                               std=[0.2023, 0.1994, 0.2010]))
                            ])),
                        val_transforms_def=None,
                        test_transforms_def=ComposeTransformDefinition(
                            ComposeTransformHyperParameterSet([
                                ToTensorTransformDefinition(),
                                NormalizeTransformDefinition(
                                    NormalizeHyperParameterSet(mean=[0.4914, 0.4822, 0.4465],
                                                               std=[0.2023, 0.1994, 0.2010]))
                            ])))),
                DataModuleType.CIFAR100: CIFAR100Definition(
                    CIFAR100HyperParameterSet(
                        num_workers=self.arguments.num_data_workers,
                        batch_size=self.arguments.batch_size,
                        val_ratio=None,
                        duplication_factor=1,
                        train_transforms_def=ComposeTransformDefinition(
                            ComposeTransformHyperParameterSet([
                                RandomHorizontalFlipTransformDefinition(),
                                RandomCropTransformDefinition(
                                    RandomCropTransformHyperParameterSet(size=32, padding=4)),
                                ToTensorTransformDefinition(),
                                NormalizeTransformDefinition(
                                    NormalizeHyperParameterSet(mean=[0.5071, 0.4867, 0.4408],
                                                               std=[0.2675, 0.2565, 0.2761]))
                            ])),
                        val_transforms_def=None,
                        test_transforms_def=ComposeTransformDefinition(
                            ComposeTransformHyperParameterSet([
                                ToTensorTransformDefinition(),
                                NormalizeTransformDefinition(
                                    NormalizeHyperParameterSet(mean=[0.5071, 0.4867, 0.4408],
                                                               std=[0.2675, 0.2565, 0.2761]))
                            ])))),
                DataModuleType.CALTECH101: Caltech101Definition(
                    Caltech101HyperParameterSet(
                        num_workers=self.arguments.num_data_workers,
                        batch_size=self.arguments.batch_size,
                        val_ratio=None,
                        train_samples_per_class=self.arguments.train_samples_per_class,
                        train_transforms_def=ComposeTransformDefinition(
                            ComposeTransformHyperParameterSet([
                                ResizeTransformDefinition(ResizeHyperParameterSet(size=(224, 224))),
                                RandomHorizontalFlipTransformDefinition(),
                                RandomCropTransformDefinition(
                                    RandomCropTransformHyperParameterSet(size=224, padding=16)),
                                ToTensorTransformDefinition(),
                                RepeatTransformDefinition(RepeatHyperParameterSet(desired_num_of_channels=3)),
                                NormalizeTransformDefinition(
                                    NormalizeHyperParameterSet(mean=[0.5013, 0.4772, 0.4475],
                                                               std=[0.3331, 0.3277, 0.3343]))
                            ])),
                        val_transforms_def=None,
                        test_transforms_def=ComposeTransformDefinition(
                            ComposeTransformHyperParameterSet([
                                ResizeTransformDefinition(ResizeHyperParameterSet(size=(224, 224))),
                                ToTensorTransformDefinition(),
                                RepeatTransformDefinition(RepeatHyperParameterSet(desired_num_of_channels=3)),
                                NormalizeTransformDefinition(
                                    NormalizeHyperParameterSet(mean=[0.5013, 0.4772, 0.4475],
                                                               std=[0.3331, 0.3277, 0.3343]))
                            ])))),
                DataModuleType.CALTECH256: Caltech256Definition(
                    Caltech256HyperParameterSet(
                        num_workers=self.arguments.num_data_workers,
                        batch_size=self.arguments.batch_size,
                        val_ratio=None,
                        train_samples_per_class=self.arguments.train_samples_per_class,
                        train_transforms_def=ComposeTransformDefinition(
                            ComposeTransformHyperParameterSet([
                                ResizeTransformDefinition(ResizeHyperParameterSet(size=(224, 224))),
                                RandomHorizontalFlipTransformDefinition(),
                                RandomCropTransformDefinition(
                                    RandomCropTransformHyperParameterSet(size=224, padding=16)),
                                ToTensorTransformDefinition(),
                                RepeatTransformDefinition(RepeatHyperParameterSet(desired_num_of_channels=3)),
                                NormalizeTransformDefinition(
                                    NormalizeHyperParameterSet(mean=[0.5118, 0.4911, 0.4646],
                                                               std=[0.3352, 0.3299, 0.3384]))
                            ])),
                        val_transforms_def=None,
                        test_transforms_def=ComposeTransformDefinition(
                            ComposeTransformHyperParameterSet([
                                ResizeTransformDefinition(ResizeHyperParameterSet(size=(224, 224))),
                                ToTensorTransformDefinition(),
                                RepeatTransformDefinition(RepeatHyperParameterSet(desired_num_of_channels=3)),
                                NormalizeTransformDefinition(
                                    NormalizeHyperParameterSet(mean=[0.5118, 0.4911, 0.4646],
                                                               std=[0.3352, 0.3299, 0.3384]))
                            ]))))
            }
        return data_definitions[self.arguments.dataset]

    def _get_model_definition(self, num_classes):
        optimizer_definition = self._get_optimizer_definition()
        scheduler_definition = MultiStepLRSchedulerDefinition(
            MultiStepLRSchedulerHyperParameterSet(
                milestone_ratios=self.arguments.lr_milestone_ratio
            )
        )
        loss_calc_params = self._get_loss_calc_params(num_classes)

        model_definitions = \
            {
                ModelType.ResNet18: self._update_iid_loss_calc_params(
                    self._update_gradient_multiplier(
                        self._get_backbone_def(
                            ModelType.ResNet18,
                            self.arguments.pretrained,
                            num_classes,
                            optimizer_definition,
                            scheduler_definition,
                            loss_calc_params))),
                ModelType.ResNet34: self._update_iid_loss_calc_params(
                    self._update_gradient_multiplier(
                        self._get_backbone_def(
                            ModelType.ResNet18,
                            self.arguments.pretrained,
                            num_classes,
                            optimizer_definition,
                            scheduler_definition,
                            loss_calc_params))),
                ModelType.EfficientNetB3: self._update_iid_loss_calc_params(
                    self._update_gradient_multiplier(
                        self._get_backbone_def(
                            ModelType.EfficientNetB3,
                            self.arguments.pretrained,
                            num_classes,
                            optimizer_definition,
                            scheduler_definition,
                            loss_calc_params))),
                ModelType.GeneralNet: GeneralNetDefinition(
                    GeneralNetHyperParameterSet(
                        iid_net_def=self._update_iid_loss_calc_params(
                            self._get_backbone_def(
                                self.arguments.backbone_model,
                                self.arguments.pretrained,
                                num_classes,
                                optimizer_definition,
                                scheduler_definition,
                                loss_calc_params)),
                        mpn_net_def=MessagePassingNetDefinition(
                            MessagePassingNetHyperParameterSet(
                                feature_size=self.arguments.feature_size,
                                num_message_pass=self.arguments.num_message_passings,
                                num_heads=self.arguments.num_heads,
                                output_size=num_classes,
                                concat=self.arguments.concat,
                                skip_downsampling=self.arguments.skip_downsampling,
                                graph_builder_def=self._get_graph_builder_definition(),
                                beta=self.arguments.beta,
                                root_weight=self.arguments.root_weight,
                                edge_dim=self.arguments.edge_dimension,
                                attention_scaling=self.arguments.attention_scaling,
                                attention_scaling_threshold=self.arguments.attention_scaling_threshold,
                                bias=self.arguments.query_key_bias,
                                disable_mp=self.arguments.disable_mp,
                                optimizer_definition=None,
                                scheduler_definition=None
                            )),
                        optimizer_definition=optimizer_definition,
                        scheduler_definition=scheduler_definition,
                        loss_calc_params=loss_calc_params
                    )),
                ModelType.JennyNet: JennyNetDefinition(
                    JennyNetHyperParameterSet(
                        classif_net_def=self._get_backbone_def(
                                self.arguments.backbone_model,
                                self.arguments.pretrained,
                                num_classes,
                                optimizer_definition,
                                scheduler_definition,
                                loss_calc_params),
                        distinct_graph=self.arguments.distinct_graph,
                        optimizer_definition=optimizer_definition,
                        scheduler_definition=scheduler_definition,
                        loss_calc_params=loss_calc_params
                    )
                ),
                ModelType.Dummy: DummyModuleDefinition(
                    DummyModuleHyperParameterSet(
                        hidden_dim=None,
                        output_size=num_classes,
                        optimizer_definition=optimizer_definition,
                        scheduler_definition=scheduler_definition,
                        loss_calc_params=loss_calc_params))
            }
        return model_definitions[self.arguments.model]

    def _get_optimizer_definition(self):
        optimizer_definitions = {
            OptimizerType.SGD: SGDOptimizerDefinition(
                SGDOptimizerHyperParameterSet(
                    lr=self.arguments.learning_rate,
                    momentum=self.arguments.momentum,
                    weight_decay=self.arguments.weight_decay)),
            OptimizerType.Adam: AdamOptimizerDefinition(
                AdamOptimizerHyperParameterSet(
                    lr=self.arguments.learning_rate,
                    betas=(0.9, 0.999),
                    eps=1e-8,
                    weight_decay=self.arguments.weight_decay,
                    amsgrad=False
                )
            ),
            OptimizerType.RAdam: RAdamOptimizerDefinition(
                RAdamOptimizerHyperParameterSet(
                    lr=self.arguments.learning_rate,
                    betas=(0.9, 0.999),
                    eps=1e-8,
                    weight_decay=self.arguments.weight_decay
                )
            )
        }
        return optimizer_definitions[self.arguments.optimizer_type]

    def _get_graph_builder_definition(self):
        edge_attrib_def = self._get_edge_attrib_definition()

        graph_builder_definitions = {
            GraphType.DENSE: DenseGraphBuilderDefinitionSet(
                DenseGraphGraphBuilderHyperParameterSet(
                    edge_attrib_def=edge_attrib_def,
                    skip_self_loops=self.arguments.skip_self_loops
                )
            ),
            GraphType.DISTINCT: DistinctGraphBuilderDefinitionSet(
                GraphBuilderHyperParameterSet(
                    edge_attrib_def=edge_attrib_def
                )
            )
        }
        return graph_builder_definitions[self.arguments.graph_builder_type]

    def _get_edge_attrib_definition(self):
        edge_attrib_definitions = {
            EdgeAttributeType.NO_EDGE_ATTRIB: NoEdgeAttributeDefinitionSet(),
            EdgeAttributeType.CORRELATION_EDGE_ATTRIB: CorrelationEdgeAttributeDefinitionSet(
                EdgeAttributeHyperParameterSet(
                    shift_negative_attrib=False
                )
            ),
            EdgeAttributeType.COSINE_EDGE_ATTRIB: CosineEdgeAttributeDefinitionSet(
                EdgeAttributeHyperParameterSet(
                    shift_negative_attrib=False
                )
            ),
            EdgeAttributeType.EUCLIDEAN_EDGE_ATTRIB: DistanceEdgeAttributeDefinitionSet(
                DistanceEdgeAttributeHyperParameterSet(
                    shift_negative_attrib=False,
                    norm=self.arguments.edge_attrib_normalize,
                )
            ),
        }
        return edge_attrib_definitions[self.arguments.edge_attrib_type]

    def _get_trainer_definition(self):
        check_val_every_n_epoch = 5 if not self.arguments.fast_dev_run else 1
        return PL_TrainerDefinition(
            PL_TrainerHyperParameterSet(
                runtime_mode=self.arguments.device,
                max_epochs=self.arguments.max_epochs,
                fast_dev_run=self.arguments.fast_dev_run,
                check_val_every_n_epoch=check_val_every_n_epoch,
                weights_summary=None))

    def _get_loss_calc_params(self, num_classes):
        all_loss_params = \
            {
                "cross_entropy":
                    LossEvaluatorHyperParameterSet(
                        layers_needed=None,
                        loss_definition=CrossEntropyDefinition()),
                "mpn_cross_entropy":
                    LossEvaluatorHyperParameterSet(
                        layers_needed=None,
                        loss_definition=CrossEntropyDefinition()),
                "iid_cross_entropy":
                    LossEvaluatorHyperParameterSet(
                        layers_needed={"input": "model.iid_net.model.classifier"},
                        loss_definition=CrossEntropyDefinition()),
                "center_loss":
                    LossEvaluatorHyperParameterSet(
                        layers_needed={"x": "flatten4"},
                        loss_definition=CenterLossDefinition(
                            CenterLossHyperParameterSet(
                                num_classes=num_classes,
                                feat_dim=512
                            )
                        )),
                "mpn_center_loss":
                    LossEvaluatorHyperParameterSet(
                        layers_needed={"x": "model.mpn_net.model.mpn_layers.layer_0"},
                        loss_definition=CenterLossDefinition(
                            CenterLossHyperParameterSet(
                                num_classes=num_classes,
                                feat_dim=None
                            )
                        )),
                "iid_center_loss":
                    LossEvaluatorHyperParameterSet(
                        layers_needed={"x": "model.iid_net.model.backbone"},
                        loss_definition=CenterLossDefinition(
                            CenterLossHyperParameterSet(
                                num_classes=num_classes,
                                feat_dim=512
                            )
                        ))
            }

        if self.arguments.loss_definition is None:
            if self.arguments.model == ModelType.GeneralNet:
                loss_definition = ["mpn_cross_entropy", "iid_cross_entropy"]
            else:
                loss_definition = ["cross_entropy"]
        else:
            loss_definition = self.arguments.loss_definition

        loss_calc_params = {}
        if self.arguments.loss_weights is not None:
            assert len(self.arguments.loss_weights) == len(loss_definition), \
                "The size of the loss weights does not match the size of the loss definitions!"
            loss_weights = self.arguments.loss_weights
        else:
            loss_weights = [1] * len(loss_definition)

        for loss_weight, loss_param_id in zip(loss_weights, loss_definition):
            loss_calc_params[loss_param_id] = all_loss_params[loss_param_id]
            loss_calc_params[loss_param_id].weight = loss_weight

        return loss_calc_params

    @staticmethod
    def _update_iid_loss_calc_params(iid_model_def):
        updated_iid_model_def = copy.deepcopy(iid_model_def)
        updated_iid_model_def.hyperparams.loss_calc_params = \
            {key: params for key, params in updated_iid_model_def.hyperparams.loss_calc_params.items() if "iid" in key}
        for param in updated_iid_model_def.hyperparams.loss_calc_params.values():
            if param.layers_needed is not None:
                param.layers_needed = {key: remove_prefix(layer_name, "model.iid_net.")
                                       for key, layer_name in param.layers_needed.items()}
        updated_iid_model_def.hyperparams.loss_calc_params.update(
            {key: params for key, params in iid_model_def.hyperparams.loss_calc_params.items() if
             "iid" not in key and "mpn" not in key})
        return updated_iid_model_def

    def _update_gradient_multiplier(self, resnet_def):
        resnet_def = copy.deepcopy(resnet_def)
        resnet_def.hyperparams.gradient_multiplier = 1.0
        return resnet_def

    def _get_fixed_params(self):
        params = {
            "batch_size": self.arguments.batch_size,
            "feature_size": self.arguments.feature_size,
            "num_message_passings": self.arguments.num_message_passings,
            "num_heads": self.arguments.num_heads,
            "attention_scaling_threshold": self.arguments.attention_scaling_threshold,
            "learning_rate": self.arguments.learning_rate,
            "lr_milestone_ratio": self.arguments.lr_milestone_ratio,
            "momentum": self.arguments.momentum,
            "weight_decay": self.arguments.weight_decay,
            "max_epochs": self.arguments.max_epochs
        }

        return {key: value for key, value in params.items() if value is not None}

    def _get_description(self):
        if self.arguments.description is not None:
            return self.arguments.description

        fixed_params = self._get_fixed_params()
        if len(fixed_params) == 0:
            return None

        return "|".join(f"{key}-{value}" for key, value in fixed_params.items())

    def _get_backbone_def(self, model, pretrained, num_classes, optimizer_definition, scheduler_definition,
                          loss_calc_params):
        if pretrained:
            assert model == ModelType.ResNet18, "Only ResNet18 is supported with pretrained!"
            resnet_def = PretrainedResNetDefinition(
                PretrainedResNetHyperParameterSet(
                    output_size=num_classes,
                    gradient_multiplier=self.arguments.gradient_multiplier,
                    optimizer_definition=optimizer_definition,
                    scheduler_definition=scheduler_definition,
                    loss_calc_params=loss_calc_params
                ))
        else:
            if model == ModelType.ResNet34:
                resnet_def = ResNet34Definition(
                    ResNet34HyperParameterSet(
                        output_size=num_classes,
                        head_bias=self.arguments.head_bias,
                        gradient_multiplier=self.arguments.gradient_multiplier,
                        optimizer_definition=optimizer_definition,
                        scheduler_definition=scheduler_definition,
                        loss_calc_params=loss_calc_params
                    ))
            else:
                resnet_def = ResNet18Definition(
                    ResNet18HyperParameterSet(
                        output_size=num_classes,
                        head_bias=self.arguments.head_bias,
                        gradient_multiplier=self.arguments.gradient_multiplier,
                        optimizer_definition=optimizer_definition,
                        scheduler_definition=scheduler_definition,
                        loss_calc_params=loss_calc_params
                    ))

        resnet_def = self._update_intermediate_layers_def(resnet_def)

        return resnet_def

    def _update_intermediate_layers_def(self, resnet_def):
        resnet_def = copy.deepcopy(resnet_def)
        resnet_def.hyperparams.backbone_definition.hyperparams.intermediate_layers_to_return = self.arguments.intermediate_layers
        return resnet_def


if __name__ == "__main__":
    arguments = HyperParamOptimizationExperiment.add_argparse_args().parse_args()
    experiment = HyperParamOptimizationExperiment.from_argparse_args(arguments)
    experiment.execute()
