"""
Language utilities
"""

import torch
from babyai.levels.verifier import INSTRS
from gym_minigrid.minigrid import COLOR_NAMES
from torch import nn
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence

PAD_TOKEN = "<PAD>"
SOS_TOKEN = "<SOS>"
EOS_TOKEN = "<EOS>"

PAD_INDEX = 0
SOS_INDEX = 1
EOS_INDEX = 2


_W2I_STARTER = {
    PAD_TOKEN: PAD_INDEX,
    SOS_TOKEN: SOS_INDEX,
    EOS_TOKEN: EOS_INDEX,
}


def get_lang(env, FLAGS):
    if FLAGS.language_goals in {None, "xy"}:
        xy_lang, xy_lang_len, xy_vocab = preprocess_xy_instrs(env.height, env.width)
        return {
            "vocab": xy_vocab,
            "lang": xy_lang,
            "lang_len": xy_lang_len,
            # - 3 to account for vocab.
            "lang_templates": [
                f"({x[1].item() - 3}, {x[2].item()} - 3)" for x in xy_lang
            ],
        }
    else:
        if FLAGS.onehot_language_goals:
            lang = {
                "vocab": VOCAB_ONEHOT,
                "lang": LANG_ONEHOT,
                "lang_len": LANG_LEN_ONEHOT,
                "lang_templates": INSTR_TEMPLATES,
            }
        else:
            lang = {
                "vocab": VOCAB,
                "lang": LANG,
                "lang_len": LANG_LEN,
                "lang_templates": INSTR_TEMPLATES,
            }
        return lang


class W2I:
    def __init__(self):
        self._w2i = _W2I_STARTER.copy()

    def __getitem__(self, i):
        if i not in self._w2i:
            self._w2i[i] = len(self._w2i)
        return self._w2i[i]

    def __len__(self):
        return len(self._w2i)

    def __dict__(self):
        return self._w2i

    def to_vocab(self):
        return {
            "w2i": self._w2i,
            "i2w": {v: k for k, v in self._w2i.items()},
            "size": len(self._w2i),
        }


def preprocess_instrs(instrs):
    # Preprocess instructions
    w2i = W2I()

    instrs_raw = [i.surface(None) for i in instrs]
    instrs_raw = [i.split(" ") for i in instrs_raw]

    lang = []
    lang_len = []

    for instr in instrs_raw:
        lang_n = [SOS_INDEX]
        for tok in instr:
            lang_n.append(w2i[tok])
        lang_n.append(EOS_INDEX)

        lang.append(torch.tensor(lang_n))
        lang_len.append(len(lang_n))

    lang = pad_sequence(lang, batch_first=True)
    lang_len = torch.tensor(lang_len)

    vocab = w2i.to_vocab()
    return lang, lang_len, vocab


def preprocess_instrs_onehot(instrs):
    """Preprocess instrs but ignore the actual language content."""
    w2i = W2I()

    lang = []
    lang_len = []

    for instr_onehot_id, _ in enumerate(instrs):
        lang_n = [SOS_INDEX, w2i[str(instr_onehot_id)], EOS_INDEX]

        lang.append(torch.tensor(lang_n))
        lang_len.append(len(lang_n))

    lang = pad_sequence(lang, batch_first=True)
    lang_len = torch.tensor(lang_len)

    vocab = w2i.to_vocab()
    return lang, lang_len, vocab


def get_instr_templates(instrs):
    instrs_raw = [i.surface(None) for i in instrs]
    instrs_raw = [i.split(" ") for i in instrs_raw]
    templates = []
    for instr in instrs_raw:
        template = []
        for tok in instr:
            if tok in COLOR_NAMES:
                template.append("C")
            else:
                template.append(tok)
        templates.append(" ".join(template))
    return templates


LANG, LANG_LEN, VOCAB = preprocess_instrs(INSTRS)
LANG_ONEHOT, LANG_LEN_ONEHOT, VOCAB_ONEHOT = preprocess_instrs_onehot(INSTRS)
INSTR_TEMPLATES = get_instr_templates(INSTRS)
INSTR_TEMPLATES_UNIQUE = sorted(list(set(INSTR_TEMPLATES)))


class LanguageEncoder(nn.Module):
    def __init__(self, vocab, embedding_dim=64, hidden_dim=256):
        super().__init__()
        self.vocab = vocab
        self.embedding_dim = embedding_dim
        self.hidden_dim = hidden_dim

        self.embedding = nn.Embedding(self.vocab["size"], self.embedding_dim)
        self.rnn = nn.LSTM(self.embedding_dim, self.hidden_dim, batch_first=True)

    def forward(self, lang, lang_len):
        lang_emb = self.embedding(lang)
        packed_input = pack_padded_sequence(
            lang_emb, lang_len.cpu(), enforce_sorted=False, batch_first=True
        )
        _, (h, c) = self.rnn(packed_input)
        return h[0]


def preprocess_xy_instrs(height, width):
    w2i = W2I()

    lang = []
    lang_len = []
    for x in range(height):
        for y in range(width):
            lang_n = [SOS_INDEX, w2i[str(x)], w2i[str(y)], EOS_INDEX]
            lang.append(lang_n)
            lang_len.append(len(lang_n))
    lang = torch.tensor(lang)
    lang_len = torch.tensor(lang_len)

    vocab = w2i.to_vocab()
    return lang, lang_len, vocab


def xy_instr_to_goal_number(instr, height, vocab):
    instr = instr.cpu().numpy()
    x = int(vocab["i2w"][instr[1]])
    y = int(vocab["i2w"][instr[2]])
    return (x * height) + y


def to_text(langs, vocab=VOCAB):
    if isinstance(langs, torch.Tensor):
        langs = langs.detach().cpu().numpy()

    if langs.ndim == 1:
        raise ValueError("This function operates on batches")

    texts = []
    for lang in langs:
        text = []
        for tok in lang:
            if tok in {SOS_INDEX, EOS_INDEX, PAD_INDEX}:
                continue
            text.append(VOCAB["i2w"][tok])
        texts.append(" ".join(text))
    return texts
