import pickle
import numpy as np
import torch
import torch.nn as nn

from elf.options import auto_import_options, PyOptionSpec
from rlpytorch.utils import assert_eq

import rlpytorch.behavior_clone.global_consts as gc
from rlpytorch.behavior_clone.cmd_heads import GlobClsHead
from rlpytorch.behavior_clone.module import MlpEncoder, RnnLanguageGenerator
from rlpytorch.behavior_clone.instruction_selector import RnnSelector
from rlpytorch.behavior_clone.utils import convert_to_raw_instruction
from rlpytorch.behavior_clone.inst_dict import put_forward_scout_inst

# for data process for sampling
from rlpytorch.behavior_clone.coach_dataset import CoachDataset
from rlpytorch.behavior_clone.instruction_encoder import LSTMInstructionEncoder
from rlpytorch.behavior_clone.instruction_encoder import MeanBOWInstructionEncoder
# from rlpytorch.behavior_clone.instruction_encoder import BOWInstructionEncoder
from rlpytorch.sampler import ContSoftmaxSampler
from rlpytorch.behavior_clone.conv_glob_encoder import GlobEncoder

from elfgames.rts.game_MC.executor_based_loader import ExecutorBasedLoader


class ConvRnnGenerator(nn.Module):
    @classmethod
    def get_option_spec(cls):
        spec = PyOptionSpec()
        # input dim configs
        # spec.addIntOption('num_unit_type', '', len(gc.UnitTypes))
        # spec.addIntOption('num_cmd_type', '', len(gc.CmdTypes))
        # spec.addIntOption('num_resource_bin', '', 11)
        # spec.addIntOption('resource_bin_size', '', 50)
        # spec.addIntOption('max_num_prev_cmds', '', 10)

        # human dataset
        spec.addStrOption('inst_dict_path', 'path to dictionary', 'data3/train.json_min10_dict.pt')
        spec.addIntOption('num_instructions', 'num instruction to use for softmax', 50)
        spec.addIntOption('num_negatives', 'num instruction to use for softmax', 50)

        # prev instruction encoder
        spec.addIntOption('word_emb_dim', 'size of word emb in rnn', 32)
        spec.addFloatOption('word_emb_dropout', 'size of word emb in rnn', 0)
        spec.addIntOption('inst_emb_dim', 'hid size of instruction encoder', 128)

        # count features encoder
        spec.addIntOption(
            'num_count_channels', '', CoachDataset.compute_num_count_channels())
        spec.addIntOption('count_dim', '', 128)
        spec.addIntOption('count_layers', '', 2)

        # classifiers
        spec.addIntOption('cont_cls_dim', 'hid dim of continue classifier', 128)
        spec.addFloatOption('cont_cls_dropout', 'hid dim of continue classifier', 0)
        spec.addIntOption('inst_cls_dim', 'hid dim of instruction selector', 128)

        # others
        spec.addIntOption('max_sentence_length', '', 10)
        spec.addIntOption('forward_scout_inst', '', 1)
        # spec.addStrOption('coach_mode', 'rnn or bow', '')

        spec.merge(GlobEncoder.get_option_spec())

        return spec

    def load_inst_dict(self, inst_dict_path):
        print('loading cmd dict from: ', inst_dict_path)
        if inst_dict_path is None or inst_dict_path == '':
            return None

        inst_dict = pickle.load(open(inst_dict_path, 'rb'))
        inst_dict.set_max_sentence_length(self.options.max_sentence_length)
        # inst_dict.set_max_num_instructions(self.options.num_instructions)
        if self.options.forward_scout_inst:
            inst_dict = put_forward_scout_inst(inst_dict)
            print('added scout to inst dictionary')
        return inst_dict

    @auto_import_options
    def __init__(self, option_map, max_raw_chars, max_instruction_span):
        super().__init__()

        self.max_raw_chars = max_raw_chars
        self.max_instruction_span = max_instruction_span
        # self.coach_mode = coach_mode

        self.pos_candidate_inst = None
        self.neg_candidate_inst = None

        self.inst_dict = self.load_inst_dict(self.options.inst_dict_path)

        self.instruction_encoder = LSTMInstructionEncoder(
            self.inst_dict.total_vocab_size,
            self.options.word_emb_dim,
            self.options.word_emb_dropout,
            self.options.inst_emb_dim,
            self.inst_dict.pad_word_idx,
        )

        self.glob_encoder = GlobEncoder(option_map, self.instruction_encoder)

        # count encoders
        self.count_encoder = MlpEncoder(
            self.options.num_count_channels * 2,
            self.options.count_dim,
            self.options.count_dim,
            self.options.count_layers - 1,
            activate_out=True
        )
        self.cons_count_encoder = MlpEncoder(
            self.options.num_unit_type,
            self.options.count_dim // 2,
            self.options.count_dim // 2,
            self.options.count_layers - 1,
            activate_out=True
        )
        self.moving_avg_encoder = MlpEncoder(
            self.options.num_unit_type,
            self.options.count_dim // 2,
            self.options.count_dim // 2,
            self.options.count_layers - 1,
            activate_out=True
        )
        self.frame_passed_encoder = nn.Embedding(
            max_instruction_span + 2,
            self.options.count_dim // 2,
        )

        self.glob_feat_dim = int(
            self.options.inst_emb_dim
            + 2 * self.options.count_dim
            + self.glob_encoder.glob_dim
        )
        self.count_feat_dim = int(
            self.options.inst_emb_dim
            + 2.5 * self.options.count_dim
            + self.glob_encoder.glob_dim
        )

        self.value = nn.utils.weight_norm(
            nn.Linear(self.glob_feat_dim, 1), dim=None
        )
        self.cont_cls = GlobClsHead(
            self.count_feat_dim,
            self.options.cont_cls_dim,
            2,
            self.options.cont_cls_dropout
        )

        self.inst_selector = RnnLanguageGenerator(
            self.instruction_encoder.emb,
            self.options.word_emb_dim,
            self.glob_feat_dim,
            self.options.inst_cls_dim,
            self.inst_dict.total_vocab_size)

        # self.inst_selector = RnnSelector(encoder, self.glob_feat_dim)
        self.sampler = ContSoftmaxSampler(
            'inst_cont', 'inst_cont_pi', 'inst', 'inst_pi')

    def format_coach_input(self, batch):
        # print('formating coach input')
        if 'hist_count' in batch:
            new_batch = {}
            for key, val in batch.batch.items():
                assert key.startswith('hist_')
                new_key = key.split('_', 1)[1]
                new_batch[new_key] = val
            batch = new_batch
            # print('train')

        data = {
            'prev_inst_idx': batch['prev_inst'],
            'count': batch['count'],
            'base_count': batch['base_count'],
            'cons_count': batch['cons_count'],
            'moving_enemy_count': batch['moving_enemy_count'],
            'resource_bin': batch['resource_bin'],
            'frame_passed': batch['frame_passed'],
            'pre_ins_prev_cmds': batch['previous_cmd'],
        }

        extra_data = self.glob_encoder.format_input(batch)
        data.update(extra_data)
        return data

    def get_input_keys(self):
        keys = ExecutorBasedLoader._get_train_spec()['input']
        return keys

    def get_action(self, batch):
        a = {
            'inst_cont': batch['hist_inst_cont'],
            'inst': batch['hist_inst']
        }
        return a

    def get_policy(self, batch):
        pi = {
            'inst_cont_pi': batch['hist_inst_cont_pi'],
            'inst_pi': batch['hist_inst_pi'],
        }
        return pi

    def _forward(self, batch):
        """shared forward function to compute glob feature
        """
        prev_inst_feat = self.instruction_encoder(
            batch['prev_inst'], batch['prev_inst_len'])
        count_input = torch.cat(
            [batch['count'], batch['base_count'] - batch['count']], 1)
        count_feat = self.count_encoder(count_input)
        cons_count_feat = self.cons_count_encoder(batch['cons_count'])
        moving_avg_feat = self.moving_avg_encoder(batch['moving_enemy_count'])
        # resource_feat = self.resource_encoder(batch['resource_bin'])
        # print(batch['frame_passed'].max())
        frame_passed_feat = self.frame_passed_encoder(batch['frame_passed'])

        _, _, _, glob, _ = self.glob_encoder(batch)
        # print('>>>', glob.size())

        glob_feat = torch.cat([
            glob,
            prev_inst_feat,
            count_feat,
            cons_count_feat,
            moving_avg_feat,
            # resource_feat
        ], dim=1)
        count_feat = torch.cat([
            glob,
            prev_inst_feat,
            count_feat,
            cons_count_feat,
            moving_avg_feat,
            # resource_feat,
            frame_passed_feat
        ], dim=1)

        return glob_feat, count_feat

    def compute_loss(self, batch):
        """used for pre-training the model with dataset
        """
        # batch = self._format_supervised_language_input(batch)
        batch = self._format_language_input(batch)
        glob_feat, cont_feat = self._forward(batch)

        cont = 1 - batch['is_base_frame']
        cont_loss = self.cont_cls.compute_loss(cont_feat, cont)
        lang_loss = self.inst_selector.compute_loss(
            # batch['pos_cand_inst'],
            # batch['pos_cand_inst_len'],
            # batch['neg_cand_inst'],
            # batch['neg_cand_inst_len'],
            batch['inst_input'],
            batch['inst'],
            # batch['inst_len'],
            glob_feat,
            # batch['inst_idx']
        )

        assert_eq(cont_loss.size(), lang_loss.size())
        lang_loss = (1 - cont.float()) * lang_loss
        loss = cont_loss + lang_loss
        loss = loss.mean()
        all_loss = {
            'loss': loss,
            'cont_loss': cont_loss.mean(),
            'lang_loss': lang_loss.mean()
        }
        return loss, all_loss

    def compute_eval_loss(self, batch):
        batch = self._format_language_input_with_candidate(batch)
        # batch = self._format_supervised_language_input(batch)
        glob_feat, cont_feat = self._forward(batch)

        cont = 1 - batch['is_base_frame']
        cont_loss = self.cont_cls.compute_loss(cont_feat, cont)

        # cand, cand_len = self._get_pos_candidate_inst(glob_feat.device)
        lang_logp = self.inst_selector.compute_prob(
            batch['inst_input'],
            batch['inst'],
            glob_feat,
            # batch['inst_idx']
            log=True
        )
        lang_loss = -lang_logp.gather(1, batch['inst_idx'].unsqueeze(1)).squeeze(1)

        assert_eq(cont_loss.size(), lang_loss.size())
        lang_loss = (1 - cont.float()) * lang_loss
        loss = cont_loss + lang_loss
        loss = loss.mean()
        all_loss = {
            'loss': loss,
            'cont_loss': cont_loss.mean(),
            'lang_loss': lang_loss.mean()
        }
        return loss, all_loss

    def forward(self, batch):
        """forward function use by RL
        """
        # print('>>>> start of forward')
        # print('>>>>', type(batch))
        # batch = self.format_coach_input(batch)
        # In RL, the data is always in one-hot format, need to process
        if not isinstance(batch, dict):
            # print('=====train========')
            # for key in sorted(batch.batch.keys()):
            #     print(key, ':', batch[key].sum().item())
            # # print(batch['frame_passed'])
            # print('========')
            # # assert False
            batch = self.format_coach_input(batch)

        batch = self._format_rl_language_input(batch)
        # TODO: if works
        # batch['frame_passed'] = torch.min(batch['frame_passed'], torch.ones(1).cuda() *
        # print(batch['frame_passed'])
        # print('============')
        glob_feat, cont_feat = self._forward(batch)
        v = torch.tanh(self.value(glob_feat).squeeze())
        cont_prob = self.cont_cls.compute_prob(cont_feat)
        inst_prob = self.inst_selector.compute_prob(
            batch['cand_inst_input'], batch['cand_inst'], glob_feat)

        output = {
            'pi': {
                'inst_cont_pi': cont_prob,
                'inst_pi': inst_prob,
            },
            'V': v
        }
        # print('>>>>end of forward')
        return output

    def sample(self, batch):
        """used for actor in ELF and visually evaulating model

        return
            inst: [batch, max_sentence_len], even inst is one-hot
            inst_len: [batch]
        """
        probs = self.forward(batch)['pi']
        # print('>>>',
        #       probs['inst_cont_pi'].size(),
        #       probs['inst_pi'].size(),
        #       batch['prev_inst_idx'].size())
        samples = self.sampler.sample(
            probs['inst_cont_pi'], probs['inst_pi'], batch['prev_inst_idx'])

        # actor reply for RL with ELF
        reply = {
            'inst_cont': samples['inst_cont'],
            'inst_cont_pi': probs['inst_cont_pi'],
            'inst': samples['inst'],
            'inst_pi': probs['inst_pi'],
        }

        # convert format needed by executor
        samples = []
        lengths = []
        raws = []
        for idx in reply['inst']:
            inst = self.inst_dict.get_inst(int(idx.item()))
            tokens, length = self.inst_dict.parse(inst, True)
            samples.append(tokens)
            lengths.append(length)
            raws.append(convert_to_raw_instruction(inst, self.max_raw_chars))

        device = reply['inst_cont'].device
        inst = torch.LongTensor(samples).to(device)
        inst_len = torch.LongTensor(lengths).to(device)
        reply['raw_inst'] = torch.LongTensor(raws).to(device)
        return inst, inst_len, reply['inst_cont'], reply

    def _format_rl_language_input(self, batch):
        prev_inst, prev_inst_len = self._parse_batch_inst(
            batch['prev_inst_idx'].cpu().numpy(), batch['prev_inst_idx'].device)
        batch['prev_inst'] = prev_inst
        batch['prev_inst_len'] = prev_inst_len

        inst, inst_len = self._get_pos_candidate_inst(prev_inst.device)
        batch['cand_inst'] = inst

        start = torch.zeros(inst.size(0), 1) + self.inst_dict.start_word_idx
        start = start.long().to(inst.device)
        inst_input = torch.cat([start, inst[:, :-1]], 1)
        batch['cand_inst_input'] = inst_input

        return batch

    def _format_language_input(self, batch):
        """convert prev_inst and inst from one hot to rnn format,
        add inst_input for RNN
        """
        inst = batch['inst']

        start = torch.zeros(inst.size(0), 1) + self.inst_dict.start_word_idx
        start = start.long().to(inst.device)
        inst_input = torch.cat([start, inst[:, :-1]], 1)
        batch['inst_input'] = inst_input

        return batch

    def _format_language_input_with_candidate(self, batch):
        """convert prev_inst and inst from one hot to rnn format,
        add inst_input for RNN
        """
        # print(batch.keys())
        # prev_inst, prev_inst_len = self._parse_batch_inst(
        #     batch['prev_inst_idx'].cpu().numpy(), batch['prev_inst_idx'].device)

        # batch['prev_inst'] = prev_inst
        # batch['prev_inst_len'] = prev_inst_len

        # inst, _ = self._parse_batch_inst(batch['inst'])
        # batch['inst'] = inst

        inst, _ = self._get_pos_candidate_inst(batch['inst'].device)
        batch['inst'] = inst

        start = torch.zeros(inst.size(0), 1) + self.inst_dict.start_word_idx
        start = start.long().to(inst.device)
        inst_input = torch.cat([start, inst[:, :-1]], 1)
        batch['inst_input'] = inst_input

        return batch

    # def _format_supervised_language_input(self, batch):
    #     device = batch['prev_inst'].device
    #     pos_inst, pos_inst_len = self._get_pos_candidate_inst(device)
    #     neg_inst, neg_inst_len = self._get_neg_candidate_inst(device, batch['inst_idx'])
    #     batch['pos_cand_inst'] = pos_inst
    #     batch['pos_cand_inst_len'] = pos_inst_len
    #     batch['neg_cand_inst'] = neg_inst
    #     batch['neg_cand_inst_len'] = neg_inst_len
    #     return batch

    def _get_pos_candidate_inst(self, device):
        if (self.pos_candidate_inst is not None
            and self.pos_candidate_inst[0].device == device):
            inst, inst_len = self.pos_candidate_inst
        else:
            inst, inst_len = self._parse_batch_inst(
                range(self.options.num_instructions), device)
            self.pos_candidate_inst = (inst, inst_len)

        return inst, inst_len

    def _parse_batch_inst(self, indices, device):
        inst = []
        inst_len = []
        for idx in indices:
            parsed, l = self.inst_dict.parse(self.inst_dict.get_inst(idx), True)
            inst.append(parsed)
            inst_len.append(l)

        inst = torch.LongTensor(inst).to(device)
        inst_len = torch.LongTensor(inst_len).to(device)
        return inst, inst_len
