"""
Hyperparameter optimization
"""
import copy
import json
import logging
import os
import time
from collections import defaultdict
from enum import Enum
from typing import Optional, List

import numpy as np
import optuna
import pytorch_lightning as pl
from optuna.exceptions import StorageInternalError
from optuna.pruners import MedianPruner, NopPruner, PercentilePruner
from optuna.samplers import RandomSampler, TPESampler, GridSampler
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import EarlyStopping
from sklearn.model_selection import StratifiedKFold

from adl4cv.active_learning.active_datamodule import ActiveDataModuleHyperParameterSet, ActiveDataModuleDefinition
from adl4cv.classification.calibration.temperature_scaling import CrossValidationModelWithTemperature
from adl4cv.classification.callback.callback import MetricsCallback
from adl4cv.classification.data.caltech import Caltech101DefinitionSpace
from adl4cv.classification.data.datas import DataModuleType
from adl4cv.classification.sampling.sampler import SamplerType
from adl4cv.classification.log.logger import TENSORBOARD_LOG_DIR, init_logger, ClassificationLogger
from adl4cv.classification.model.dummy.fc_classifier import DummyModuleDefinitionSpace, DummyModuleHyperParameterSpace, \
    DummyModuleHyperParameterSet
from adl4cv.classification.model.graph._graph.edge_attrib import EdgeAttributeType
from adl4cv.classification.model.graph._graph.graph_builder import GraphType
from adl4cv.classification.model.graph._transformer.softmax import AttentionScalingType
from adl4cv.classification.model.models import ModelType
from adl4cv.classification.optimizer.optimizers import OptimizerType
from adl4cv.classification.trainer.trainers import PL_TrainerDefinition, PL_TrainerHyperParameterSet, DeviceType
from adl4cv.classification.training.train import train_model
from adl4cv.parameters.hyperparam_storage import HyperParameterStorage
from adl4cv.parameters.params import OptimizationDefinitionSpace, SingleDefinitionSpace
from adl4cv.utils.summary import summary
from adl4cv.utils.utils import relative_file_lock, get_lock_folder, try_plot, SerializableEnum


class OptimizationSamplerType(SerializableEnum):
    """Definition of the available Optuna Samplers"""
    Random = "random"
    TPE = "tpe"
    Grid = "grid"


class PrunerType(SerializableEnum):
    """Definition of the available Optuna Pruners"""
    NopPruner = "NopPruner"
    MedianPruner = "MedianPruner"
    PercentilePruner = "PercentilePruner"


class OptimizationType(SerializableEnum):
    CrossValidation = "CrossValidation"
    TestSet = "TestSet"


