################################################################################
# deepthinking/models/deep_thinking_recall.py
#
# 
# 
# 2023
#
# Implementation of a wrapper for deep thinking networks with recall.
# Non-recall (vanilla) version originally described in [1], with recall
#  modifications from [2].
# Extended to be a generalized version for implementation of individual
#  components (modules).
#
# Changes from 'DeepThinking' will have comments start with 'RECALL:'.

import torch

from typing import Callable, Optional, Tuple, Union

from .deep_thinking_vanilla import DeepThinking

Module = torch.nn.Module
Tensor = torch.Tensor

class DeepThinkingRecall(DeepThinking):
  """
  Wrapper for deep thinking systems with recall. Non-recall (vanilla) version
  originally described by Schwarzschild et al. (2021), with recall modifications
  by Bansal et al. (2022).
  """

  def __init__(self,
      # Arguments:
      input_module:         Module,
      preprocessing_module: Module,
      thought_module:       Module,
      output_module:        Module,
      *args,
      # Keyword Arguments:
      max_iterations:           int  = 1,
      use_incremental_progress: bool = False,
      **kwargs
    ):
    """
    Initializes ``DeepThinkingRecall``.

    Args:
      input_module (Module):
        Transforms initial input for all further computation.
      preprocessing_module (Module):
        Preprocesses transformed input to first thought.
      thought_module (Module):
        Will be iterated to perform the thought, with the addition of recall
        to input.
      output_module (Module):
        Transforms thought into output.
      max_iterations (int, optional):
        Maximum iterations for training and inference.
        Defaults to ``1``.
      use_incremental_progress (bool, optional):
        Whether to use incremental progress training.
        Defaults to ``False``.
      *args:
        Additional arguments for ``Module``.
      **kwargs:
        Additional keyword arguments for ``Module``.
    """
    super(DeepThinkingRecall, self).__init__(
      input_module,
      thought_module,
      output_module,
      *args,
      max_iterations = max_iterations,
      use_incremental_progress = use_incremental_progress,
      **kwargs
    )
    self.preprocessing_module = preprocessing_module

  def incremental_progress(self,
      # Arguments:
      x_tilde:        Tensor,
      max_iterations: int
    ) -> Tuple[Tensor, Tensor]:
    """
    Implements the iterative portion of the "Incremental Progress Training
    Algorithm" described by Bansal et al. (2022).

    Args:
      x_tilde (Tensor):
        Result from input module (used for recall and as input to
        preprocessing module).
      max_iterations (int):
        Maximum iterations for incremental progress training.

    Returns:
      Tuple[Tensor, Tensor]:
        (phi_prog, phi_m)...
        Thought module output from progressive iterations and maximum
        iterations.
    """
    # n ~ U{0, m-1}, k ~ U{1, m-n}
    n: int = torch.randint(
      low = 0, high = max_iterations - 1,
      size = (1,)
    ).item()
    k: int = torch.randint(
      low = 1, high = max_iterations - n,
      size = (1,)
    ).item()
    # Get initial phi value.
    phi_0: Tensor = self.preprocessing_module(x_tilde)
    # Set up progressive thought.
    phi: Tensor = phi_0
    # First n iterations do not track gradients.
    with torch.no_grad():
      for _ in range(n):
        phi = self.perform_iteration(phi, x_tilde)
    # no_grad doesn't calculate gradients. detach ensures these are not part
    #  of the graph.
    phi = phi.detach()
    # Final k iterations do track gradients.
    for _ in range(k):
      phi = self.perform_iteration(phi, x_tilde)
    phi_prog: Tensor = phi
    # Set up maximum iteration thought.
    phi = phi_0
    # All iterations track gradients.
    for _ in range(max_iterations):
      phi = self.perform_iteration(phi, x_tilde)
    phi_m: Tensor = phi
    # Return both outputs.
    return (phi_prog, phi_m)

  # incremental_progress_loss is unchanged from 'deep_thinking_vanilla.py'.

  def perform_iteration(self,
      # Arguments:
      phi: Tensor,
      # Keyword Arguments:
      x_tilde: Tensor = None
    ) -> Tensor:
    """
    Implements a single iteration of the thought module.

    Args:
      phi (Tensor):
        Current state of thought module.
      x_tilde (Tensor, optional):
        Original input module result to be used as recall.
        Defaults to ``None``, but will error if left as this.

    Returns:
      Tensor:
        Next state of thought module.
    """
    # In the case of the recall DT network, this is just the thought module.
    phi = self.thought_module(phi, x_tilde)
    return phi

  def forward(self,
      # Arguments:
      x: Tensor,
      # Keyword Arguments:
      max_iterations: Optional[int]    = None,
      return_thought: bool             = False,
      phi:            Optional[Tensor] = None
    ) -> Union[
      Tuple[Tensor, Tensor, Tensor, Tensor], Tuple[Tensor, Tensor], Tensor
    ]:
    """
    Forward pass of the model.

    Args:
      x (Tensor):
        Model input.
      max_iterations (int, optional):
        Override attribute value.
        Defaults to ``None`` (don't override value).
      return_thought (bool, optional):
        Whether to return the thought after iterations.
        Defaults to ``False``.
      phi (Tensor, optional):
        The intermediatary phi used only for inference, to avoid recalculating
        many phis.
        Defaults to ``None``.

    Returns:
      Tuple[Tensor, Tensor, Tensor, Tensor]:
        (y_hat_prog, y_hat_m, phi_prog, phi_m)...
        If ``use_incremental_progress`` *and* ``return_thought``.
      Tuple[Tensor, Tensor]:
        (y_hat_prog, y_hat_m) OR (y_hat, phi)...
        First case if *only* ``use_incremental_progress``.
        Second case if *only* ``return_thought``.
      Tensor:
        y_hat...
        If neither ``use_incremental_progress`` nor ``return_thought``.
    """
    m = self.max_iterations if max_iterations is None else max_iterations
    # Get the input module result.
    x_tilde = self.input_module(x)
    # Only use incremental training if currently training and has the flag set.
    if self.training and self.use_incremental_progress:
      phi_prog, phi_m = self.incremental_progress(x_tilde, m)
      y_hat_prog = self.output_module(phi_prog)
      y_hat_m    = self.output_module(phi_m)
      return (y_hat_prog, y_hat_m, phi_prog, phi_m) if return_thought else \
             (y_hat_prog, y_hat_m)
    # Otherwise just compute iterations normally.
    phi = self.preprocessing_module(x_tilde) if phi is None else phi
    for _ in range(m):
      phi = self.perform_iteration(phi, x_tilde)
    y_hat = self.output_module(phi)
    return (y_hat, phi) if return_thought else y_hat

# REFERENCES:
#
# [1] Avi Schwarzschild, Eitan Borgnia, Arjun Gupta, Furong Huang, Uzi Vishkin,
#     Micah Goldblum, and Tom Goldstein.
#     "Can You Learn an Algorithm? Generalizing from Easy to Hard Problems
#      with Recurrent Networks".
#     CoRR. June 2021.
#     https://arxiv.org/abs/2106.04537
#
# [2] Arpit Bansal, Avi Schwarzschild, Eitan Borgnia, Zeyad Emam, Furong Huang,
#     Micah Goldlum, and Tom Goldstein.
#     "End-to-End Algorithm Synthesis with Recurrent Networks: Extrapolation
#      without Overthinking".
#     Advances in Neural Information Processing Systems. Oct 2022.
#     https://arxiv.org/abs/2202.05826