import torch

@torch.no_grad()
def nxr(parameters_to_optim, config, scale=0.):
    """
    nBAR for non-overparametrized problems
    """
    merged_lora_layers = parameters_to_optim['module.transformer.h.0.attn.c_attn.lora_A'].shape[0] \
                         // config.lora_attn_dim

    for l in range(config.n_layer):
        Astr = 'module.transformer.h.' + str(l) + '.attn.c_attn.lora_A'
        Bstr = 'module.transformer.h.' + str(l) + '.attn.c_attn.lora_B'

        lora_A = parameters_to_optim[Astr].split(config.lora_attn_dim, dim=0)
        lora_B = parameters_to_optim[Bstr].split(config.n_embd, dim=0)

        lora_A_grad = parameters_to_optim[Astr].grad.split(config.lora_attn_dim, dim=0)
        lora_B_grad = parameters_to_optim[Bstr].grad.split(config.n_embd, dim=0)

        # lora weights are merged.
        for i in range(merged_lora_layers):
            lora_A_grad_norm = lora_A_grad[i].norm(p='fro')
            lora_B_grad_norm = lora_B_grad[i].norm(p='fro')
            diff = lora_A_grad_norm - lora_B_grad_norm
            sign = -1. if diff >= 0. else 1.
            lora_A[i].mul_(1.0 - sign * scale)
            lora_B[i].mul_(1.0 + sign * scale)


@torch.no_grad()
def oxr(parameters_to_optim, config, scale=0.):
    """
    oBAR for overparametrized problems
    """

    merged_lora_layers = parameters_to_optim['module.transformer.h.0.attn.c_attn.lora_A'].shape[0] \
                         // config.lora_attn_dim

    for l in range(config.n_layer):
        Astr = 'module.transformer.h.' + str(l) + '.attn.c_attn.lora_A'
        Bstr = 'module.transformer.h.' + str(l) + '.attn.c_attn.lora_B'

        lora_A = parameters_to_optim[Astr].split(config.lora_attn_dim, dim=0)
        lora_B = parameters_to_optim[Bstr].split(config.n_embd, dim=0)

        # lora weights are merged.
        for i in range(merged_lora_layers):
            lora_A_norm = lora_A[i].norm(p='fro')
            lora_B_norm = lora_B[i].norm(p='fro')
            diff = lora_A_norm - lora_B_norm
            sign = 1. if diff > 0. else -1.
            lora_A[i].mul_(1.0 - sign * scale)
            lora_B[i].mul_(1.0 + sign * scale)


def xr_linear_schedule(mu_min, mu_max, cur_iter, all_iters=11271):
    mu = mu_min + (mu_max - mu_min) * (1.0 - cur_iter / all_iters)
    return mu
