# %%

import math
import os

if __name__ == "__main__":
    # change dir to ../../
    os.chdir(os.path.dirname(os.path.dirname(os.getcwd())))


import pickle

import numpy as np
import torch
from omegaconf import OmegaConf
from phonemizer.backend import FestivalBackend, EspeakBackend
from phonemizer.punctuation import Punctuation
from phonemizer.separator import Separator
from scipy.io import loadmat
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm

from ntldm.data.lds import sequential_split

# %%


def get_bert_embeddings(phonemized_sentences):
    # llaceholder for actual BERT embeddings retrieval
    return [torch.zeros(768) for _ in phonemized_sentences]


def load_mat_files(datapath, split, max_seqlen=512):
    filenames = os.listdir(os.path.join(datapath, split))
    all_sentences = []
    all_dates = []
    all_spikes = []

    for f in filenames:
        matfile = loadmat(os.path.join(datapath, split, f))
        all_sentences.extend(matfile["sentenceText"])
        all_dates.extend([f[:-4] + "_" + str(date[0]) for date in matfile["blockIdx"]])
        all_spikes.extend([np.array(spike) for spike in matfile["tx1"][0]])

    all_sentences_filtered = []
    all_dates_filtered = []
    all_spikes_filtered = []

    for sentence, date, spike in zip(all_sentences, all_dates, all_spikes):
        if len(spike) > max_seqlen:
            continue
        all_sentences_filtered.append(sentence.strip())
        all_dates_filtered.append(date)
        all_spikes_filtered.append(spike)

    return {
        "sentences": all_sentences_filtered,
        "dates": all_dates_filtered,
        "spikes": all_spikes_filtered,
    }


# %%


class PhonemeDataset(Dataset):
    def __init__(self, datapath, split="train", time_last=True, max_seqlen=512):
        super().__init__()
        self.datapath = datapath
        self.time_last = time_last
        self.max_seqlen = max_seqlen

        self.data = load_mat_files(datapath, split, max_seqlen=max_seqlen)
        self.embedding_dict = torch.load(os.path.join(datapath, "embeddings.pkl"))
        self.embedding_dict = self.embedding_dict[split]

        original_sentences = self.data["sentences"]

        self.original_sentences = []
        self.phonemized_sentences = []
        self.embeddings = []

        for i, (sent, phon, emb) in enumerate(
            zip(
                self.embedding_dict["original_sentence"],
                self.embedding_dict["phonemized_sentence"],
                self.embedding_dict["embedding"],
            )
        ):
            if sent.strip() == original_sentences[i].strip():
                self.original_sentences.append(sent)
                self.phonemized_sentences.append(phon)
                self.embeddings.append(emb)
            else:
                raise ValueError(
                    f"Sentences do not match: {sent} != {original_sentences[i]}"
                )

        self.spikes, self.masks = self.process_spikes(self.data["spikes"])

        self.embeddings, self.embedding_masks = self.process_embeddings(self.embeddings)
        assert (
            len(self.original_sentences)
            == len(self.phonemized_sentences)
            == len(self.embeddings)
            == len(self.spikes)
            == len(self.masks)
        ), "Lengths of all lists should be the same"

    def phonemize_sentences(self, sentences):
        text = Punctuation(';:,.!"?()-').remove(sentences)
        backend = EspeakBackend("en-us", preserve_punctuation=True, with_stress=True)
        # backend = FestivalBackend("en-us", preserve_punctuation=True)
        separator = Separator(phone=" ", word="|")
        return backend.phonemize(
            text,
            separator=separator,
            strip=True,
        )

    def process_embeddings(self, embeddings):
        embedding_masks = []
        processed_embeddings = []
        max_embedding_length = max([emb.shape[-1] for emb in embeddings])

        for emb in embeddings:
            emb = emb[0].permute(1, 0).numpy()  # [1, 512, L] -> [L, 512]
            mask = np.ones((max_embedding_length, emb.shape[1]), dtype=np.float32)
            if len(emb) < max_embedding_length:
                padding_length = max_embedding_length - len(emb)
                mask[len(emb) :] *= 0
                emb = np.pad(emb, ((padding_length, 0), (0, 0)), mode="constant")
            processed_embeddings.append(emb)
            embedding_masks.append(mask)
        return (
            torch.from_numpy(np.array(processed_embeddings)).float(),
            torch.from_numpy(np.array(embedding_masks)).float(),
        )

    def process_spikes(self, spikes):
        processed_spikes = []
        masks = []
        for spike in spikes:
            spike = spike[:, :128]  # limit to 128 (6v area) spikes
            mask = np.ones((self.max_seqlen, spike.shape[1]), dtype=np.float32)
            if len(spike) < self.max_seqlen:
                padding_length = self.max_seqlen - len(spike)
                mask[len(spike) :] *= 0
                spike = np.pad(spike, ((0, padding_length), (0, 0)), mode="constant")
            processed_spikes.append(spike)
            masks.append(mask)
        return (
            torch.from_numpy(np.array(processed_spikes)).float(),
            torch.from_numpy(np.array(masks)).float(),
        )

    def __len__(self):
        return len(self.spikes)

    def __getitem__(self, idx):
        return {
            "original_sentence": self.original_sentences[idx],
            "phonemized_sentence": self.phonemized_sentences[idx],
            "embedding": self.embeddings[idx],
            "embedding_mask": self.embedding_masks[idx],
            "signal": (self.spikes[idx].T if self.time_last else self.spikes[idx]),
            "mask": (self.masks[idx].T if self.time_last else self.masks[idx]),
        }


def get_phoneme_dataloaders(datapath, batch_size=32, num_workers=4, max_seqlen=512):

    train_dataset = PhonemeDataset(datapath, split="train", max_seqlen=max_seqlen)
    val_dataset = PhonemeDataset(datapath, split="test", max_seqlen=max_seqlen)

    val_dataset, test_dataset = sequential_split(
        val_dataset, [len(val_dataset) // 4, len(val_dataset) - len(val_dataset) // 4]
    )

    train_dataloader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )
    test_dataloader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    print(f"Data path: {datapath}")
    print(
        f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}"
    )

    return train_dataloader, val_dataloader, test_dataloader


# %%
if __name__ == "__main__":
    datapath = "data/phoneme/competitionData"
    batch_size = 32
    num_workers = 2

    # train_dataset = PhonemeDataset(datapath, split="train", max_seqlen=512)
    # print(
    #     train_dataset.original_sentences[0],
    #     train_dataset.phonemized_sentences[0],
    #     train_dataset.embeddings[0].shape,
    #     train_dataset.embedding_masks[0].shape,
    #     train_dataset.spikes[0].shape,
    #     train_dataset.masks[0].shape,
    # )

    import time

    start = time.time()

    train_loader, val_loader, test_loader = get_phoneme_dataloaders(
        datapath, batch_size, num_workers
    )

    print(f"Time taken for dataloader: {time.time() - start:.2f}s")

    for data in train_loader:
        print(data["signal"].shape, data["mask"].shape)
        break

    for data in val_loader:
        print(data["signal"].shape, data["mask"].shape)
        break

    for data in test_loader:
        print(data["signal"].shape, data["mask"].shape)
        break

# %%
