import torch
from diffusers import DiffusionPipeline

class Monitor():
  def __init__(self):
    self.history = {}

  def coalesce(self):
    res = {}
    for k in self.history:
      timesteps = sorted(list(self.history[k]))
      if type(self.history[k][timesteps[0]]) == torch.Tensor:
        res[k] = torch.stack([self.history[k][t] for t in timesteps])
    return res

  def __call__(
    self,
    calling_pipeline: DiffusionPipeline,
    step: int,
    timestep: int,
    callback_kwargs: dict
  ):
    for k in callback_kwargs:
      if k not in self.history:
        self.history[k] = {}

      if step in self.history[k]:
        raise RuntimeError
      
      if type(callback_kwargs[k]) == torch.Tensor:
        self.history[k][step] = callback_kwargs[k].clone().detach().cpu()
      else:
        self.history[k][step] = callback_kwargs[k]

      # print(k,callback_kwargs[k].requires_grad)

    return callback_kwargs