from pytorch_pretrained_bert.tokenization import BertTokenizer
from pytorch_pretrained_bert.modeling import BertForMaskedLM
from pytorch_pretrained_bert.optimization import BertAdam

import torch
from torch import nn
from torch import optim
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

import random

from classifier import Classifier


BERT_MODEL = 'bert-base-uncased'


class Generator:
    def __init__(self,
                 label_list,
                 learning_rate,
                 warmup_proportion,
                 num_train_steps,
                 device, use_bert_adam=False):
        self._label_list = label_list

        self._tokenizer = BertTokenizer.from_pretrained(
            BERT_MODEL, do_lower_case=True)

        self._model = BertForMaskedLM.from_pretrained(BERT_MODEL)
        if len(self._label_list) != 2:
            self._model.bert.embeddings.token_type_embeddings = \
                nn.Embedding(len(label_list), 768)
            self._model.bert.embeddings.token_type_embeddings.weight.data.\
                normal_(mean=0.0, std=0.02)

        self._device = device
        self._model.to(self._device)

        self._optimizer = _get_optimizer(
            self._model,
            learning_rate, warmup_proportion, num_train_steps, use_bert_adam)

        self._train_dataset = None
        self._train_dataloader = None

        self._dev_dataset = None
        self._dev_dataloader = None

        self._evaluator = None
        self._aug_dicts = None

    def eval_with_evaluator(self, max_seq_length, log_file=None):
        if self._evaluator is None:
            #print('No evaluator')
            return
        if self._aug_dicts is None:
            #print('No augs')
            return

        pred_example_ids = []
        for i, aug_dict in enumerate(self._aug_dicts):
            init_ids = aug_dict['init_ids']
            probs = aug_dict['probs']
            predict_probs = aug_dict['predict_probs']
            predict_ids = aug_dict['predict_ids']
            mask_idx = aug_dict['mask_idx']

            for rank in range(2):
                ids_onehot = torch.zeros_like(probs).scatter_(
                    1, init_ids[0].unsqueeze(1), 1.)
                for ti, t in enumerate(mask_idx):
                    pid_onehot = torch.zeros_like(probs[0:1]).scatter_(
                        1, predict_ids[i:i+1, rank:rank+1], 1.)
                    ids_onehot = torch.cat(
                        [ids_onehot[:t], pid_onehot, ids_onehot[t + 1:]], dim=0)
                pred_example_ids.append(ids_onehot)
                #print('ids_onehot')
                #print(ids_onehot.size())
                #print(ids_onehot)
                #exit()

        examples = [ex['example'] for ex in self._aug_dicts]

        accu = self._evaluator.eval_given_ids(pred_example_ids, examples, max_seq_length)

        print('\nGenerator augment acc: %.4f' % accu)
        if log_file is not None:
            print('Generator augment acc: %.4f' % accu, file=log_file)


    def load_train_data(self, train_examples, max_seq_length, batch_size):
        self._train_dataset = GeneratorDataset(
            examples=train_examples,
            label_list=self._label_list,
            tokenizer=self._tokenizer,
            max_seq_length=max_seq_length)

        self._train_dataloader = DataLoader(
            self._train_dataset, batch_size=batch_size, shuffle=True)


    def load_dev_data(self, dev_examples, max_seq_length, batch_size):
        self._dev_dataset = GeneratorDataset(
            examples=dev_examples,
            label_list=self._label_list,
            tokenizer=self._tokenizer,
            max_seq_length=max_seq_length)

        self._dev_dataloader = DataLoader(
            self._dev_dataset, batch_size=batch_size, shuffle=False)


    def eval_loss(self):
        self._model.eval()
        sum_loss = 0.
        nsamples = 0
        for step, batch in enumerate(self._dev_dataloader):
            batch = tuple(t.to(self._device) for t in batch)
            _, input_ids, input_mask, segment_ids, masked_ids = batch

            loss = self._model(input_ids, segment_ids, input_mask, masked_ids)

            sum_loss += loss.item()
            nsamples += len(input_ids)

        return sum_loss / nsamples


    def train_epoch(self):
        display_steps = len(self._train_dataloader) // 3

        avg_loss = 0.
        sum_loss = 0.
        self._model.train()
        for step, batch in enumerate(self._train_dataloader):
            batch = tuple(t.to(self._device) for t in batch)
            _, input_ids, input_mask, segment_ids, masked_ids = batch

            self._model.zero_grad()
            loss = self._model(input_ids, segment_ids, input_mask, masked_ids)
            loss.backward()
            self._optimizer.step()

            avg_loss += loss.item()
            sum_loss += loss.item()
            if (step + 1) % display_steps == 0:
                eval_loss = self.eval_loss()
                print("Generator Training, step {}, eval_loss: {}, train_loss: {}".format(
                    step, eval_loss, avg_loss / display_steps))
                avg_loss = 0.

        #return sum_loss / len(self._train_dataloader)
        return self.eval_loss()


    def _augment_example(self, example, max_seq_length, softmax_temperature,
                         log_file=None, return_all=False, num_aug=1):
        features = _convert_example_to_features(
            example=example,
            label_list=self._label_list,
            max_seq_length=max_seq_length,
            tokenizer=self._tokenizer)

        init_ids, _, input_mask, segment_ids, _ = \
            (t.view(1, -1).to(self._device) for t in features)

        len = int(torch.sum(input_mask).item())
        if len >= 4:
            mask_idx = sorted(
                random.sample(list(range(1, len - 1)), max(len // 7, 2)))
        else:
            mask_idx = [1]

        masked_ids = init_ids[0][mask_idx]
        init_ids[0][mask_idx] = \
            self._tokenizer.convert_tokens_to_ids(['[MASK]'])[0]
        logits = self._model(init_ids, segment_ids, input_mask)[0]
        #probs = F.softmax(logits / softmax_temperature, dim=1)
        #
        #if log_file is not None:
        #    predict_probs, predict_ids = probs[mask_idx].sort(descending=True)
        #    print('=' * 50, file=log_file)
        #    print('text:', self._tokenizer.convert_ids_to_tokens(
        #        init_ids[0][:len].tolist()), file=log_file)
        #    print('origin tokens:', self._tokenizer.convert_ids_to_tokens(
        #        masked_ids.tolist()), file=log_file)
        #    for i in range(3):
        #        print('predict tokens {}:'.format(i),
        #              self._tokenizer.convert_ids_to_tokens(
        #                  predict_ids[:, i].tolist()),
        #              predict_probs[:, i].tolist(), file=log_file)
        #    log_file.flush()

        # Get 2 samples
        aug_probs_all = []
        for _ in range(num_aug):
            probs = F.gumbel_softmax(logits, tau=softmax_temperature, hard=False) # TODO
            aug_probs = torch.zeros_like(probs).scatter_(
                1, init_ids[0].unsqueeze(1), 1.)
            for t in mask_idx:
                aug_probs = torch.cat(
                    [aug_probs[:t], probs[t:t + 1], aug_probs[t + 1:]], dim=0)

            aug_probs_all.append(aug_probs)

        aug_probs = torch.cat([ap.unsqueeze(0) for ap in aug_probs_all], dim=0)

        if not return_all:
            return aug_probs
        else:
            if log_file is not None:
               predict_probs, predict_ids = probs[mask_idx].sort(descending=True)
               print('=' * 50, file=log_file)
               print('text:', self._tokenizer.convert_ids_to_tokens(
                   init_ids[0][:len].tolist()), file=log_file)
               print('label:', example.label, file=log_file)
               print('origin tokens:', self._tokenizer.convert_ids_to_tokens(
                   masked_ids.tolist()), file=log_file)
               for i in range(10):
                   print('predict tokens {}:'.format(i),
                         self._tokenizer.convert_ids_to_tokens(
                             predict_ids[:, i].tolist()),
                         predict_probs[:, i].tolist(), file=log_file)
               log_file.flush()

            ret_dict = {
                'example': example,
                'label_id': _get_label_id(example.label, self._label_list),
                'init_ids': init_ids,
                'probs': probs,
                'predict_probs': predict_probs,
                'predict_ids': predict_ids,
                'mask_idx': mask_idx
            }

            #print('label_id')
            #print(ret_dict['label_id'])
            #print('probs')
            #print(probs.size())
            #print('predict_probs')
            #print(predict_probs.size())
            #print('predict_ids')
            #print(predict_ids.size())
            #print('init_ids')
            #print(init_ids.size())
            #print(mask_idx)
            #exit()

            return aug_probs #, ret_dict


    def _finetune_example(self, classifier, example, max_seq_length,
                          softmax_temperature, finetune_batch_size, num_aug=1):
        #print('example: {}'.format(example.size()))
        aug_probs = self._augment_example(
            example, max_seq_length, softmax_temperature, num_aug=num_aug)
        #print('aug_probs: {}'.format(aug_probs.size()))
        classifier.finetune_generator(
            example, aug_probs, max_seq_length,
            sample_dev_batches=-1, #TODO
            finetune_batch_size=finetune_batch_size)


    def finetune_and_augment_batch(
            self, classifier, examples, max_seq_length,
            softmax_temperature, log_file, finetune_generator=True, num_aug=1):

        if finetune_generator == 1:
            self._model.train()

            # # train it like a normal batch to prevent collapse
            # features = []
            # for example in examples:
            #     features.append(_convert_example_to_features(
            #         example=example,
            #         label_list=self._label_list,
            #         max_seq_length=max_seq_length,
            #         tokenizer=self._tokenizer))
            # features = [torch.cat([t[i].unsqueeze(0) for t in features], dim=0).to(
            #     self._device) for i in range(5)]
            #
            # _, input_ids, input_mask, segment_ids, masked_ids = features
            #
            # self._model.zero_grad()
            # loss = self._model(input_ids, segment_ids, input_mask, masked_ids)
            # loss.backward()
            # self._optimizer.step()

            ## finetune
            self._model.zero_grad()

            # for training efficiency, just take some samples to finefune generator.
            finetune_examples = random.sample(examples, len(examples) // 2) #TODO #5

            for example in finetune_examples:
                self._finetune_example(
                    classifier, example, max_seq_length, softmax_temperature,
                    finetune_batch_size=len(finetune_examples), num_aug=num_aug)
            self._optimizer.step()

        self._model.eval()
        aug_examples = []
        self._aug_dicts = []
        for example in examples:
            with torch.no_grad():
                #aug_probs, ret_dict = self._augment_example(
                #    example, max_seq_length, softmax_temperature,
                #    log_file=log_file, return_all=True)
                aug_probs = self._augment_example(
                    example, max_seq_length, softmax_temperature,
                    log_file=log_file, return_all=True, num_aug=num_aug)

                aug_examples.append((example, aug_probs))

                #self._aug_dicts.append(ret_dict)

        self.eval_with_evaluator(max_seq_length, log_file)

        return aug_examples

    def augment_example_cbert(self, example, max_seq_length, writer):
        features = _convert_example_to_features(
            example=example,
            label_list=self._label_list,
            max_seq_length=max_seq_length,
            tokenizer=self._tokenizer)

        init_ids, _, input_mask, segment_ids, _ = \
            (t.view(1, -1).to(self._device) for t in features)

        len = int(torch.sum(input_mask).item())
        if len >= 4:
            mask_idx = sorted(
                random.sample(list(range(1, len - 1)), max(len // 7, 2)))
        else:
            mask_idx = [1]

        masked_ids = init_ids[0][mask_idx]
        init_ids[0][mask_idx] = \
            self._tokenizer.convert_tokens_to_ids(['[MASK]'])[0]
        logits = self._model(init_ids, segment_ids, input_mask)[0]

        _, predict_ids = logits[mask_idx].sort(descending=True)

        def _to_string(token_list):
            return ' '.join(token_list[1:-1])

        # orogin text
        init_ids[0][mask_idx] = masked_ids

        writer.writerow([_to_string(self._tokenizer.convert_ids_to_tokens(
            init_ids[0][:len].tolist())), example.label])
        # aug text
        for i in range(2):
            init_ids[0][mask_idx] = predict_ids[:, i]
            writer.writerow([_to_string(self._tokenizer.convert_ids_to_tokens(
                init_ids[0][:len].tolist())), example.label])

    @property
    def model(self):
        return self._model

    @property
    def optimizer(self):
        return self._optimizer


class GeneratorDataset(Dataset):
    def __init__(self, examples, label_list, tokenizer, max_seq_length):
        self._examples = examples
        self._label_list = label_list
        self._tokenizer = tokenizer
        self._max_seq_length = max_seq_length

        # self._features = []
        # for example in self._examples:
        #     self._features.append(_convert_example_to_features(
        #         example=example,
        #         label_list=self._label_list,
        #         max_seq_length=self._max_seq_length,
        #         tokenizer=self._tokenizer))

    def __len__(self):
        return len(self._examples)

    def __getitem__(self, index):
        # return self._features[index]

        # generate different random masks every time.
        return _convert_example_to_features(
            example=self._examples[index],
            label_list=self._label_list,
            max_seq_length=self._max_seq_length,
            tokenizer=self._tokenizer)


def _get_optimizer(model, learning_rate, warmup_proportion,
                   num_train_steps, use_bert_adam):
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if
                    not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if
                    any(nd in n for nd in no_decay)], 'weight_decay': 0.0}]

    if use_bert_adam:
        return BertAdam(
            optimizer_grouped_parameters,
            lr=learning_rate,
            warmup=warmup_proportion,
            t_total=num_train_steps)
    else:
        return optim.Adam(optimizer_grouped_parameters, lr=learning_rate)



def _get_label_id(label, label_list):
    label_map = {}
    for (i, label) in enumerate(label_list):
        label_map[label] = i
    return label_map[label]


def _convert_example_to_features(
        example, label_list, max_seq_length, tokenizer):
    """
    this function is copied from
    https://github.com/IIEKES/cbert_aug/blob/master/aug_dataset_wo_ft.py#L119
    """

    label_map = {}
    for (i, label) in enumerate(label_list):
        label_map[label] = i

    masked_lm_prob = 0.15
    max_predictions_per_seq = 20

    tokens_a = tokenizer.tokenize(example.text_a)
    segment_id = label_map[example.label]
    # Account for [CLS] and [SEP] with "- 2"
    if len(tokens_a) > max_seq_length - 2:
        tokens_a = tokens_a[0:(max_seq_length - 2)]

    # 由于是CMLM，所以需要用标签
    tokens = []
    segment_ids = []
    # is [CLS]和[SEP] needed ？
    tokens.append("[CLS]")
    segment_ids.append(segment_id)
    for token in tokens_a:
        tokens.append(token)
        segment_ids.append(segment_id)
    tokens.append("[SEP]")
    segment_ids.append(segment_id)
    masked_lm_labels = [-1] * max_seq_length

    cand_indexes = []
    for (i, token) in enumerate(tokens):
        if token == "[CLS]" or token == "[SEP]":
            continue
        cand_indexes.append(i)

    random.shuffle(cand_indexes)
    len_cand = len(cand_indexes)

    output_tokens = list(tokens)

    num_to_predict = min(max_predictions_per_seq,
                         max(1, int(round(len(tokens) * masked_lm_prob))))

    masked_lms_pos = []
    covered_indexes = set()
    for index in cand_indexes:
        if len(masked_lms_pos) >= num_to_predict:
            break
        if index in covered_indexes:
            continue
        covered_indexes.add(index)

        # 80% of the time, replace with [MASK]
        if random.random() < 0.8:
            masked_token = "[MASK]"
        else:
            # 10% of the time, keep original
            if random.random() < 0.5:
                masked_token = tokens[index]
            # 10% of the time, replace with random word
            else:
                masked_token = tokens[cand_indexes[
                    random.randint(0, len_cand - 1)]]

        masked_lm_labels[index] = \
            tokenizer.convert_tokens_to_ids([tokens[index]])[0]
        output_tokens[index] = masked_token
        masked_lms_pos.append(index)

    init_ids = tokenizer.convert_tokens_to_ids(tokens)
    input_ids = tokenizer.convert_tokens_to_ids(output_tokens)

    # The mask has 1 for real tokens and 0 for padding tokens. Only real
    # tokens are attended to.
    input_mask = [1] * len(input_ids)

    # Zero-pad up to the sequence length.
    while len(input_ids) < max_seq_length:
        init_ids.append(0)
        input_ids.append(0)
        input_mask.append(0)
        segment_ids.append(0)  # ?segment_id

    assert len(init_ids) == max_seq_length
    assert len(input_ids) == max_seq_length
    assert len(input_mask) == max_seq_length
    assert len(segment_ids) == max_seq_length

    return (torch.tensor(init_ids),
            torch.tensor(input_ids),
            torch.tensor(input_mask),
            torch.tensor(segment_ids),
            torch.tensor(masked_lm_labels))


def _rev_wordpiece(str):
    if len(str) > 1:
        for i in range(len(str)-1, 0, -1):
            if str[i] == '[PAD]':
                str.remove(str[i])
            elif len(str[i]) > 1 and str[i][0] == '#' and str[i][1] == '#':
                str[i-1] += str[i][2:]
                str.remove(str[i])
    return " ".join(str[1:-1])
