# -*- coding: utf-8 -*-
# !/usr/bin/python

import sys
import torch
import time
import random
sys.path.append("..")
from baselines.basic_trainer import BasicTrainer


class MAMLTrainer(BasicTrainer):
    def __init__(self, args, model_save_path):
        super(MAMLTrainer, self).__init__(args, model_save_path)

    def meta_train(self, examples):
        torch.autograd.set_detect_anomaly(True)
        schema_dict = {}
        for example in examples:
            schema = " ||| ".join(example.cols_set)
            if schema not in schema_dict:
                schema_dict[schema] = []
            schema_dict[schema].append(example)

        n_way = self.args.n_way
        k_shot = self.args.k_shot
        meta_tasks = []
        for i in range(self.args.meta_task_num):
            support_set = []
            query_set = []
            schemas = random.sample(schema_dict.keys(), n_way)
            for schema in schemas:
                support_set.extend(random.sample(schema_dict[schema], k_shot))

            schemas = random.sample(schema_dict.keys(), n_way)
            for schema in schemas:
                query_set.extend(random.sample(schema_dict[schema], k_shot))

            meta_tasks.append([support_set, query_set])

        meta_report_loss = 0.0
        for support_set, query_set in meta_tasks:

            _, _, s_loss = self.train_one_batch(support_set, 0, len(support_set))
            s_loss.backward(retain_graph=True)

            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            self.optimizer.zero_grad()

            _, _, q_loss = self.train_one_batch(query_set, 0, len(query_set))

            meta_loss = 0.5 * s_loss + 0.5 * q_loss
            meta_loss.backward()

            meta_report_loss += meta_loss.data

            # self.load(model_path, "temp")

            torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
            self.optimizer.step()
            self.optimizer.zero_grad()
            torch.cuda.empty_cache()
        return meta_report_loss

    def train(self):

        self.args.skip_task_one = self.args.adaptation_cpt != ''

        if self.args.skip_task_one:
            self.load(self.model, name='model_0.bin', path=self.args.adaptation_cpt)
            start_task = 1
            start_time = time.time()
            test_acc, beam_acc, (right, wrong, _), write_data = self.epoch_acc(
                self.task_controller.task_list[0]["test"])
            print('Evaluation: \tTime: %.4f\tTest acc: %.4f\n' % (time.time() - start_time, test_acc))
            self.first_acc_list[0] = test_acc
            self.eval_task_stream(0, test_acc)
        else:
            start_task = 0

        for i in range(start_task, self.args.task_num):
            best_result = {"acc": 0.0, "epoch": 0}

            n_epochs = self.args.epoch
            epoch_eval = self.args.epoch_eval

            # model_name = 'model_0.bin' if i == 0 else 'model.bin'
            model_name = 'model.bin'

            if i > 0:
                patience = 0
                for epoch in range(20):

                    start_time = time.time()
                    meta_loss = self.meta_train(self.task_controller.task_list[i]["train"])
                    print("\nTask {}, Epoch Train {}, Meta Loss {}, Time {}".format(i, epoch, meta_loss,
                                                                                    time.time() - start_time))

                    if epoch < 15:
                        continue

                    start_time = time.time()
                    dev_acc, beam_acc, (right, wrong, _), write_data = self.epoch_acc(
                        self.task_controller.task_list[i]["dev"])
                    print('Evaluation: \tEpoch: %d\tTime: %.4f\tDev acc: %.4f\n' % (
                    epoch, time.time() - start_time, dev_acc))

                    if dev_acc >= best_result['acc']:
                        best_result['acc'], best_result['epoch'] = dev_acc, epoch
                        self.save(self.model, name=model_name)
                        patience = 0
                    else:
                        patience += 1

                    if patience > self.args.max_patience:
                        break

            patience = 0
            for epoch in range(n_epochs):
                start_time = time.time()

                loss = self.train_one_epoch(self.task_controller.task_list[i]["train"])
                print("\nTask {}, Epoch Train {}, Loss {}, Time {}".format(i, epoch, loss, time.time() - start_time))

                # print(epoch, epoch_eval)

                if epoch < epoch_eval:
                    continue

                start_time = time.time()
                dev_acc, beam_acc, (right, wrong, _), write_data = self.epoch_acc(self.task_controller.task_list[i]["dev"])
                print('Evaluation: \tEpoch: %d\tTime: %.4f\tDev acc: %.4f\n' % (epoch, time.time() - start_time, dev_acc))

                if dev_acc >= best_result['acc']:
                    best_result['acc'], best_result['epoch'] = dev_acc, epoch
                    self.save(self.model, name=model_name)

                    patience = 0
                else:
                    patience += 1

                if patience > self.args.max_patience:
                    break

            self.load(self.model, name=model_name)
            start_time = time.time()
            test_acc, beam_acc, (right, wrong, _), write_data = self.epoch_acc(self.task_controller.task_list[i]["test"])
            print('Evaluation: \tTime: %.4f\tTest acc: %.4f\n' % (time.time() - start_time, test_acc))

            self.first_acc_list[i] = test_acc
            self.eval_task_stream(i, test_acc)

        return self.avg_acc_list, self.whole_acc_list, self.bwt_list, self.fwt_list
