import torch.optim as optim
import torch
from models import *
from color_mnist import *
from attacks import attack
import torchvision.models

device = 'cuda' if torch.cuda.is_available() else 'cpu'

import torchattacks

class Trainer(object):
    def __init__(self, model_type, flip_frac, spur_corr, dset, adv_params=None, num_epochs=20):
        self.model_type = model_type
        self.adv_params = adv_params
        self.adv_train = (adv_params is not None)
        self.flip_fraction = flip_frac
        self.spur_corr = spur_corr
        self.dset = dset

        self.init_model()
        self.init_spur_corr_fn()
        if self.adv_train:
            self.init_attack()

        self.construct_save_path()
        if self.dset == 'mnist':
            self.train_loader, self.test_loader = get_mnist_loaders()
        elif self.dset == 'cifar':
            self.train_loader, self.test_loader = get_cifar_loaders()

        self.criterion = nn.CrossEntropyLoss()
        self.optimizer = optim.Adam(self.model.parameters(), lr=0.001, betas=(0.9,0.999), weight_decay=0.0001)
        self.num_epochs = num_epochs

        self.best_acc = 0

    def init_attack(self):
        eps, alpha, steps = [self.adv_params[x] for x in ['eps', 'alpha', 'steps']]
        if self.adv_params['norm'] == 'l2':
            self.attack = torchattacks.PGDL2(self.model, eps=eps, alpha=alpha, steps=steps)
        elif self.adv_params['norm'] == 'linf':
            self.attack = torchattacks.PGD(self.model, eps=eps, alpha=alpha, steps=steps)
        else:
            raise ValueError('Adv params norm must either be l2 or linf')

    def init_model(self):
        if self.model_type == 'MLP':
            model = Net()
        elif self.model_type == 'resnet':
            model = torchvision.models.resnet18()
        elif self.model_type == 'ConvNet':
            model = ConvNet()
        else:
            raise ValueError('Did not recognize model type {}'.format(self.model_type))
        self.model = model.to(device)

    def init_spur_corr_fn(self):
        ff = self.flip_fraction
        if self.spur_corr == 'recolor':
            fn_1 = make_blue
            fn_2 = make_red
        elif self.spur_corr == 'lighting':
            fn_1 = lambda batch : alter_lighting(batch, scale=1.25)
            fn_2 = lambda batch : alter_lighting(batch, scale=0.75)
        elif self.spur_corr == 'color_shift':
            fn_1 = lambda batch : shift_color(batch, channel_dim=0)
            fn_2 = lambda batch : shift_color(batch, channel_dim=1)
        elif self.spur_corr == 'no_spur_corr':
            fn_1 = lambda batch: batch
            fn_2 = lambda batch: batch
        self.add_spurious_correlation = lambda batch, labels : insert_spur(batch, labels, ff, fn_1, fn_2)

    def construct_save_path(self):
        # save_path = './models/{}/{}_{}.pth'.format(self.dset, self.model_type, int(self.flip_fraction*100))
        ext = '_{}_eps{}'.format(self.adv_params['norm'], self.adv_params['eps']) if self.adv_train else ''
        save_path = './models/{}/{}_flip_frac_{}{}.pth'.format(self.dset, self.spur_corr, int(self.flip_fraction*100), ext)
        self.save_path = save_path

    def forward_pass(self, batch, labels):
        batch = self.add_spurious_correlation(batch, labels)
        batch, labels = batch.to(device), labels.to(device)
        if self.adv_params is not None:
            # atk = torchattacks.PGD(self.model, eps=eps/255, alpha=2/255, steps=steps)
            # batch = atk(batch, labels)

            # eps, steps, norm = [self.adv_params[x] for x in ['eps', 'steps', 'norm']]
            # batch = pgd_linf(self.model, batch, labels, eps=eps/255, steps=steps)

            # batch = attack(batch, labels, self.model, attack_eps=eps, attack_steps=steps)
            batch = self.attack(batch, labels)

        logits = self.model(batch)
        _, preds = torch.max(logits, 1)
        loss = self.criterion(logits, labels)
        correct = (preds == labels).sum().item()
        return loss, correct

    def test(self):
        self.model.eval()
        test_loss = 0
        total, total_correct = 0, 0
        for batch, labels in self.test_loader:
            loss, correct = self.forward_pass(batch, labels)
            total += labels.shape[0]
            test_loss += loss.item()
            total_correct += correct

        test_loss, test_acc = [x/total for x in [test_loss, total_correct]]
        return test_loss, test_acc

    def train(self):
        self.model.train()
        train_loss = 0
        total, total_correct = 0, 0
        for batch, labels in self.train_loader:
            self.optimizer.zero_grad()
            loss, correct = self.forward_pass(batch, labels)
            loss.backward()
            self.optimizer.step()

            total += labels.shape[0]
            train_loss += loss.item()
            total_correct += correct
    
        train_loss, train_acc = [x/total for x in [train_loss, total_correct]]
        return train_loss, train_acc
    
    def save_model(self):
        self.model.eval()
        model_dict = dict()
        model_dict['model_type'] = self.model_type
        model_dict['acc'] = self.best_acc
        model_dict['state'] = self.model.state_dict()
        model_dict['adv_params'] = self.adv_params
        model_dict['spur_corr'] = self.spur_corr
        model_dict['flip_fraction'] = self.flip_fraction
        torch.save(model_dict, self.save_path)
    
    def restore_model(self):
        model_dict = torch.load(self.save_path, map_location=torch.device(device))
        self.model.load_state_dict(model_dict['state'])
        self.model_type = model_dict['model_type']
        self.best_acc = model_dict['acc']
        self.adv_params = model_dict['adv_params']
        self.spur_corr = model_dict['spur_corr']
        self.flip_fraction = model_dict['flip_fraction']
        self.init_spur_corr_fn()

    def run(self):
        for epoch in range(self.num_epochs):
            train_loss, train_acc = self.train()
            print('Epoch: {}/{}.....Training Loss: {:.4f}.....Training Acc: {:.4f}'
                    .format(epoch, self.num_epochs, train_loss, train_acc))
            if (epoch % 2 == 0) or (epoch+1 == self.num_epochs):
                test_loss, test_acc = self.test()
                print('Epoch: {}/{}.....Testing Loss:  {:.4f}.....Testing Acc:  {:.4f}\n'
                        .format(epoch, self.num_epochs, test_loss, test_acc))
                if test_acc > self.best_acc:
                    self.best_acc = test_acc
                    self.save_model()
        print('Training Finished.')