'''
 *
 *     ICTP: Irreducible Cartesian Tensor Potentials
 *
 *        File:  callbacks.py
 *
 *     Authors: Deleted for purposes of anonymity 
 *
 *     Proprietor: Deleted for purposes of anonymity --- PROPRIETARY INFORMATION
 * 
 * The software and its source code contain valuable trade secrets and shall be maintained in
 * confidence and treated as confidential information. The software may only be used for 
 * evaluation and/or testing purposes, unless otherwise explicitly stated in the terms of a
 * license agreement or nondisclosure agreement with the proprietor of the software. 
 * Any unauthorized publication, transfer to third parties, or duplication of the object or
 * source code---either totally or in part---is strictly prohibited.
 *
 *     Copyright (c) 2024 Proprietor: Deleted for purposes of anonymity
 *     All Rights Reserved.
 *
 * THE PROPRIETOR DISCLAIMS ALL WARRANTIES, EITHER EXPRESS OR 
 * IMPLIED, INCLUDING BUT NOT LIMITED TO IMPLIED WARRANTIES OF MERCHANTABILITY 
 * AND FITNESS FOR A PARTICULAR PURPOSE AND THE WARRANTY AGAINST LATENT 
 * DEFECTS, WITH RESPECT TO THE PROGRAM AND ANY ACCOMPANYING DOCUMENTATION. 
 * 
 * NO LIABILITY FOR CONSEQUENTIAL DAMAGES:
 * IN NO EVENT SHALL THE PROPRIETOR OR ANY OF ITS SUBSIDIARIES BE 
 * LIABLE FOR ANY DAMAGES WHATSOEVER (INCLUDING, WITHOUT LIMITATION, DAMAGES
 * FOR LOSS OF BUSINESS PROFITS, BUSINESS INTERRUPTION, LOSS OF INFORMATION, OR
 * OTHER PECUNIARY LOSS AND INDIRECT, CONSEQUENTIAL, INCIDENTAL,
 * ECONOMIC OR PUNITIVE DAMAGES) ARISING OUT OF THE USE OF OR INABILITY
 * TO USE THIS PROGRAM, EVEN IF the proprietor HAS BEEN ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGES.
 * 
 * For purposes of anonymity, the identity of the proprietor is not given herewith. 
 * The identity of the proprietor will be given once the review of the 
 * conference submission is completed. 
 *
 * THIS HEADER MAY NOT BE EXTRACTED OR MODIFIED IN ANY WAY.
 *
'''
import os
from typing import Dict

import torch

from src.utils.misc import padded_str, count_parameters


class TrainingCallback:
    """Generates training callbacks."""
    def before_fit(self, trainer: 'Trainer'):
        """Callbacks before training.

        Args:
            trainer (Trainer): The `Trainer` object which performs training of the model.
        """
        raise NotImplementedError()

    def after_epoch(self,
                    trainer: 'Trainer',
                    train_metrics: Dict[str, float],
                    eval_metrics: Dict[str, float],
                    epoch_time: float):
        """Callbacks after a single epoch.

        Args:
            trainer (Trainer): The `Trainer` object which performs training of the model.
            train_metrics (Dict[str, float]): Dictionary of running average error metrics evaluated 
                                              on the training data set.
            eval_metrics (Dict[str, float]): Dictionary of running average error metrics evaluated 
                                             on the validation data set.
            epoch_time (float): Runtime of the epoch.
        """
        raise NotImplementedError()

    def after_fit(self,
                  trainer: 'Trainer',
                  session_time: float):
        """Callback after training.

        Args:
            trainer (Trainer): The `Trainer` object which performs training of the model.
            session_time (float): The total runtime for training the model.
        """
        raise NotImplementedError()


class FileLoggingCallback(TrainingCallback):
    """Generates file logging callbacks."""
    def __init__(self):
        self.train_out = None
        self.column_widths = None

        self.test_out = None
        self.test_column_widths = None

    def before_fit(self, trainer: 'Trainer'):
        # define files to save the progress of training
        self.train_out = os.path.join(trainer.model_path, 'train.out')

        headings = [metric + ' (train/valid/best_valid)' for metric in trainer.eval_loss_fns]
        # last column does not need to be whitespace-padded because it does not matter visually
        self.column_widths = [17] + [len(heading) + 2 for heading in headings] + [0]
        headings = ['Epoch'] + headings + ['Time']
        
        if trainer.epoch > 0:
            # restored checkpoint
            f = open(self.train_out, "a+")
            f.write("".ljust(sum(self.column_widths) + 9, "=") + "\n")
            f.write('Training is restarted from epoch {} \n'.format(trainer.epoch))

        else:
            # start new session
            f = open(self.train_out, "a+")
            f.write('ICTP: Irreducible Cartesian tensor potentials\n')
            f.write('CUDA version: {}, CUDA device: {} \n'.format(torch.version.cuda, trainer.device))
            f.write('Model: {} \n'.format(trainer.model))
            f.write('Number of parameters: {} \n'.format(count_parameters(trainer.model)))
            f.write('Optimizer: {} \n'.format(trainer.optimizer))
            f.write("".ljust(sum(self.column_widths) + 9, "=") + "\n")

        f.write('Best checkpoints for the model can be found in          ............. {} \n'.format(trainer.best_dir))
        f.write('Checkpoints for restart for the model can be found in   ............. {} \n'.format(trainer.log_dir))
        f.write(' \n')
        f.write(padded_str(headings, self.column_widths) + '\n')
        f.write("".ljust(sum(self.column_widths) + 9, "=") + "\n")
        f.close()

    def after_epoch(self,
                    trainer: 'Trainer',
                    train_metrics: Dict[str, float],
                    eval_metrics: Dict[str, float],
                    epoch_time: float):
        f = open(self.train_out, "a+")
        strs = [f'Epoch {trainer.epoch}/{trainer.max_epoch}: ']
        for metric_name in train_metrics:
            vals = [train_metrics[metric_name], eval_metrics[metric_name],
                    trainer.best_eval_metrics[metric_name]]
            strs.append('/'.join([f'{val:6.4f}' for val in vals]) + ' ')
        strs.append(f'[{epoch_time:5.2f} s]')
        f.write(padded_str(strs, self.column_widths) + '\n')
        f.close()

    def after_fit(self,
                  trainer: 'Trainer',
                  session_time: float):
        f = open(self.train_out, "a+")
        f.write("".ljust(sum(self.column_widths) + 9, "=") + "\n")
        f.write('Timing report \n')
        f.write("".ljust(13, "-") + "\n")
        f.write(f"Total time                    ............. {session_time:g} s \n")
        f.write('Best model report \n')
        f.write("".ljust(17, "-") + "\n")
        f.write(f'Best epochs from the training  ............. {trainer.best_epoch} \n')
        f.close()
