from copy import deepcopy
import torch
import torch.nn as nn


def validate(net, 
    dataloader, 
    criterion=nn.MSELoss(), 
    device=torch.device('cpu')):

    net.to(device)
    dataloader.idx = 0
    loss_t, nbatch = 0, len(dataloader)

    for i in range(nbatch): 
        inputs, targets = next(dataloader)
        inputs, targets = inputs.to(device), targets.to(device)
        preds = net(inputs).squeeze()
        loss_t += criterion(preds, targets)

    net.cpu()
    return loss_t/nbatch


def train(net, 
    dataloader, 
    nsteps, 
    lr,
    criterion=nn.MSELoss(), 
    display=10, 
    num_checkpoint=0, 
    device=torch.device('cpu')):

    net = net.to(device)
    optimizer = torch.optim.SGD(net.parameters(), lr=lr)

    history = {'train': [], 'test': [], 'net': [], 'iter': []}
    step_to_save_net = 1e50 if num_checkpoint==0 else nsteps//num_checkpoint 

    for step in range(nsteps):
        inputs, targets = next(dataloader)
        inputs, targets = inputs.to(device), targets.to(device)


        optimizer.zero_grad()
        preds = net(inputs).squeeze()
        loss = criterion(preds, targets)

        if step % step_to_save_net==0:
            net_copy = deepcopy(net).cpu()
            history['net'].append(net_copy.state_dict())
            history['iter'].append(step+1)
        history['train'].append(loss.item())
        if (step+1) % display == 0:
            print('{:}, loss_tr: {:.1e}'.format(step+1, loss.item()))

        loss.backward()
        optimizer.step()


    net.cpu()
    return history



# class SGD:
#     def __init__(self, model, lr):
#         self.para = model.parameters()
#         self.lr = lr 

#     def generate_noise(self, stddev):
#         pass 

#     def step(self):
#         noise_r = self.generate_noise()

#         idx = 0
#         for p in self.para:
#             p_shape = p.data.shape 
#             p_num  = p.data.numel()

#             dp = self.para.grad.data  + noise[idx:idx+p_num].reshape(p_shape)

#             self.para.data.add_(-self.lr * dp)


# class IsotropicSGD(SGD):
#     def __init__(self, model, lr):
#         super(IsotropicSGD, self).__init__()
#         self.np = num_para(model)

#     def generate_noise(self, stddev):
#         noise = torch.randn(self.np)
#         noise *= stddev
#         return noise 


# class GramSGD(SGD):
#     def __init__(self, model, lr):
#         super(GramSGD, self).__init__()
#         self.np = num_para(model)
#         self.F = None # pxn
#         self.ns = 0

#     def compute_feature_matrix(self, X, y):
#         ana = AnalyzeLargeNet(net, X_tr, y_tr)
#         ana.compute_grads()
#         self.F = ana.grads.t()
#         self.ns = F.shape[1]

#     def generate_noise(self, stddev):
#         z = torch.randn(self.ns)
#         noise = (self.F @ z) * stddev
#         return noise 
























