'''
 *
 *     ICTP: Irreducible Cartesian Tensor Potentials
 *
 *        File:  trainer.py
 *
 *     Authors: Deleted for purposes of anonymity 
 *
 *     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 * 
 * The software and its source code contain valuable trade secrets and shall be maintained in
 * confidence and treated as confidential information. The software may only be used for 
 * evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 * license agreement or nondisclosure agreement with the proprietor of the software. 
 * Any unauthorized publication, transfer to third parties, or duplication of the object or
 * source code---either totally or in part---is strictly prohibited.
 *
 *     Copyright (c) 2024 Proprietor: Deleted for purposes of anonymity
 *     All Rights Reserved.
 *
 * THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR 
 * IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY 
 * AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT 
 * DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION. 
 * 
 * NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 * IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE 
 * LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 * FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 * OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 * ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 * TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGES.
 * 
 * For purposes of anonymity, the identity of the proprietor is not given herewith. 
 * The identity of the proprietor will be given once the review of the 
 * conference submission is completed. 
 *
 * THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 *
'''
import os
import shutil
import time
from pathlib import Path
from typing import Dict, Union, List, Optional, Any


import numpy as np

import torch

from torch_ema import ExponentialMovingAverage

from src.data.data import AtomicData

from src.model.forward import ForwardAtomisticNetwork
from src.model.calculators import StructurePropertyCalculator

from src.training.callbacks import TrainingCallback
from src.training.loss_fns import LossFunction, TotalLossTracker

from src.utils.torch_geometric import Data
from src.utils.torch_geometric.dataloader import DataLoader
from src.utils.misc import load_object, save_object, get_default_device


def eval_metrics(calc: StructurePropertyCalculator,
                 dl: DataLoader,
                 eval_loss_fns: Dict[str, LossFunction],
                 eval_output_variables: List[str],
                 device: str = 'cuda:0',
                 early_stopping_loss_fn: Optional[LossFunction] = None) -> Dict[str, Any]:
    """Evaluates error metrics using the provided data set.

    Args:
        calc (StructurePropertyCalculator): Torch calculator for the atomistic model, see `calculators.py`.
        dl (DataLoader): Atomic data loader.
        eval_loss_fns (Dict[str, LossFunction]): Loss functions defined for evaluating model's performance.
        eval_output_variables (List[str]): Output variables: energy, forces, etc.
        device (str, optional): Available device (e.g., 'cuda:0' or 'cpu'). Defaults to 'cuda:0'.
        early_stopping_loss_fn (Optional[LossFunction], optional): Optional early stopping loss (used, e.g., 
                                                                   during training). Defaults to None.

    Returns:
        Dict[str, Any]: Dictionary with evaluation metrics provided by the loss function.
    """
    metrics = {}

    loss_trackers = {name: TotalLossTracker(loss_fn, requires_grad=False)
                     for name, loss_fn in eval_loss_fns.items()}

    if early_stopping_loss_fn is not None:
        early_stopping_loss_tracker = TotalLossTracker(early_stopping_loss_fn, requires_grad=False)
    else:
        early_stopping_loss_tracker = None

    n_structures_total = 0
    n_atoms_total = 0

    for _, batch in enumerate(dl):
        n_structures_total += len(batch.n_atoms)
        n_atoms_total += batch.n_atoms.sum().item()

        results = calc(batch.to(device), 
                       forces='forces' in eval_output_variables,
                       stress='stress' in eval_output_variables,
                       virials='virials' in eval_output_variables,
                       create_graph=True)

        if early_stopping_loss_fn is not None:
            early_stopping_loss_tracker.append_batch(results, batch)

        for loss_tracker in loss_trackers.values():
            loss_tracker.append_batch(results, batch)

    metrics['eval_losses'] = {name: loss_tracker.compute_final_result(n_structures_total, n_atoms_total).item() for name, loss_tracker in loss_trackers.items()}

    if early_stopping_loss_fn is not None:
        metrics['early_stopping'] = early_stopping_loss_tracker.compute_final_result(n_structures_total, n_atoms_total).item()

    return metrics


