# -*- coding: utf-8 -*-
# !/usr/bin/python

import sys
import time
sys.path.append("..")
from baselines.basic_trainer import BasicTrainer


class ICTTrainer(BasicTrainer):
    def __init__(self, args, model_save_path):
        super(ICTTrainer, self).__init__(args, model_save_path)

    def train(self):

        self.args.skip_task_one = self.args.adaptation_cpt != ''

        if self.args.skip_task_one:
            self.load(self.model, path=self.args.adaptation_cpt)
            start_task = 1
            start_time = time.time()
            test_acc, beam_acc, (right, wrong, _), write_data = self.epoch_acc(
                self.task_controller.task_list[0]["test"])
            print('Evaluation: \tTime: %.4f\tTest acc: %.4f\n' % (time.time() - start_time, test_acc))
            self.first_acc_list[0] = test_acc
            self.eval_task_stream(0, test_acc)
        else:
            start_task = 0

        for i in range(start_task, self.args.task_num):
            best_result = {"acc": 0.0, "epoch": 0}

            n_epochs = self.args.epoch
            epoch_eval = self.args.epoch_eval

            patience = 0
            for epoch in range(n_epochs):
                start_time = time.time()
                loss = self.train_one_epoch(self.task_controller.task_list[i]["train"])
                print("\nTask {}, Epoch Train {}, Loss {}, Time {}".format(i, epoch, loss, time.time() - start_time))

                if epoch < epoch_eval:
                    continue

                start_time = time.time()
                dev_acc, beam_acc, (right, wrong, _), write_data = self.epoch_acc(self.task_controller.task_list[i]["dev"])
                print('Evaluation: \tEpoch: %d\tTime: %.4f\tDev acc: %.4f\n' % (epoch, time.time() - start_time, dev_acc))

                if dev_acc >= best_result['acc']:
                    best_result['acc'], best_result['epoch'] = dev_acc, epoch
                    self.save(self.model, name="model.bin")
                    patience = 0
                else:
                    patience += 1

                if patience > self.args.max_patience:
                    break

            self.load(self.model)
            start_time = time.time()
            test_acc, beam_acc, (right, wrong, _), write_data = self.epoch_acc(self.task_controller.task_list[i]["test"])
            print('Evaluation: \tTime: %.4f\tTest acc: %.4f\n' % (time.time() - start_time, test_acc))

            self.first_acc_list[i] = test_acc
            self.eval_task_stream(i, test_acc)

        return self.avg_acc_list, self.whole_acc_list, self.bwt_list, self.fwt_list
