################################################################################
# training/utils/reduction_tools.py
#
# 
# 
# 2023
#
# Utility functions for reductions (averages, sums, etc.).

from typing_extensions import Self

from .type_checking import *

# Efficient running average class.
class RunningAverage():
  """
  Efficient running average class.
  """

  def __init__(self,
      scale: float = 1.0
    ):
    """
    Initializes ``RunningAverage``.

    Args:
      scale (float, optional):
        A scale applied to each added number to reduce the chances of
        overflow or floating point errors, if necessary *1*.
        Defaults to ``1.0``.

    *1*: The likelihood is this will never (and should never) be used.
    """
    check_if_type(scale, float, "scale")
    self.scale = scale
    # Start the x_sum as 0.0.
    self._x_sum: float = 0.0
    # Start the number of x values as 0.
    self._n_x: int = 0
  
  def __add__(self,
      # Arguments:
      x: float
    ) -> Self:
    """
    Adds a value to the running average.

    Args:
      x (float):
        Value to add to the running average.

    Returns:
      Self:
        The current ``RunningAverage`` object.
    """
    # Add the value divided by the total number.
    self._x_sum += self.scale * x
    # Increment the number of values.
    self._n_x += 1
    return self
  
  # Annoying boilerplate because procedural programming is tHe BeST tHiNg
  #  EvER!!11!
  def __radd__(self,
      # Arguments:
      x: float
    ) -> Self:
    """
    Adds a value to the running average.

    Args:
      x (float):
        Value to add to the running average.

    Returns:
      Self:
        The current ``RunningAverage`` object.
    """
    return self.__add__(x)
  
  def __call__(self) -> float:
    """
    Calculates the current average.

    Returns:
      float:
        The current calculated average.
    """
    return (self._x_sum / self._n_x) / self.scale