class Trainer:
    """Trains an atomistic model using the provided training data set. It uses early stopping to prevent 
    overfitting.

    Args:
        model (ForwardAtomisticNetwork): Atomistic model.
        lrs (float): Learning rate.
        lr_factor (float): Factor by which learning rate is reduced.
        scheduler_patience (int): Frequency for applying `lr_factor'.
        model_path (str): Path to the model.
        train_loss (LossFunction): Train loss function.
        eval_losses (Dict[str, LossFunction]): Evaluation loss function.
        early_stopping_loss (LossFunction): Early stopping loss function.
        device (Optional[str], optional): Available device (e.g., 'cuda:0' or 'cpu'). Defaults to None.
        max_epoch (int, optional): Maximal training epoch. Defaults to 1000.
        save_epoch (int, optional): Frequency for storing models for restarting. Defaults to 100.
        validate_epoch (int, optional): Frequency for evaluating models on validation data set and storing 
                                        best models, if requested.  Defaults to 1.
        train_batch_size (int, optional): Training mini-batch size. Defaults to 32.
        valid_batch_size (int, optional): Validation mini-batch size. Defaults to 100.
        callbacks (Optional[List[TrainingCallback]], optional): Callbacks to track training process. 
                                                                Defaults to None.
        opt_class (optional): Optimizer class. Defaults to torch.optim.Adam.
        amsgrad (bool, optional): If True, use amsgrad variant of adam. Defaults to False.
        max_grad_norm (float, optional): Gradient clipping value. Defaults to None.
        weight_decay (float, optional): Weight decay for the parameters of product basis.
        ema (bool, optional): It True, use exponential moving average.
        ema_decay (float, optional): Decay parameter for the exponential moving average.
    """
    def __init__(self,
                 model: ForwardAtomisticNetwork,
                 lr: float,
                 lr_factor: float,
                 scheduler_patience: int,
                 model_path: str,
                 train_loss: LossFunction,
                 eval_losses: Dict[str, LossFunction],
                 early_stopping_loss: LossFunction,
                 device: Optional[str] = None,
                 max_epoch: int = 1000,
                 save_epoch: int = 100,
                 validate_epoch: int = 1,
                 train_batch_size: int = 32,
                 valid_batch_size: int = 100,
                 callbacks: Optional[List[TrainingCallback]] = None,
                 opt_class=torch.optim.Adam,
                 amsgrad: bool = False,
                 max_grad_norm: Optional[float] = None,
                 weight_decay: float = 5e-7,
                 ema: bool = False,
                 ema_decay: float = 0.99):
        self.model = model
        self.device = device or get_default_device()
        self.calc = StructurePropertyCalculator(self.model, training=True).to(self.device)
        self.train_loss = train_loss
        self.eval_loss_fns = eval_losses
        self.early_stopping_loss_fn = early_stopping_loss
        self.train_output_variables = self.train_loss.get_output_variables()
        self.eval_output_variables = list(set(sum([l.get_output_variables() for l in self.eval_loss_fns.values()], [])))
        self.early_stopping_output_variables = self.early_stopping_loss_fn.get_output_variables()
        
        decay_interactions = {}
        no_decay_interactions = {}
        for name, param in self.model.representation.interactions.named_parameters():
            if "linear_second.weight" in name:
                decay_interactions[name] = param
            else:
                no_decay_interactions[name] = param
        
        parameter_ops = dict(
            params=[
                {
                    'name': 'embedding', 
                    'params': self.model.representation.node_embedding.parameters(), 
                    'weight_decay': 0.0,
                },
                {
                    'name': 'interactions_decay',
                    'params': list(decay_interactions.values()),
                    'weight_decay': weight_decay,
                },
                {
                    'name': 'interactions_no_decay',
                    'params': list(no_decay_interactions.values()),
                    'weight_decay': 0.0,
                },
                {
                    'name': 'products',
                    'params': self.model.representation.products.parameters(),
                    'weight_decay': weight_decay,
                },
                {
                    'name': 'readouts',
                    'params': self.model.readouts.parameters(),
                    'weight_decay': 0.0,
                }],
            lr=lr,
            amsgrad=amsgrad
            )
        
        self.optimizer = opt_class(**parameter_ops)
        self.lr_sched = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer=self.optimizer, factor=lr_factor, patience=scheduler_patience)
        
        if ema:
            self.ema = ExponentialMovingAverage(self.model.parameters(), decay=ema_decay)
        else: 
            self.ema = None
        
        self.callbacks = callbacks
        self.model_path = model_path
        self.max_epoch = max_epoch
        self.save_epoch = save_epoch
        self.validate_epoch = validate_epoch
        self.train_batch_size = train_batch_size
        self.valid_batch_size = valid_batch_size
        self.max_grad_norm = max_grad_norm

        self.epoch = 0
        self.best_es_metric = np.Inf
        self.best_epoch = 0
        self.best_eval_metrics = None

        # create best and log directories to save/restore training progress
        if not os.path.exists(self.model_path):
            os.makedirs(self.model_path)
            
        self.log_dir = os.path.join(self.model_path, 'logs')
        self.best_dir = os.path.join(self.model_path, 'best')
        
        for dir in [self.log_dir] + [self.best_dir]:
            if not os.path.exists(dir):
                os.makedirs(dir)

    def save(self,
             path: Union[Path, str]):
        """Saves the model to the folder.

        Args:
            path (Union[Path, str]): Path to the model.
        """
        to_save = {'opt': self.optimizer.state_dict(),
                   'lr_sched': self.lr_sched.state_dict(),
                   'ema': self.ema.state_dict() if self.ema is not None else None,
                   'best_es_metric': self.best_es_metric, 
                   'best_epoch': self.best_epoch,
                   'epoch': self.epoch,
                   'best_eval_metrics': self.best_eval_metrics}

        old_folders = list(Path(path).iterdir())

        new_folder = Path(path) / f'ckpt_{self.epoch}'
        os.makedirs(new_folder)

        if self.ema is not None:
            with self.ema.average_parameters():
                self.model.save(new_folder)
        else:
            self.model.save(new_folder)
        save_object(new_folder / f'training_state.pkl', to_save)
        # delete older checkpoints after the new one has been saved
        for folder in old_folders:
            if any([p.is_dir() for p in folder.iterdir()]):
                # folder contains another folder, this shouldn't occur, we don't want to delete anything important
                raise RuntimeError(f'Model saving folder {folder} contains another folder, will not be deleted')
            else:
                shutil.rmtree(folder)

    def try_load(self,
                 path: Union[Path, str]):
        """Loads the model from the folder.

        Args:
            path (Union[Path, str]): Path to the model.
        """
        # if no checkpoint exists, just don't load
        folders = list(Path(path).iterdir())
        if len(folders) == 0:
            return  # no checkpoint exists
        if len(folders) >= 2:
            folders = [f for f in folders if f.name.startswith('ckpt_')]
            file_epoch_numbers = [int(f.name[5:]) for f in folders]
            newest_file_idx = np.argmax(np.asarray(file_epoch_numbers))
            folder = folders[newest_file_idx]
        else:
            folder = folders[0]

        self.model.load_params(folder / 'params.pkl')

        state_dict = load_object(folder / 'training_state.pkl')
        self.optimizer.load_state_dict(state_dict['opt'])
        self.lr_sched.load_state_dict(state_dict['lr_sched'])
        if self.ema is not None and state_dict['ema'] is not None:
            self.ema.load_state_dict(state_dict['ema'])
        else:
            self.ema = None
        self.best_es_metric = state_dict['best_es_metric']
        self.best_eval_metrics = state_dict['best_eval_metrics']
        self.best_epoch = state_dict['best_epoch']
        self.epoch = state_dict['epoch']

    def _train_step(self,
                    batch: Data,
                    train_loss_trackers: Dict[str, TotalLossTracker]):
        """Performs a training step using the provided batch.

        Args:
            batch (Data): Atomic data graph.
            train_loss_trackers (Dict[str, TotalLossTracker]): Dictionary of loss trackers using during training; 
                                                               see `loss_fns.py`.
        """
        self.optimizer.zero_grad(set_to_none=True)
        results = self.calc(batch, 
                            forces='forces' in self.train_output_variables,
                            stress='stress' in self.train_output_variables,
                            virials='virials' in self.train_output_variables,
                            create_graph=True)

        # compute sum of train losses for model
        tracker = TotalLossTracker(self.train_loss, requires_grad=True)
        tracker.append_batch(results, batch)
        loss = tracker.compute_final_result(n_atoms_total=batch.n_atoms.sum(), n_structures_total=batch.n_atoms.shape[0])
        loss.backward()
        
        if self.max_grad_norm is not None:
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=self.max_grad_norm)
        
        # optimizer update
        self.optimizer.step()
        
        if self.ema is not None:
            self.ema.update()

        with torch.no_grad():
            for loss_tracker in train_loss_trackers.values():
                loss_tracker.append_batch(results, batch)

    def fit(self,
            train_ds: List[AtomicData],
            valid_ds: List[AtomicData]):
        """Trains atomistic models using provided training structures. Validation data is used for early stopping.

        Args:
            train_ds (List[AtomicData]): Training data.
            valid_ds (List[AtomicData]): Validation data.
        """
        # todo: put model in train() mode in the beginning and in eval() mode (or the mode they had before) at the end?
        # reset in case this fit() is called multiple times and try_load() doesn't find a checkpoint
        self.epoch = 0
        self.best_es_metric = np.Inf
        self.best_epoch = 0
        self.best_eval_metrics = None

        self.try_load(self.log_dir)
        
        # start timing
        start_session = time.time()

        # generate data queues for efficient training
        use_gpu = self.device.startswith('cuda')
        train_dl = DataLoader(train_ds, batch_size=self.train_batch_size, shuffle=True, drop_last=True,
                              pin_memory=use_gpu, pin_memory_device=self.device if use_gpu else '')
        valid_dl = DataLoader(valid_ds, batch_size=self.valid_batch_size, shuffle=False, drop_last=False,
                              pin_memory=use_gpu, pin_memory_device=self.device if use_gpu else '')

        for callback in self.callbacks:
            callback.before_fit(self)

        while self.epoch < self.max_epoch:
            start_epoch = time.time()
            
            self.epoch += 1

            train_loss_trackers = {name: TotalLossTracker(loss_fn, requires_grad=False) 
                                   for name, loss_fn in self.eval_loss_fns.items()}

            n_structures_total = 0
            n_atoms_total = 0

            for batch in train_dl:
                n_structures_total += len(batch.n_atoms)
                n_atoms_total += batch.n_atoms.sum().item()

                self._train_step(batch.to(self.device), train_loss_trackers)

            train_metrics = {name: loss_tracker.compute_final_result(n_structures_total, n_atoms_total).item()
                             for name, loss_tracker in train_loss_trackers.items()}

            if self.epoch % self.save_epoch == 0:
                # save progress for restoring
                self.save(self.log_dir)

            if self.epoch % self.validate_epoch == 0 or self.epoch == self.max_epoch:
                # check performance on validation step
                if self.ema is not None:
                    with self.ema.average_parameters():
                        valid_metrics = eval_metrics(calc=self.calc, dl=valid_dl, eval_loss_fns=self.eval_loss_fns,
                                                     eval_output_variables=self.eval_output_variables,
                                                     early_stopping_loss_fn=self.early_stopping_loss_fn,
                                                     device=self.device)
                else:
                    valid_metrics = eval_metrics(calc=self.calc, dl=valid_dl, eval_loss_fns=self.eval_loss_fns,
                                                 eval_output_variables=self.eval_output_variables,
                                                 early_stopping_loss_fn=self.early_stopping_loss_fn,
                                                 device=self.device)

                # update best metric based on early stopping score
                es_metric = valid_metrics['early_stopping']
                if es_metric < self.best_es_metric:
                    self.best_es_metric = es_metric
                    self.best_eval_metrics = valid_metrics['eval_losses']
                    self.best_epoch = self.epoch
                    self.save(self.best_dir)

                self.lr_sched.step(metrics=valid_metrics['early_stopping'])
                
                end_epoch = time.time()

                for callback in self.callbacks:
                    callback.after_epoch(self, train_metrics, valid_metrics['eval_losses'], end_epoch - start_epoch)

        end_session = time.time()

        for callback in self.callbacks:
            callback.after_fit(self, end_session - start_session)
