
import pytorch_lightning as pl
import torch
from pytorch_lightning import Callback
from pytorch_lightning.utilities import rank_zero_only
import deepspeed

class LogParamsAndGrads(Callback):

    def __init__(self, model, log_gradient: bool, log_params: bool, log_quantiles:bool, log_every_n_steps: int):
        super().__init__()
        self.log_gradient = log_gradient
        self.log_params = log_params
        self.log_quantiles = log_quantiles
        self.log_every_n_steps = log_every_n_steps

        lora_A_name = "lora_A"
        lora_B_name = "lora_B"

        lora_params_indices = {}
        for idx, (name, param) in enumerate(model.named_parameters()):

            if not param.requires_grad:
                continue

            if lora_A_name in name:
                param_name = name.split(f".{lora_A_name}")[0]
                if param_name in lora_params_indices:
                    lora_params_indices[param_name]["lora_A"] = idx
                else:
                    lora_params_indices[param_name] = {'lora_A': idx}
            if lora_B_name in name:
                param_name = name.split(f".{lora_B_name}")[0]
                if param_name in lora_params_indices:
                    lora_params_indices[param_name]["lora_B"] = idx
                else:
                    lora_params_indices[param_name] = {'lora_B': idx}
        self.lora_params_indices = lora_params_indices

    #@rank_zero_only
    def on_before_optimizer_step(self, model, global_step, loggers):

        if global_step % self.log_every_n_steps == 0 and (self.log_params or self.log_gradient):

            stats = {}

            parameters = list(model.parameters())

            for k, lora_goup in self.lora_params_indices.items():
                lora_a = parameters[lora_goup["lora_A"]]
                lora_b = parameters[lora_goup["lora_B"]]
                v_detached = (lora_b @ lora_a).detach()

                stats[f"lora/{k}/mean"] = v_detached.mean().item()
                stats[f"lora/{k}/std"] = v_detached.std().item()
                stats[f"lora/{k}/min"] = v_detached.min().item()
                stats[f"lora/{k}/max"] = v_detached.max().item()
                # stats[f"lora/{k}/abs_mean"] = v_detached.abs().mean().item()
                # stats[f"lora/{k}/abs_std"] = v_detached.abs().std().item()
                stats[f"lora/{k}/l2m"] = (v_detached ** 2).mean().item()
                stats[f"lora/{k}/l2s"] = (v_detached ** 2).sum().item()
                stats[f"lora/{k}/l2sn"] = (v_detached ** 2).sum().item() / v_detached.numel()


            for k, v in model.named_parameters():
                if not v.requires_grad:
                    continue

                v_detached = v.detach()

                if torch.isnan(v_detached).sum() > 0: print(f"# NaN in param {k}")
                if torch.isinf(v_detached).sum() > 0: print(f"# Inf in param {k}")

                stats[f"param/{k}/mean"] = v_detached.mean().item()
                if v_detached.shape[0] > 1:
                    stats[f"param/{k}/std"] = v_detached.std().item()
                    stats[f"param/{k}/min"] = v_detached.min().item()
                    stats[f"param/{k}/max"] = v_detached.max().item()
                    stats[f"param/{k}/abs_mean"] = v_detached.abs().mean().item()
                    stats[f"param/{k}/abs_std"] = v_detached.abs().std().item()
                    stats[f"param/{k}/l2m"] = (v_detached**2).mean().item()
                    stats[f"param/{k}/l2s"] = (v_detached**2).sum().item()
                    stats[f"param/{k}/l2sn"] = (v_detached**2).sum().item() / v_detached.numel()


            for logger in loggers:
                logger.log_metrics(stats, step=global_step)

