# -*- coding: utf-8 -*-
# !/usr/bin/python

import sys
import time
import torch
sys.path.append("..")
import random
from baselines.basic_trainer import BasicTrainer
from baselines.utils import DLFS


class TRTrainer(BasicTrainer):
    def __init__(self, args, model_save_path):
        super(TRTrainer, self).__init__(args, model_save_path)

        # CL component
        self.past_task_id = -1
        self.observed_task_ids = []
        self.memory_data = {}  # stores exemplars class by class

    def fix_ct_param(self):
        for n, p in self.model.named_parameters():
            if "col" not in n and "table" not in n and "lf" not in n:
                p.requires_grad = False
            else:
                p.requires_grad = True

    def fix_ts_param(self):
        for n, p in self.model.named_parameters():
            if "col" not in n and "table" not in n and "lf" not in n:
                p.requires_grad = True
            else:
                p.requires_grad = False

    def unfix_param(self):
        for n, p in self.model.named_parameters():
            p.requires_grad = True

    def train(self):
        for i in range(self.args.task_num):
            best_result = {"acc": 0.0, "epoch": 0}
            examples = self.task_controller.task_list[i]["train"]
            epoch_eval = self.args.epoch_eval
            patience = 0

            if i != self.past_task_id:
                self.observed_task_ids.append(i)
                self.past_task_id = i

            self.memory_data[i] = []
            # sampled_examples = RANDOM(examples=examples,
            #                           memory_size=self.args.memory_size)
            # sampled_examples = FSS(model=self.model,
            #                        examples=examples,
            #                        memory_size=self.args.memory_size, args=self.args)
            # sampled_examples = BALANCE(examples=examples,
            #                            memory_size=self.args.memory_size)
            # sampled_examples = LFS(examples=examples,
            #                            memory_size=self.args.memory_size)
            sampled_examples = DLFS(examples=examples,
                                    memory_size=self.args.memory_size)
            # sampled_examples = GSS(examples=examples,
            #                         memory_size=self.args.memory_size)

            self.memory_data[i].extend(sampled_examples)

            self.fix_ct_param()

            for epoch in range(self.args.epoch // 2):
                self.model.train()
                epoch_begin = time.time()
                random.shuffle(examples)
                st = 0
                report_loss, example_num = 0.0, 0
                cnt = 0
                self.optimizer.zero_grad()

                while st < len(examples):
                    ed = st + self.args.batch_size if st + self.args.batch_size < len(examples) else len(examples)

                    report_loss, example_num, loss = self.train_one_batch(examples[st:ed], report_loss, example_num)

                    loss.backward()
                    if (cnt + 1) % self.args.accumulation_step == 0 or ed == len(examples):
                        if self.args.clip_grad > 0.:
                            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad)
                        self.optimizer.step()
                        self.optimizer.zero_grad()

                    st = ed
                    cnt += 1

                print("Task {} Fast Updating, Epoch Train {},  Loss {}, Time {}".format(i, epoch, report_loss / example_num, time.time() - epoch_begin))

            for epoch in range(self.args.epoch):
                self.model.train()

                self.unfix_param()
                epoch_begin = time.time()
                random.shuffle(examples)
                st = 0
                report_loss, example_num = 0.0, 0
                cnt = 0
                self.optimizer.zero_grad()

                while st < len(examples):
                    ed = st + self.args.batch_size if st + self.args.batch_size < len(examples) else len(examples)

                    report_loss, example_num, loss = self.train_one_batch(examples[st:ed], report_loss, example_num)

                    loss.backward()
                    if (cnt + 1) % self.args.accumulation_step == 0 or ed == len(examples):
                        if self.args.clip_grad > 0.:
                            torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad)
                        self.optimizer.step()
                        self.optimizer.zero_grad()

                    st = ed
                    cnt += 1

                replay_example_num = 0
                if len(self.observed_task_ids) > 1:
                    self.fix_ts_param()
                    for _task_id in range(len(self.observed_task_ids) - 1):
                        past_task_id = self.observed_task_ids[_task_id]
                        replay_examples = random.sample(self.memory_data[past_task_id],
                                                        min(len(self.memory_data[past_task_id]), self.args.batch_size))

                        assert past_task_id != i

                        random.shuffle(replay_examples)
                        replay_report_loss = 0.0
                        _st = 0
                        _cnt = 0

                        while _st < len(replay_examples):
                            _ed = _st + self.args.batch_size if _st + self.args.batch_size < len(
                                replay_examples) else len(replay_examples)

                            replay_report_loss, replay_example_num, replay_loss = self.train_one_batch(
                                replay_examples[_st:_ed],
                                replay_report_loss,
                                replay_example_num)
                            replay_loss.backward()
                            report_loss += replay_report_loss
                            if (_cnt + 1) % self.args.accumulation_step == 0 or _ed == len(replay_examples):
                                if self.args.clip_grad > 0.:
                                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                                   self.args.clip_grad)
                                self.optimizer.step()
                                self.optimizer.zero_grad()
                            _st = _ed
                            _cnt += 1

                print("Task {} Slow Updating, Epoch Train {}, Loss {}, Time {}".format(i, epoch,
                                                                                       report_loss / (example_num + replay_example_num),
                                                                                       time.time() - epoch_begin))

                if epoch < epoch_eval:
                    continue

                start_time = time.time()
                dev_acc, beam_acc, (right, wrong, _), write_data = self.epoch_acc(
                    self.task_controller.task_list[i]["dev"])
                print('Evaluation: \tEpoch: %d\tTime: %.4f\tDev acc: %.4f\n' % (epoch, time.time() - start_time, dev_acc))

                if dev_acc >= best_result['acc']:
                    best_result['acc'], best_result['epoch'] = dev_acc, epoch
                    self.save(self.model, name="model.bin")
                    patience = 0
                else:
                    patience += 1

                if patience > self.args.max_patience:
                    break

            self.load(self.model)
            start_time = time.time()
            test_acc, beam_acc, (right, wrong, _), write_data = self.epoch_acc(
                self.task_controller.task_list[i]["test"])
            print('Evaluation: \tTime: %.4f\tTest acc: %.4f\n' % (time.time() - start_time, test_acc))

            self.first_acc_list[i] = test_acc
            self.eval_task_stream(i, test_acc)

        return self.avg_acc_list, self.whole_acc_list, self.bwt_list, self.fwt_list

    def train_one_batch(self, examples, report_loss, example_num):
        score = self.model.forward(examples)
        loss_sketch = -score[0]
        loss_lf = -score[1]

        _loss = torch.sum(loss_sketch).data.item() + torch.sum(loss_lf).data.item()
        #
        loss_sketch = torch.mean(loss_sketch)
        loss_lf = torch.mean(loss_lf)

        loss = loss_lf + loss_sketch

        report_loss += _loss
        example_num += len(examples)
        return report_loss, example_num, loss
