from __future__ import print_function
import os
import numpy as np
import argparse
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from collections import OrderedDict
from data_preprocess.process_fmnist import FMNIST_Dataset
from trainer.PLAD_trainer_fmnist import PLADTrainer
from VAE_fmnist import VAE
import itertools
import scipy.io

class FashionMNIST_LeNet(nn.Module):
    def __init__(self):
        super().__init__()

        self.rep_dim = 64
        self.pool = nn.MaxPool2d(2, 2)
        self.conv1 = nn.Conv2d(1, 16, 5, bias=False, padding=2)
        self.bn1 = nn.BatchNorm2d(16, eps=1e-04, affine=False)
        self.conv2 = nn.Conv2d(16, 32, 5, bias=False, padding=2)
        self.bn2 = nn.BatchNorm2d(32, eps=1e-04, affine=False)
        self.fc1 = nn.Linear(32 * 7 * 7, 128, bias=False)
        self.fc2 = nn.Linear(128, self.rep_dim, bias=False)
        self.fc3 = nn.Linear(self.rep_dim, 1, bias=False)

    def forward(self, x):
        x = x.view(x.shape[0],1,28,28)
        x = self.conv1(x)
        x = self.pool(F.leaky_relu(self.bn1(x)))
        x = self.conv2(x)
        x = self.pool(F.leaky_relu(self.bn2(x)))
        x = x.view(x.size(0), -1)
        x = self.fc1(x)
        x = self.fc2(x)
        x = self.fc3(x)
        return x


def main():
    dataset = FMNIST_Dataset("data", args.normal_class)
    train_loader, test_loader = dataset.loaders(batch_size=args.batch_size)
    print("Fashion-MNIST class: ", args.normal_class)
    model = FashionMNIST_LeNet().to(device)
    model = nn.DataParallel(model)
    
    e_ae = VAE().to(device)
     
    if args.optim == 1:
        optimizer = optim.SGD(itertools.chain(model.parameters(),e_ae.parameters()),lr=args.lr, momentum=args.mom)
        print("Optimizer: SGD")
    else:
        optimizer = optim.Adam(itertools.chain(model.parameters(),e_ae.parameters()), lr=args.lr, amsgrad = True)
        print("Optimizer: Adam")
    scores = []


    trainer = PLADTrainer(model,e_ae, optimizer, args.lamda, device)
    if args.eval == 0:
        # Training the model 
        score = trainer.train(train_loader, test_loader, args.lr, args.epochs, metric=args.metric)
        trainer.save(args.model_dir)

    else:
        if os.path.exists(os.path.join(args.model_dir, './fmnist_trained_model/fmnist-{}.pt'.format(args.normal_class))):
            filename = './fmnist_trained_model/fmnist-{}.pt'.format(args.normal_class)
            trainer.load(args.model_dir, filename)
            print("Testing the trained model on Fashion-MNIST class {}".format(args.normal_class))      
            print("Saved Model Loaded")
        else:
            print('Saved model not found. Cannot run evaluation.')
            exit()
        score = trainer.test(test_loader, 'AUC')
        print('Test AUC: {}'.format(score))

if __name__ == '__main__':
    torch.set_printoptions(precision=5)
    
    parser = argparse.ArgumentParser(description='PLAD Training')
    parser.add_argument('--normal_class', type=int, default=5, metavar='N',
                    help='CIFAR10 normal class index')
    parser.add_argument('--batch_size', type=int, default=512, metavar='N',
                        help='batch size for training')
    parser.add_argument('--epochs', type=int, default=100, metavar='N',
                        help='number of epochs to train')                   
    parser.add_argument('--lr', type=float, default=0.005, metavar='LR',
                        help='learning rate')   
    parser.add_argument('--lamda', type=float, default=5, metavar='N',
                        help='Weight of the perturbator loss')
    parser.add_argument('--optim', type=int, default=1, metavar='N',
                        help='0 : Adam 1: SGD')
    parser.add_argument('--mom', type=float, default=0.0, metavar='M',
                        help='momentum')
    parser.add_argument('--model_dir', default='log',
                        help='path where to save checkpoint')		
    parser.add_argument('--eval', type=int, default=1, metavar='N',
                        help='whether to load a saved model and evaluate (0/1)')
    parser.add_argument('-d', '--data_path', type=str, default='.')
    parser.add_argument('--metric', type=str, default='AUC')
    args = parser. parse_args()


    #Model save path
    model_dir = args.model_dir
    if not os.path.exists(model_dir):
        os.makedirs(model_dir)
    use_cuda = torch.cuda.is_available()
    device = torch.device("cuda" if use_cuda else "cpu")
    main()
