import torch
from IPython import display
import matplotlib.pyplot as plt
import torch.nn.functional as F

def display_loss(loss, xlabel='#iteration', ylabel='loss', title='Training loss'):
    display.clear_output(wait=True)
    plt.figure(figsize=(8, 6))
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.plot(loss, 'g')
    plt.yscale("log")
    plt.show()

def set_T(model, T):
    for layer in model:
        # All ODEBlock classes should contain ODEBlock
        if 'ODEBlock' in str(layer.__class__):
            layer.T = T

def train_ODENet(model, device, train_loader, optimizer, epoch, train_loss, T=None):
    # if T is not None:
    #     set_T(model, T)
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        train_loss.append(loss.item())
    display_loss(train_loss)

def test_ODENet(model, device, test_loader, test_loss, test_accuracy, is_training=False, T=None):
    if is_training:
        model.train()
    else:
        model.eval()
    cur_test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            cur_test_loss += F.nll_loss(output, target, reduction='sum').item() # sum up batch loss
            pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability
            correct += pred.eq(target.view_as(pred)).sum().item()

    cur_test_loss /= len(test_loader.dataset)
    test_loss.append(cur_test_loss)
    cur_accuracy = 100. * correct / len(test_loader.dataset)
    test_accuracy.append(cur_accuracy)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        cur_test_loss, correct, len(test_loader.dataset),
        cur_accuracy))
