import sys

import torch
from tqdm import tqdm

sys.path.append(".")
from src.tools.sharpness_tools.math_utils import hvp
from src.tools.sharpness_tools.utils import get_param_dim, get_device
from torch.cuda.amp import GradScaler


def frob_norm(model, data_loader, mcmc_itr):
    """
    :param model:
    :param data_loader:
    :param mcmc_itr: sampling iteration
    :return: frobenius norm
    """

    scalar = GradScaler()
    out = 0.0
    model_dim = get_param_dim(model)
    for _ in tqdm(range(mcmc_itr), ncols=120):
        v = torch.normal(mean=torch.zeros(model_dim), std=torch.ones(model_dim)).to(device=get_device(model)).half()
        out += torch.norm(hvp(model, data_loader, v, scalar), p=2) ** 2

    return (out / mcmc_itr).item() ** 0.5
