from typing import Dict
import torch


class Callback(object):
    """
        Base class for all callbacks, it is meant to allow outside actions to be performed while model is training.
    """

    def on_train_begin(self) -> None:
        """
            This method will be executed when training begins, before first epoch starts.
        """
        pass

    def on_train_end(self) -> None:
        """
            This method will be executed when training ends, after last epoch ends.
        """
        pass

    def on_epoch_test_begin(self, epoch: int) -> None:
        """
            This method will be executed during epoch before metrics are evaluated on test set.

        :param int epoch: Current epoch.
        """
        pass

    def on_epoch_test_end(self, epoch: int, metrics: Dict[str, float]) -> None:
        """
            This method will be executed during epoch after metrics are evaluated on test set.

        :param int epoch: Current epoch.
        :param dict[str, float] metrics: Current model metrics results
        """
        pass

    def on_epoch_begin(self, epoch: int) -> None:
        """
            This method will be executed when epoch begins.

        :param int epoch: Current epoch.
        """
        pass

    def on_epoch_end(self, epoch: int, metrics: Dict[str, float]) -> None:
        """
            This method will be executed when epoch ends.

        :param int epoch: Current epoch.
        :param dict[str, float] metrics: Average of current model metrics evaluated during epoch.
        """
        pass

    def on_forward_begin(self, epoch: int, i: int, xs: torch.Tensor, ys: torch.Tensor) -> None:
        """
            This method will be executed before forward pass.

        :param int epoch: Current epoch.
        :param int i: Current iteration.
        :param torch.Tensor xs: Inputs that will be feed to the model.
        :param torch.Tensor ys: Groundtruth outputs.
        """
        pass

    def on_forward_end(self, epoch: int, i: int, xs: torch.Tensor,
                       ys: torch.Tensor, outputs: torch.Tensor) -> None:
        """
            This method will be executed after forward pass.

        :param int epoch: Current epoch.
        :param int i: Current iteration.
        :param torch.Tensor xs: Inputs that will be feed to the model.
        :param torch.Tensor ys: Groundtruth outputs.
        :param torch.Tensor outputs: Model's outputs.
        """
        pass

    def on_loss_begin(self, epoch: int, i: int, xs: torch.Tensor,
                      ys: torch.Tensor, outputs: torch.Tensor) -> None:
        """
            This method will be executed before loss is computed.

        :param int epoch: Current epoch.
        :param int i: Current iteration.
        :param torch.Tensor xs: Inputs that will be feed to the model.
        :param torch.Tensor ys: Groundtruth outputs.
        :param torch.Tensor outputs: Model's outputs.
        """
        pass

    def on_loss_end(self, epoch: int, i: int, xs: torch.Tensor, ys: torch.Tensor,
                    outputs: torch.Tensor, loss: torch.Tensor) -> None:
        """
            This method will be executed after loss is computed.

        :param int epoch: Current epoch.
        :param int i: Current iteration.
        :param torch.Tensor xs: Inputs that will be feed to the model.
        :param torch.Tensor ys: Groundtruth outputs.
        :param torch.Tensor outputs: Model's outputs.
        :param torch.Tensor loss: Loss at current iteration (scalar).
        """
        pass

    def on_backward_begin(self, epoch: int, i: int, xs: torch.Tensor, ys: torch.Tensor,
                          outputs: torch.Tensor, loss: torch.Tensor) -> None:
        """
            This method will be executed before backward is performed.

        :param int epoch: Current epoch.
        :param int i: Current iteration.
        :param torch.Tensor xs: Inputs that will be feed to the model.
        :param torch.Tensor ys: Groundtruth outputs.
        :param torch.Tensor outputs: Model's outputs.
        :param torch.Tensor loss: Loss at current iteration (scalar).
        """
        pass

    def on_backward_end(self, epoch: int, i: int, xs: torch.Tensor, ys: torch.Tensor,
                        outputs: torch.Tensor, loss: torch.Tensor) -> None:
        """
            This method will be executed after backward is performed.
        Note that the gradients are stored in model's parameters.

        :param int epoch: Current epoch.
        :param int i: Current iteration.
        :param torch.Tensor xs: Inputs that will be feed to the model.
        :param torch.Tensor ys: Groundtruth outputs.
        :param torch.Tensor outputs: Model's outputs.
        :param torch.Tensor loss: Loss at current iteration (scalar).
        """
        pass

    def on_optimizer_step_begin(self, epoch: int, i: int, xs: torch.Tensor, ys: torch.Tensor,
                                outputs: torch.Tensor, loss: torch.Tensor) -> None:
        """
            This method will be executed before parameters update.

        :param int epoch: Current epoch.
        :param int i: Current iteration.
        :param torch.Tensor xs: Inputs that will be feed to the model.
        :param torch.Tensor ys: Groundtruth outputs.
        :param torch.Tensor outputs: Model's outputs.
        :param torch.Tensor loss: Loss at current iteration (scalar).
        """
        pass

    def on_optimizer_step_end(self, epoch: int, i: int, xs: torch.Tensor, ys: torch.Tensor,
                              outputs: torch.Tensor, loss: torch.Tensor) -> None:
        """
            This method will be executed after parameters update.

        :param int epoch: Current epoch.
        :param int i: Current iteration.
        :param torch.Tensor xs: Inputs that will be feed to the model.
        :param torch.Tensor ys: Groundtruth outputs.
        :param torch.Tensor outputs: Model's outputs.
        :param torch.Tensor loss: Loss at current iteration (scalar).
        """
        pass
