""" Makes and trains models for comparison tests, based on what we can find
    from https://arxiv.org/pdf/1711.07356.pdf
"""
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torchvision import datasets, transforms
import plnn

##############################################################################
#                                                                            #
#                           TRAINING | TESTING CODE                          #
#                                                                            #
##############################################################################


def l1_loss(net):
    return sum(_.norm(p=1) for _ in net.parameters() if _.dim() > 1)


def train_net(net, trainset, valset, num_epochs, adam_kwargs=None, l1_reg=None):

    # Set up optimizer
    default_adam_kwargs = {'lr': 1e-3, 'weight_decay': 0}
    if adam_kwargs is not None:
        default_adam_kwargs.update(adam_kwargs)

    # Set up loss function
    def loss_fxn(data, labels, l1_reg=l1_reg):
        output = net(Variable(data))
        loss = nn.CrossEntropyLoss()(output, Variable(labels)).view([1])
        if l1_reg is not None:
            loss += l1_reg * l1_loss(net).view([1])
        return loss

    opt = optim.Adam(net.parameters(), lr=1e-3, weight_decay=0)
    for epoch in range(num_epochs):
        for data, labels in trainset:
            opt.zero_grad()
            loss_val = loss_fxn(data, labels)
            loss_val.backward()
            opt.step()
        print("EPOCH %02d | TEST ACCURACY %.03f" %
             (epoch, test_acc(net, valset)))



def test_acc(net, valset):
    err_acc = 0
    err_count = 0
    for data, labels in valset:
        n = data.shape[0]
        output = net(Variable(data))
        err_acc += (output.max(1)[1].data != labels).float().mean() * n
        err_count += n
    return 1 - (err_acc / err_count).item()




##############################################################################
#                                                                            #
#                           MODEL BANK                                       #
#                                                                            #
##############################################################################
MNIST_IN_DIM = 784
MNIST_OUT_DIM = 10

def build_MLP_AB(a_or_b='A', mnist=True):
    assert a_or_b in ['A', 'B']
    assert mnist # only works for mnist right now
    if a_or_b == 'A':
        return plnn.PLNN(layer_sizes=[MNIST_IN_DIM, 500, MNIST_OUT_DIM])
    else:
        return plnn.PLNN(layer_sizes=[MNIST_IN_DIM, 200, 200, MNIST_OUT_DIM])




