from resnet.ops import soft_loss, loss, err, train_transform, test_transform, trans_err, parameters_vec, reverse_sigmoid
import os
import gc
import json
import torch
import torch.optim as optim
import time
from resnet.resnet import ResNet18, ResNet50
from torchvision.datasets import CIFAR10
from torch.utils.data import DataLoader, SubsetRandomSampler, BatchSampler


class Experiment(object):

    def __init__(
        self,
        name,
        load_from_exists_config=True,
        config={
            'device_name': 'cuda:0',
            'is_teacher_loaded': False,
            'teacher_model_load_path': None,
            'teacher_lr': 0.1,
            'teacher_batch_size': 128,
            'teacher_weight_decay': 0.0005,
            'teacher_momentum': 0.9,
            'teacher_lr_decay': 0.1,
            'teacher_lr_decay_epoch': [135, 185, 240],
            'teacher_train_epoch': 250,
            'is_student_trained': False,
            'is_student_using_augmentation': False,
            'student_datanum': 10000,
            'student_lr': 0.1,
            'student_batch_size': 256,
            'student_weight_decay': 0.0005,
            'student_momentum': 0.9,
            'student_lr_decay': 0.1,
            'student_lr_decay_epoch': [135, 185, 240],
            'student_train_epoch': 250,
            'soft_ratio': 1.0,
            'temperature': 1.0
        },
        experiment_dir='./experiment',
        data_dir='./data'
    ):
        self.name = name
        self.experiment_dir = experiment_dir
        self.data_dir = data_dir
        self.config_path = experiment_dir + '/' + name + '/config.json'

        self.makedirs()

        if load_from_exists_config and os.path.exists(self.config_path):
            with open(self.config_path, 'r', encoding='utf-8') as json_file:
                self.config = json.load(json_file)
            print('Load from existing config')
        else:
            self.config = config
            self.save_config()

        self.device = torch.device(self.config['device_name'])

    def makedirs(self):

        if not os.path.exists(self.experiment_dir):
            os.makedirs(self.experiment_dir)

        self.model_dir = self.experiment_dir + '/' + self.name
        if not os.path.exists(self.model_dir):
            os.makedirs(self.model_dir)

        self.net_dir = self.model_dir + '/net'
        if not os.path.exists(self.net_dir):
            os.makedirs(self.net_dir)

        self.teacher_net_dir = self.net_dir + '/teacher'
        if not os.path.exists(self.teacher_net_dir):
            os.makedirs(self.teacher_net_dir)

        self.student_net_dir = self.net_dir + '/student'
        if not os.path.exists(self.student_net_dir):
            os.makedirs(self.student_net_dir)

        self.log_dir = self.model_dir + '/log'
        if not os.path.exists(self.log_dir):
            os.makedirs(self.log_dir)

        self.student_train_log_dir = self.log_dir + '/student_train_log'
        if not os.path.exists(self.student_train_log_dir):
            os.makedirs(self.student_train_log_dir)

        self.teacher_train_log_dir = self.log_dir + '/teacher_train_log'
        if not os.path.exists(self.teacher_train_log_dir):
            os.makedirs(self.teacher_train_log_dir)

        self.other_log_dir = self.log_dir + '/others'
        if not os.path.exists(self.other_log_dir):
            os.makedirs(self.other_log_dir)

    def save_config(self):
        with open(self.config_path, 'w', encoding='utf-8') as json_file:
            json.dump(self.config, json_file, ensure_ascii=False, indent=4)

    def train_teacher(self):

        if self.config['is_teacher_loaded']:
            print('Teacher already exist!')
            return

        # data
        train_dataset = CIFAR10(self.data_dir, train=True, transform=train_transform, download=True)
        test_dataset = CIFAR10(self.data_dir, train=False, transform=test_transform, download=True)

        train_dataloader = DataLoader(
            train_dataset, batch_size=self.config['teacher_batch_size'], shuffle=True
        )
        test_dataloader = DataLoader(test_dataset, batch_size=1000, shuffle=True)

        # net
        if self.config['teacher_model_load_path'] is not None:
            net = torch.load(self.config['teacher_model_load_path'], map_location=self.device)

        else:
            net = ResNet50(num_classes=10).to(self.device)

        torch.save(net, self.teacher_net_dir + '/init.pth')

        # optim
        optimizer = optim.SGD(
            net.parameters(),
            lr=self.config['teacher_lr'],
            momentum=self.config['teacher_momentum'],
            weight_decay=self.config['teacher_weight_decay']
        )

        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=self.config['teacher_lr_decay_epoch'],
            gamma=self.config['teacher_lr_decay']
        )

        # print/saver
        with open(self.teacher_train_log_dir + '/train_log.csv', 'w') as f:
            f.write('iter,train_loss,train_error\n')

        with open(self.teacher_train_log_dir + '/test_log.csv', 'w') as f:
            f.write('epoch,test_loss,test_error\n')

        train_batch_num = len(train_dataloader)
        min_test_loss = 99999.

        # run epoch
        for epoch in range(self.config['teacher_train_epoch']):

            timer = time.time()

            for i, (train_images, train_labels) in enumerate(train_dataloader):

                iter = epoch*train_batch_num + i

                train_outputs = torch.squeeze(net(train_images.to(self.device)))
                train_loss = loss(train_outputs, train_labels.to(self.device))

                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()

                with torch.no_grad():
                    train_err = err(train_outputs, train_labels.to(self.device))

                with open(self.teacher_train_log_dir + '/train_log.csv', 'a') as f:
                    f.write(
                        '{:05d},{:.20e},{:.020f}\n'.format(
                            iter, train_loss.item(), train_err.item()
                        )
                    )

                # display
                if iter % 50 == 0:
                    print(
                        'TRAIN> epcoh: {:5d}, iter:{:5d}, train loss:{:.06e}, train err:{:.5f}'
                        .format(epoch, iter, train_loss.item(), train_err.item())
                    )

            scheduler.step(epoch)

            # test
            del train_loss, train_err, train_outputs
            torch.cuda.empty_cache()
            gc.collect

            with torch.no_grad():
                test_loss, test_err = [], []
                for test_images, test_labels in test_dataloader:

                    test_outputs = torch.squeeze(net(test_images.to(self.device)))
                    test_loss.append(loss(test_outputs, test_labels.to(self.device)))
                    test_err.append(err(test_outputs, test_labels.to(self.device)))

                test_loss = torch.mean(torch.tensor(test_loss))
                test_err = torch.mean(torch.tensor(test_err))

                with open(self.teacher_train_log_dir + '/test_log.csv', 'a') as f:
                    f.write(
                        '{:05d},{:.20e},{:.020f}\n'.format(
                            epoch, test_loss.item(), test_err.item()
                        )
                    )

                print(
                    '\nTEST> epoch:{:5d}, test loss:{:.06e}, test err:{:.5f}'.format(
                        epoch, test_loss.item(), test_err.item()
                    )
                )

            # save model
            if test_loss < min_test_loss:
                min_test_loss = test_loss
                torch.save(net, self.teacher_net_dir + '/teacher.pth')

            del test_outputs, test_loss, test_err
            torch.cuda.empty_cache()
            gc.collect

            print('time used:{:.02f}s\n'.format(time.time() - timer))

        self.config['is_teacher_loaded'] = True
        self.save_config()

    def load_teacher(self):

        if self.config['is_teacher_loaded']:
            print('Teacher already exist!')
            return

        if self.config['teacher_model_load_path'] is None:
            print('No source file of Teacher!')
            return

        net = torch.load(self.config['teacher_model_load_path'], map_location=self.device)
        torch.save(net, self.teacher_net_dir + '/teacher.pth')

        self.config['is_teacher_loaded'] = True
        self.save_config()

    def train_student(self):

        if not self.config['is_teacher_loaded']:
            print('Teacher not ready!')
            return
        if self.config['is_student_trained']:
            print('Student already trained!')
            return

        # data
        if self.config['is_student_using_augmentation']:
            train_dataset = CIFAR10(self.data_dir, train=True, transform=train_transform, download=True)
        else:
            train_dataset = CIFAR10(self.data_dir, train=True, transform=test_transform, download=True)
        test_dataset = CIFAR10(self.data_dir, train=False, transform=test_transform, download=True)

        indices = torch.randperm(50000)[:self.config['student_datanum']]
        sampler = SubsetRandomSampler(indices)
        batch_sampler = BatchSampler(
            sampler, batch_size=self.config['teacher_batch_size'], drop_last=False
        )

        train_dataloader = DataLoader(train_dataset, batch_sampler=batch_sampler)
        test_dataloader = DataLoader(test_dataset, batch_size=1000)

        # net
        net = ResNet18(num_classes=10).to(self.device)
        torch.save(net, self.student_net_dir + '/init.pth')
        teacher_net = torch.load(self.teacher_net_dir + '/teacher.pth', map_location=self.device)

        # optim
        optimizer = optim.SGD(
            net.parameters(),
            lr=self.config['student_lr'],
            momentum=self.config['student_momentum'],
            weight_decay=self.config['student_weight_decay']
        )

        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=self.config['student_lr_decay_epoch'],
            gamma=self.config['teacher_lr_decay']
        )

        # print/saver
        with open(self.student_train_log_dir + '/train_log.csv', 'w') as f:
            f.write('iter,train_loss,train_error,tranfer_error\n')

        with open(self.student_train_log_dir + '/test_log.csv', 'w') as f:
            f.write('epoch,test_loss,test_error,test_trans_error\n')

        train_batch_num = len(train_dataloader)
        min_test_loss = 99999.
        rho = self.config['soft_ratio']
        T = self.config['temperature']

        # run epoch
        for epoch in range(self.config['student_train_epoch']):

            timer = time.time()

            for i, (train_images, train_labels) in enumerate(train_dataloader):

                iter = epoch*train_batch_num + i

                with torch.no_grad():
                    train_teacher_outputs = torch.squeeze(teacher_net(train_images.to(self.device)))
                train_outputs = torch.squeeze(net(train_images.to(self.device)))

                train_loss = rho * T * T * soft_loss(
                    train_outputs / T, train_teacher_outputs / T
                ) + (1-rho) * loss(train_outputs, train_labels.to(self.device))

                optimizer.zero_grad()
                train_loss.backward()
                optimizer.step()

                with torch.no_grad():
                    train_err = err(train_outputs, train_labels.to(self.device))
                    train_trans_err = trans_err(train_outputs, train_teacher_outputs)

                with open(self.student_train_log_dir + '/train_log.csv', 'a') as f:
                    f.write(
                        '{:05d},{:.20e},{:.020f},{:.020f}\n'.format(
                            iter, train_loss.item(), train_err.item(), train_trans_err.item()
                        )
                    )

                # display
                if iter % 50 == 0:
                    print(
                        'TRAIN> epcoh: {:5d}, iter:{:5d}, train loss:{:.06e}, train err:{:.5f}, transfer err:{:.5f}'
                        .format(
                            epoch, iter, train_loss.item(), train_err.item(), train_trans_err.item()
                        )
                    )

            scheduler.step(epoch)

            # test
            del train_loss, train_err, train_outputs, train_trans_err, train_teacher_outputs
            torch.cuda.empty_cache()
            gc.collect

            with torch.no_grad():
                test_loss, test_err, test_trans_err = [], [], []
                for test_images, test_labels in test_dataloader:

                    test_teacher_outputs = torch.squeeze(teacher_net(test_images.to(self.device)))
                    test_outputs = torch.squeeze(net(test_images.to(self.device)))
                    test_loss.append(
                        rho * T * T * soft_loss(test_outputs / T, test_teacher_outputs / T) +
                        (1-rho) * loss(test_outputs, test_labels.to(self.device))
                    )
                    test_err.append(err(test_outputs, test_labels.to(self.device)))
                    test_trans_err.append(trans_err(test_outputs, test_teacher_outputs))

                test_loss = torch.mean(torch.tensor(test_loss))
                test_err = torch.mean(torch.tensor(test_err))
                test_trans_err = torch.mean(torch.tensor(test_trans_err))

            with open(self.student_train_log_dir + '/test_log.csv', 'a') as f:
                f.write(
                    '{:05d},{:.20e},{:.020f},{:.020f}\n'.format(
                        epoch, test_loss.item(), test_err.item(), test_trans_err.item()
                    )
                )

            print(
                '\nTEST> epoch:{:5d}, test loss:{:.06e}, test err:{:.5f}, test transfer error:{:.5f}'
                .format(epoch, test_loss.item(), test_err.item(), test_trans_err.item())
            )

            # save model
            if test_trans_err < min_test_loss:
                min_test_loss = test_trans_err
                torch.save(net, self.student_net_dir + '/student.pth')

            del test_outputs, test_loss, test_err, test_trans_err, test_teacher_outputs
            torch.cuda.empty_cache()
            gc.collect

            print('time used:{:.02f}s\n'.format(time.time() - timer))

        self.config['is_student_trained'] = True
        self.save_config()

    def student_weight_change(self):
        init = torch.load(self.student_net_dir + '/init.pth', map_location=self.device)
        net = torch.load(self.student_net_dir + '/student.pth', map_location=self.device)

        net_vec = parameters_vec(net)
        init_vec = parameters_vec(init)

        weight_change = torch.norm(net_vec - init_vec)
        # print(weight_change.item())

        return weight_change.cpu().detach().numpy()

    def effective_logits_std(self, net):

        net = net.to(self.device)

        test_dataset = CIFAR10(self.data_dir, train=False, transform=test_transform, download=True)
        test_dataloader = DataLoader(test_dataset, batch_size=1000)

        with torch.no_grad():
            eff_logits, eff_right_logits = [], []
            for test_images, test_labels in test_dataloader:

                test_outputs = net(test_images.to(self.device))
                eff_logits.append(reverse_sigmoid(torch.softmax(test_outputs, dim=1)))

                test_right_outputs = torch.max(torch.softmax(test_outputs, dim=1), dim=1)[0]
                eff_right_logits.append(reverse_sigmoid(test_right_outputs))

        eff_logits = torch.cat(eff_logits)
        eff_right_logits = torch.cat(eff_right_logits)

        # print(
        #     torch.mean(eff_logits).item(),
        #     torch.std(eff_logits).item(),
        #     torch.mean(eff_right_logits).item(),
        #     torch.std(eff_right_logits).item()
        # )

        return torch.mean(eff_logits).cpu().detach().numpy(), torch.std(eff_logits).cpu().detach(
        ).numpy(), torch.mean(eff_right_logits).cpu().detach().numpy(
        ), torch.std(eff_right_logits).cpu().detach().numpy()