class HyperParamOptimizer:
    """
    Class for hyper parameter optimization
    """
    OPTIMIZATION_FOLDER = "optimization"
    OPTIMIZATION_SPACE_FILE = "optimization_space.json"

    def __init__(
            self,
            optimization_definition_space: OptimizationDefinitionSpace,
            description: str = None,
            datamodule_hyperparams: ActiveDataModuleHyperParameterSet = ActiveDataModuleHyperParameterSet(),
            job_id: Optional[int] = None,
            callbacks: Optional[List[Callback]] = None,
            sampler_type: OptimizationSamplerType = OptimizationSamplerType.TPE,
            pruner_type: PrunerType = PrunerType.PercentilePruner,
            n_trials: Optional[int] = None,
            timeout: Optional[float] = None,
            n_jobs: int = 1,
            num_initial_samples: int = 1000,
            cross_validation_folds: int = 5,
            cross_validation_steps: int = 5,
            temperature_optimization: bool = False,
            optimization_type: OptimizationType = OptimizationType.CrossValidation
    ):
        """
        Creates new instance
        :param optimization_definition_space: The definition space of the optimization
        :param callbacks: The PyTorchLightning callbacks of the training
        :param sampler_type: The type of the samples
        :param pruner_type: The type of the pruner
        :param n_trials: The number of trials
        :param timeout: The timeout of the hyper parameter optimization
        :param n_jobs: The number of parallel jobs
        """
        self.description = description if description is not None else "default"
        self.logger = init_logger(self.__class__.__name__)
        self.datamodule_hyperparams = datamodule_hyperparams
        self.optimization_definition_space = optimization_definition_space
        self.callbacks = callbacks if callbacks is not None else []
        self.n_trials = n_trials
        self.timeout = timeout
        self.n_jobs = n_jobs
        # self.description = description
        self.num_initial_samples = num_initial_samples
        self.cross_validation_folds = cross_validation_folds
        self.cross_validation_steps = cross_validation_steps
        self.temperature_optimization = temperature_optimization
        self.optimization_type = optimization_type

        self.optimization_id = self._get_optimization_id()
        self.job_id = job_id
        self.job_id = self._get_job_id(job_id)

        self.optimization_root_folder_path = os.path.join(TENSORBOARD_LOG_DIR, self.OPTIMIZATION_FOLDER)
        self.log_folder_path = os.path.join(self.optimization_root_folder_path, self.optimization_id)
        os.makedirs(self.log_folder_path, exist_ok=True)

        self.study_log_file_path = self._get_log_file_path()

        logging.getLogger().addHandler(logging.FileHandler(self.study_log_file_path))
        optuna.logging.enable_propagation()  # Propagate logs to the root logger.

        # -----------------------------------------------
        # Define the sampler
        # -----------------------------------------------
        self.sampler = self._get_sampler(sampler_type)

        # -----------------------------------------------
        # Define the pruner
        # -----------------------------------------------
        self.pruner = self._get_pruner(pruner_type)

        # -----------------------------------------------
        # Define the study
        # -----------------------------------------------
        self.hyperparam_storage = HyperParameterStorage(self.optimization_root_folder_path, self.optimization_id)
        self.study = self._get_study()

    def _get_study(self):
        # WORKAROUND: The create_study function throws exception in case of concurrent creation, unable to handle that
        load_trial = 0
        while True:
            try:
                if load_trial == 0:
                    self.logger.info(f"Creating study!")
                    study = optuna.create_study(study_name=self.optimization_id,
                                                direction="maximize",
                                                sampler=self.sampler,
                                                pruner=self.pruner,
                                                storage=self.hyperparam_storage.storage,
                                                load_if_exists=True)
                    break
                else:
                    self.logger.info(f"Loading study!")
                    study = optuna.load_study(study_name=self.optimization_id,
                                              storage=self.hyperparam_storage.storage,
                                              sampler=self.sampler,
                                              pruner=self.pruner)
                    break
            except Exception:
                if load_trial == 5:
                    self.logger.error(f"Unable to load the study!")
                    raise

                self.logger.info(f"Study not found, waiting 10 seconds!")
                load_trial += 1
                time.sleep(10)
        return study

    def _get_optimization_id(self):
        if hasattr(self.optimization_definition_space.model_definition_space.hyperparam_space.default_hyperparam_set, "mpn_net_def"):
            attention_scaling = self.optimization_definition_space.model_definition_space.hyperparam_space.default_hyperparam_set.mpn_net_def.hyperparams.attention_scaling
            graph_type = self.optimization_definition_space.model_definition_space.hyperparam_space.default_hyperparam_set.mpn_net_def.hyperparams.graph_builder_def.type
            if self.optimization_definition_space.model_definition_space.hyperparam_space.default_hyperparam_set.mpn_net_def.hyperparams.graph_builder_def.hyperparams.edge_attrib_def is not None:
                edge_type = self.optimization_definition_space.model_definition_space.hyperparam_space.default_hyperparam_set.mpn_net_def.hyperparams.graph_builder_def.hyperparams.edge_attrib_def.type
            else:
                edge_type = None
        else:
            attention_scaling = None
            graph_type = None
            edge_type = None
        return os.path.join(
            *self.get_scope(
                data_type=self.optimization_definition_space.data_definition_space.type,
                model_type=self.optimization_definition_space.model_definition_space.type,
                optimizer_type=self.optimization_definition_space.model_definition_space.hyperparam_space.default_hyperparam_set.optimizer_definition.type,
                loss_params=self.optimization_definition_space.model_definition_space.hyperparam_space.default_hyperparam_set.loss_calc_params.keys(),
                attention_scaling=attention_scaling,
                graph_type=graph_type,
                edge_type=edge_type,
                description=self.description))

    @staticmethod
    def _get_job_id(job_id):
        if job_id is None:
            return -1
        return job_id

    def _get_log_file_path(self):
        return os.path.join(self.log_folder_path, f"study_{self.job_id}.log")

    @staticmethod
    def get_scope(data_type: DataModuleType = None,
                  model_type: ModelType = None,
                  optimizer_type: OptimizerType = None,
                  loss_params: List[str] = None,
                  attention_scaling: AttentionScalingType = None,
                  graph_type: GraphType = None,
                  edge_type: EdgeAttributeType = None,
                  description: str = "default"):
        """
        Get the scope of the logger based on the optimization
        :param optimization_definition_set: The definition of the optimization task
        :return: List of scopes of the logger
        """
        loss_params = '-'.join(loss_params) if loss_params is not None else None
        full_scope = [data_type.value if data_type is not None else None,
                      model_type.value if model_type is not None else None,
                      optimizer_type.value if optimizer_type is not None else None,
                      loss_params,
                      attention_scaling.value if attention_scaling is not None else None,
                      graph_type.value if graph_type is not None else None,
                      edge_type.value if edge_type is not None else None]
        # if None in full_scope:
        #     scope_level = next(idx for idx, x in enumerate(full_scope) if x is None)
        #     assert all(x is None for x in full_scope[scope_level:]), "Incorrect scope!"
        #     full_scope = full_scope[:scope_level]
        full_scope = [str(x) for x in full_scope]
        return full_scope + [description]

    def optimize(self) -> optuna.trial.FrozenTrial:
        """
        Optimizes the hyper parameters of an optimization definition inside the defined definition space
        :return: The best trial
        """
        self.log_optimization_space()

        if self.optimization_type == OptimizationType.CrossValidation:
            objective = self.cv_objective
        elif self.optimization_type == OptimizationType.TestSet:
            objective = self.testset_objective

        self.study.optimize(objective,
                            n_trials=self.n_trials,
                            timeout=self.timeout,
                            n_jobs=self.n_jobs,
                            catch=(ValueError, StorageInternalError, RuntimeError))
        self.save_best_definition()

        self.update_summary()
        self.plot_results()
        return self.study.best_trial

    @property
    def _result_file_path(self):
        return os.path.join(self.log_folder_path, "study_results.json")

    @property
    def _optimization_space_file_path(self):
        return os.path.join(self.log_folder_path, "optimization_space.json")

    @property
    def _viz_folder_path(self):
        viz_path = os.path.join(self.log_folder_path, "viz")
        os.makedirs(viz_path, exist_ok=True)
        return viz_path

    def _read_study_results(self):
        if os.path.exists(self._result_file_path):
            with open(self._result_file_path, 'r', encoding='utf-8') as f:
                study_logs = json.load(f)
            return study_logs
        return {}

    def cv_objective(self, trial: optuna.trial.Trial):
        """
        Defines the objective function for the hyper parameter optimization as the validation accuracy
        :return: The validation accuracy of the trial
        """
        # -----------------------------------------------
        # Get trial optimization_parameter_set
        # -----------------------------------------------
        optimization_definition_set = self.optimization_definition_space.suggest(trial)
        self._update_loss_feature_size(optimization_definition_set)
        optimization_definition_set.data_definition.hyperparams.val_ratio = None

        self.logger.info(f"Starting trial: {trial.params}")

        datamodule_definition = ActiveDataModuleDefinition(self.datamodule_hyperparams)
        datamodule = datamodule_definition.instantiate(optimization_definition_set.data_definition)

        if datamodule.params.train_sampler_definition.type == SamplerType.CombineSampler \
                or datamodule.params.train_sampler_definition.type == SamplerType.ClassBasedSampler:
            datamodule.params.train_sampler_definition.hyperparams.num_samples_per_class = \
                datamodule.region_size // datamodule.params.train_sampler_definition.hyperparams.num_classes_in_batch

        optimization_definition_set.data_definition.hyperparams.train_sampler_def = copy.deepcopy(
            datamodule.params.train_sampler_definition)
        optimization_definition_set.data_definition.hyperparams.val_sampler_def = copy.deepcopy(
            datamodule.params.train_sampler_definition)

        datamodule.label_initial_samples(self.num_initial_samples)

        # -----------------------------------------------
        # Cross validation
        # -----------------------------------------------

        labeled_indices = np.array(datamodule.labeled_pool_indices)
        labeled_targets = np.array(datamodule.dataset.targets_train)[labeled_indices]

        val_accuracies = []
        num_epochs_list = []
        temperatures = []
        calibrated_model = CrossValidationModelWithTemperature()
        skf = StratifiedKFold(n_splits=self.cross_validation_folds)
        for fold_idx, (train_pool, val_pool) in enumerate(skf.split(labeled_indices, labeled_targets)):
            if fold_idx >= self.cross_validation_steps:
                break

            self.logger.info(f"Starting fold {fold_idx}!")
            train_indices = labeled_indices[train_pool]
            val_indices = labeled_indices[val_pool]

            train_targets = labeled_targets[train_pool]
            val_targets = labeled_targets[val_pool]

            # self.logger.debug(f"Train indices: {train_indices}")
            # self.logger.debug(f"Val indices: {val_indices}")

            self.update_sampler(optimization_definition_set.data_definition.hyperparams.train_sampler_def,
                                train_indices.tolist(),
                                train_targets,
                                datamodule.batch_size)

            self.update_sampler(optimization_definition_set.data_definition.hyperparams.val_sampler_def,
                                val_indices.tolist(),
                                val_targets,
                                datamodule.batch_size)

            # -----------------------------------------------
            # Callbacks
            # -----------------------------------------------
            metrics_callback = MetricsCallback()
            early_stop_callback = EarlyStopping(
                monitor='valid_acc',
                min_delta=0.00,
                patience=10,  # 10*5 (validation check of trainer) = 50 epochs
                verbose=True,
                mode='max'
            )
            callbacks = copy.deepcopy(self.callbacks) + \
                        [metrics_callback, early_stop_callback]

            # -----------------------------------------------
            # Train model and get accuracy
            # -----------------------------------------------
            trainer, model, datamodule = train_model(optimization_definition_set=optimization_definition_set,
                                                     save_dir=self.optimization_root_folder_path,
                                                     model_name=self.optimization_id,
                                                     version=f"{self.hyperparam_storage.get_trial_name(trial.number)}/"
                                                             f"{self.hyperparam_storage.get_fold_name(fold_idx)}",
                                                     callbacks=callbacks)
            if self.temperature_optimization:
                calibrated_model.add_fold_model(model, datamodule.val_dataloader())

            num_epochs = trainer.current_epoch + 1  # Indexing from zero
            valid_acc = metrics_callback.metrics[-1].item()

            trial.report(valid_acc, fold_idx)
            if trial.should_prune():
                raise optuna.TrialPruned()

            val_accuracies.append(valid_acc)
            num_epochs_list.append(num_epochs)

            self.logger.info(
                f"Fold {fold_idx} finished in {num_epochs} epochs with validation accuracy: {val_accuracies[-1]}!")

        trial.set_user_attr("num_epochs", max(num_epochs_list))

        if self.temperature_optimization:
            temperature = calibrated_model.optimize_temperature()
            temperatures.append(temperature)
            self.logger.info(f"Optimized temperature {fold_idx}: {temperature}")

            trial.set_user_attr("temperature", sum(temperatures) / len(temperatures))

        return float(np.mean(val_accuracies))

    def testset_objective(self, trial: optuna.trial.Trial):
        """
        Defines the objective function for the hyper parameter optimization as the validation accuracy
        :return: The validation accuracy of the trial
        """
        # -----------------------------------------------
        # Get trial optimization_parameter_set
        # -----------------------------------------------
        test_accs = []
        for run_id in range(3):
            optimization_definition_set = self.optimization_definition_space.suggest(trial)
            self._update_loss_feature_size(optimization_definition_set)
            optimization_definition_set.data_definition.hyperparams.val_ratio = None

            self.logger.info(f"Starting run {run_id} with params: {trial.params}")

            datamodule_definition = ActiveDataModuleDefinition(self.datamodule_hyperparams)
            datamodule = datamodule_definition.instantiate(optimization_definition_set.data_definition)

            if datamodule.params.train_sampler_definition.type == SamplerType.CombineSampler \
                    or datamodule.params.train_sampler_definition.type == SamplerType.ClassBasedSampler:
                datamodule.params.train_sampler_definition.hyperparams.num_samples_per_class = \
                    datamodule.region_size // datamodule.params.train_sampler_definition.hyperparams.num_classes_in_batch

            optimization_definition_set.data_definition.hyperparams.train_sampler_def = copy.deepcopy(
                datamodule.params.train_sampler_definition)
            optimization_definition_set.data_definition.hyperparams.val_sampler_def = copy.deepcopy(
                datamodule.params.train_sampler_definition)

            datamodule.label_initial_samples(self.num_initial_samples)

            # -----------------------------------------------
            # Cross validation
            # -----------------------------------------------

            labeled_indices = np.array(datamodule.labeled_pool_indices)
            labeled_targets = np.array(datamodule.dataset.targets_train)[labeled_indices]

            train_indices = labeled_indices

            train_targets = labeled_targets

            # -----------------------------------------------
            # Callbacks
            # -----------------------------------------------
            metrics_callback = MetricsCallback()
            # early_stop_callback = EarlyStopping(
            #     monitor='valid_acc',
            #     min_delta=0.00,
            #     patience=10,  # 10*5 (validation check of trainer) = 50 epochs
            #     verbose=True,
            #     mode='max'
            # )
            callbacks = copy.deepcopy(self.callbacks) + \
                        [metrics_callback]

            # -----------------------------------------------
            # Train model and get accuracy
            # -----------------------------------------------

            # Use test set as validation
            datamodule.dataset.dataset_valid = datamodule.dataset.dataset_test
            datamodule.dataset.val_sampler = datamodule.dataset.test_sampler

            # -----------------------------------------------
            # Model
            # -----------------------------------------------
            model = optimization_definition_set.model_definition.instantiate()

            # -----------------------------------------------
            # Training
            # -----------------------------------------------
            pl.seed_everything(optimization_definition_set.seed)

            # -----------------------------------------------
            # Logging
            # -----------------------------------------------
            logger = ClassificationLogger(save_dir=self.optimization_root_folder_path,
                                          name=self.optimization_id,
                                          version=f"{self.hyperparam_storage.get_trial_name(trial.number)}/testset")
            logger.log_parameters(optimization_definition_set)
            self.logger.info(f"Optimization set: {optimization_definition_set.dumps()}")

            # -----------------------------------------------
            # Training
            # -----------------------------------------------
            trainer: Trainer = optimization_definition_set.trainer_definition.instantiate(logger=logger,
                                                                                          callbacks=callbacks)

            summary(model.to(optimization_definition_set.trainer_definition.hyperparams.device),
                    datamodule.dims,
                    batch_size=datamodule.batch_size,
                    device=optimization_definition_set.trainer_definition.hyperparams.device)
            trainer.fit(model=model, datamodule=datamodule.dataset)

            final_accuracy = trainer.test(model=model, datamodule=datamodule.dataset)

            accs = np.array([acc.item() for acc in metrics_callback.metrics])
            num_epochs = int(np.argmax(accs))
            test_acc = accs[num_epochs]
            num_epochs *= 5

            self.logger.info(
                f"Run {run_id} finished in {num_epochs} epochs with test acc: {final_accuracy}, best acc: {test_acc}!")

            trial.set_user_attr("num_epochs", num_epochs)
            test_accs.append(test_acc)

        return sum(test_accs) / len(test_accs)

    def _update_loss_feature_size(self, optimization_definition_set):
        if optimization_definition_set.model_definition.type == ModelType.GeneralNet:
            for loss_param in optimization_definition_set.model_definition.hyperparams.loss_calc_params.values():
                loss_hyperparams = loss_param.loss_definition.hyperparams
                if hasattr(loss_hyperparams, "feat_dim"):
                    loss_hyperparams.feat_dim = \
                        optimization_definition_set.model_definition.hyperparams.mpn_net_def.hyperparams.feature_size

    def update_sampler(self, sampler_def, indices, targets, batch_size):
        if sampler_def.type == SamplerType.SubsetRandomSampler:
            sampler_def.hyperparams.indices = indices
        elif sampler_def.type == SamplerType.CombineSampler or sampler_def.type == SamplerType.ClassBasedSampler:
            indices_map_by_class = self.indices_map_by_class(indices, targets)
            sampler_def.hyperparams.indices_of_classes = indices_map_by_class
            assert sampler_def.hyperparams.num_classes_in_batch * \
                   sampler_def.hyperparams.num_samples_per_class == batch_size

    def indices_map_by_class(self, indices_list, targets):
        list_of_indices_for_each_class = defaultdict(list)
        for idx, target in zip(indices_list, targets):
            list_of_indices_for_each_class[target].append(idx)
        return list_of_indices_for_each_class

    def plot_results(self):
        """Plots the results of the study"""
        print("Plotting the results")
        self._plot_contour()
        self._plot_intermediate_values()
        self._plot_optimization_history()
        self._plot_parallel_coordinate()
        self._plot_param_importances()

    @try_plot
    def _plot_contour(self):
        contour = optuna.visualization.plot_contour(self.study)
        self._write_html(contour, os.path.join(self._viz_folder_path, "countour.html"))

    @try_plot
    def _plot_intermediate_values(self):
        intermediate_values = optuna.visualization.plot_intermediate_values(self.study)
        self._write_html(intermediate_values, os.path.join(self._viz_folder_path, "intermediate_values.html"))

    @try_plot
    def _plot_optimization_history(self):
        optimization_history = optuna.visualization.plot_optimization_history(self.study)
        self._write_html(optimization_history, os.path.join(self._viz_folder_path, "optimization_history.html"))

    @try_plot
    def _plot_parallel_coordinate(self):
        parallel_coordinate = optuna.visualization.plot_parallel_coordinate(self.study)
        self._write_html(parallel_coordinate, os.path.join(self._viz_folder_path, "parallel_coordinate.html"))

    @try_plot
    def _plot_param_importances(self):
        try:
            param_importances = optuna.visualization.plot_param_importances(self.study)
            self._write_html(param_importances, os.path.join(self._viz_folder_path, "param_importances.html"))
        except (ValueError, ZeroDivisionError, RuntimeError) as exp:
            self.logger.warning(exp)

    def _write_html(self, plot, path):
        with relative_file_lock(path, timeout=10, root_path=get_lock_folder(self.optimization_root_folder_path)):
            plot.write_html(path)

    def _get_sampler(self, sampler_type: OptimizationSamplerType):
        """
        Gets sampler from type
        :param sampler_type: The type of the sampler
        :return: Instance of the sampler
        """
        seed = 1234
        if sampler_type == OptimizationSamplerType.Random:
            return RandomSampler(seed=seed)
        elif sampler_type == OptimizationSamplerType.TPE:
            return TPESampler(seed=seed)
        elif sampler_type == OptimizationSamplerType.Grid:
            return GridSampler(self.optimization_definition_space.search_grid)
        else:
            raise ValueError(f"Unknown sampler type: {sampler_type}")

    @staticmethod
    def _get_pruner(pruner_type: PrunerType):
        """
        Gets pruner from type
        :param pruner_type: The type of the pruner
        :return: Instance of the pruner
        """
        if pruner_type == PrunerType.NopPruner:
            return NopPruner()
        elif pruner_type == PrunerType.MedianPruner:
            return MedianPruner()
        elif pruner_type == PrunerType.PercentilePruner:
            return PercentilePruner(75)
        else:
            raise ValueError(f"Unknown pruner type: {pruner_type}")

    def update_summary(self):
        self.hyperparam_storage.update_summary()

    def save_best_definition(self):
        self.hyperparam_storage.save_best_definition()

    def log_optimization_space(self):
        with relative_file_lock(self._optimization_space_file_path, timeout=10,
                                root_path=get_lock_folder(self.optimization_root_folder_path)):
            self.optimization_definition_space.dumps_to_file(self._optimization_space_file_path)


def main():
    """
    Example usage
    """

    optimization_definition_space = \
        OptimizationDefinitionSpace(
            data_definition_space=Caltech101DefinitionSpace(),
            model_definition_space=DummyModuleDefinitionSpace(
                DummyModuleHyperParameterSpace(
                    DummyModuleHyperParameterSet(
                        output_size=101
                    )
                )
            ),
            trainer_definition_space=SingleDefinitionSpace(
                PL_TrainerDefinition(PL_TrainerHyperParameterSet(
                    fast_dev_run=True,
                    runtime_mode=DeviceType.CPU))))

    optimizer = HyperParamOptimizer(optimization_definition_space=optimization_definition_space, n_trials=2)
    optimizer.optimize()


if __name__ == "__main__":
    main()
