import math
from copy import deepcopy
import torch
import torch.nn as nn


# helper functions for analyzing models
def para_norm(net, p=2):
    nm = 0
    for para in net.parameters():
        nm += para.view(-1).norm(p=p)**p
    return nm**(1.0/p)

def grad_norm(net, p=2):
    gm = 0 
    for para in net.parameters():
        gm += para.grad.view(-1).norm(p=2)**p 
    return gm**(1.0/p)

def num_para(net):
    cnt = 0
    for para in net.parameters():
        cnt += para.data.numel()
    return cnt

def vecterize_para(net):
    theta = torch.zeros(num_para(net))

    i = 0
    for p in net.parameters():
        vp = p.data.view(-1)
        theta[i:i+vp.numel()] = vp.clone()
        i += vp.numel()
    return theta

def vecterize_grad(net):
    grad = torch.zeros(num_para(net))

    i = 0
    for p in net.parameters():
        vg = p.grad.data.view(-1)
        grad[i:i+vg.numel()] = vg.clone()
        i += vg.numel()
    return grad

def set_model_para(model, new_para):
    """
    new_para: a vector containing the new parameters
    """
    idx = 0
    for p in model.parameters():
        p_shape = p.data.shape 
        p_num = p.data.numel()

        p.data.copy_(new_para[idx:idx+p_num].reshape(p_shape))
        idx += p_num 


# some linear algereas
def eigh(A):
    sA, vA = torch.linalg.eigh(A)
    sA = torch.flip(sA, dims=(0,))
    vA = torch.flip(vA, dims=(1,))
    return sA, vA

def eig_gram_matrix(F):
    """
    F: n x m, the feature matrix
    G=FF^T/n is a nxn matrix
    """
    n, m = F.shape
    v0 = torch.randn(m)
    G = (F @ F.t())/n
    s = torch.eigvalsh(G).flip(dims=(0,))
    return s


def power_method(v0, Av_func, n_iters=10, tol=1e-3, verbose=False):
    mu = 0
    v = v0/v0.norm()
    for i in range(n_iters):
        time_start = time.time()

        Av = Av_func(v)
        mu_pre = mu
        mu = torch.dot(Av,v).item()
        v = Av/Av.norm()

        if abs(mu-mu_pre)/abs(mu) < tol:
            break
        if verbose:
            print('%d-th step takes %.0f seconds, \t %.2e'%(i+1,time.time()-time_start,mu))
    return mu

