import math
import os

import numpy as np
from omegaconf import OmegaConf
import torch

from ntldm.utils.utils import set_seed, standardize_array
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm
from einops import rearrange

from nlb_tools.nwb_interface import NWBDataset
from nlb_tools.make_tensors import make_train_input_tensors

from ntldm.data.lds import sequential_split
import pickle


class MonkeyDataset(Dataset):
    def __init__(
        self,
        task,
        datapath,
        bin_width=5,
        is_train=True,
        time_last=True,
        new_data=False,
    ):
        super().__init__()
        self.task = task
        self.datapath = datapath
        self.time_last = time_last

        ## Try to load preprocessed monkey data from neural latents lib.
        ## If it doesn't exist, we create it

        ## training split first
        if not new_data and os.path.exists(
            os.path.join(datapath, f"monkey_{task}_data_dict_train_split_{bin_width}.pkl")
        ):
            # directly load it
            data_dict_train = pickle.load(
                open(
                    os.path.join(datapath, f"monkey_{task}_data_dict_train_split_{bin_width}.pkl"),
                    "rb",
                )
            )
            print(
                f"Loaded train data dict from {os.path.join(datapath, f'monkey_{task}_data_dict_train_split_{bin_width}.pkl')}"
            )
        else:
            # create it and save it

            Nwbdataset = NWBDataset(datapath)
            # Choose bin width and resample
            Nwbdataset.resample(bin_width)

            data_dict_train = make_train_input_tensors(
                Nwbdataset,
                dataset_name=task,
                trial_split="train",
                save_file=False,
                include_behavior=True,
            )
            pickle.dump(
                data_dict_train,
                open(
                    os.path.join(
                        datapath, f"monkey_{task}_data_dict_train_split_{bin_width}.pkl"
                    ),
                    "wb",
                ),
            )
            print(
                f"Saved train data dict to {os.path.join(datapath, f'monkey_{task}_data_dict_train_split_{bin_width}.pkl')}"
            )

        ## now the val split from nlb
        if not new_data and os.path.exists(
            os.path.join(datapath, f"monkey_{task}_data_dict_val_split_{bin_width}.pkl")
        ):
            # directly load it
            data_dict_val = pickle.load(
                open(
                    os.path.join(datapath, f"monkey_{task}_data_dict_val_split_{bin_width}.pkl"),
                    "rb",
                )
            )
            print(
                f"Loaded val data dict from {os.path.join(datapath, f'monkey_{task}_data_dict_val_split_{bin_width}.pkl')}"
            )
        else:
            # create it and save it

            Nwbdataset = NWBDataset(datapath)
            # Choose bin width and resample
            Nwbdataset.resample(bin_width)

            data_dict_val = make_train_input_tensors(
                Nwbdataset,
                dataset_name=task,
                trial_split="val",
                save_file=False,
                include_behavior=True,
            )
            pickle.dump(
                data_dict_val,
                open(
                    os.path.join(
                        datapath, f"monkey_{task}_data_dict_val_split_{bin_width}.pkl"
                    ),
                    "wb",
                ),
            )
            print(
                f"Saved val data dict to {os.path.join(datapath, f'monkey_{task}_data_dict_val_split_{bin_width}.pkl')}"
            )

        ## Make train data

        if is_train:
            # Unpack data [B, L, C]
            train_spikes_heldin = data_dict_train["train_spikes_heldin"]
            train_spikes_heldout = data_dict_train["train_spikes_heldout"]
            train_spikes = np.concatenate(
                [train_spikes_heldin, train_spikes_heldout], axis=-1
            )
            train_behavior = data_dict_train["train_behavior"]

            # load 50% of val
            # Unpack data [B, L, C]
            train_spikes_heldin = data_dict_val["train_spikes_heldin"]
            train_spikes_heldout = data_dict_val["train_spikes_heldout"]
            train_spikes2 = np.concatenate(
                [train_spikes_heldin, train_spikes_heldout], axis=-1
            )

            train_behavior2 = data_dict_val["train_behavior"][: len(train_spikes2) // 2]
            train_spikes2 = train_spikes2[: len(train_spikes2) // 2]

            train_spikes = np.concatenate([train_spikes, train_spikes2], axis=0)
            train_behavior = np.concatenate([train_behavior, train_behavior2], axis=0)

        else:
            # load rest 50% of val
            # Unpack data [B, L, C]
            train_spikes_heldin = data_dict_val["train_spikes_heldin"]
            train_spikes_heldout = data_dict_val["train_spikes_heldout"]

            train_spikes = np.concatenate(
                [train_spikes_heldin, train_spikes_heldout], axis=-1
            )
            train_behavior = data_dict_val["train_behavior"]
            print("before  test split", train_spikes.shape, train_behavior.shape)

            train_spikes = train_spikes[len(train_spikes) // 2 :]
            train_behavior = train_behavior[len(train_behavior) // 2 :]

        self.train_spikes = torch.from_numpy(train_spikes).float()
        self.behavior = torch.from_numpy(train_behavior).float()
        print(
            f"Train spikes shape: {self.train_spikes.shape}, Behavior shape: {self.behavior.shape}"
        )

    def __len__(self):
        return len(self.train_spikes)

    def __getitem__(self, idx):
        if self.time_last:
            return {
                "signal": self.train_spikes[idx].T,  # [C, L]
                "behavior": self.behavior[idx].T,  # [2, L]
            }
        return {
            "signal": self.train_spikes[idx],  # [L, C]
            "behavior": self.behavior[idx],  # [2, L]
        }


# create the latent dataset
class LatentMonkeyDataset(torch.utils.data.Dataset):
    def __init__(
        self, dataloader, ae_model, clip=True, latent_means=None, latent_stds=None
    ):
        self.full_dataloader = dataloader
        self.ae_model = ae_model
        self.latents, self.train_spikes, self.behavior = self.create_latents()
        # normalize to N(0, 1)
        if latent_means is None or latent_stds is None:
            self.latent_means = self.latents.mean(dim=(0, 2)).unsqueeze(0).unsqueeze(2)
            self.latent_stds = self.latents.std(dim=(0, 2)).unsqueeze(0).unsqueeze(2)
        else:
            self.latent_means = latent_means
            self.latent_stds = latent_stds
        self.latents = (self.latents - self.latent_means) / self.latent_stds
        if clip:
            self.latents = self.latents.clamp(-5, 5)

        assert len(self.latents) == len(self.behavior) and len(self.latents) == len(
            self.train_spikes
        ), f"Lengths of latents, behavior, and spikes do not match: {len(self.latents)}, {len(self.behavior)}, {len(self.train_spikes)}"

        self.behavior = self.behavior / 1e3  # better this way

        self.behavior_cumsum = torch.cumsum(self.behavior, dim=-1)

        # take the first 50 time steps after go cue to determine reach angle
        self.behavior_angles = torch.atan2(
            self.behavior[:, 1, 50], self.behavior[:, 0, 50]
        )
        self.behavior_angles = rearrange(self.behavior_angles, "B -> B 1")

    def create_latents(self):
        latent_dataset = []
        train_spikes = []
        behavior = []
        self.ae_model.eval()
        for i, batch in tqdm(
            enumerate(self.full_dataloader),
            total=len(self.full_dataloader),
            desc="Creating latent dataset",
        ):
            with torch.no_grad():
                z = self.ae_model.encode(batch["signal"])
                latent_dataset.append(z.cpu())
                train_spikes.append(batch["signal"].cpu())
                behavior.append(batch["behavior"].cpu())
        return torch.cat(latent_dataset), torch.cat(train_spikes), torch.cat(behavior)

    def __len__(self):
        return len(self.latents)

    def __getitem__(self, idx):
        return {
            "signal": self.train_spikes[idx],
            "latent": self.latents[idx],
            "behavior": self.behavior[idx],
            "behavior_angle": self.behavior_angles[idx],
        }


def get_monkey_dataloaders(task, datapath, bin_width=5, batch_size=32, num_workers=2, new_data=False):
    """Gte the monkey data loaders

    Args:
        task
        datapath: where to store data
        bin_width: Bin width of behavior and neural activity
        batch_size: Defaults to 32.
        num_workers: Defaults to 2.
        new_data: Should new data be read in or loaded from disc

    Returns:
        data loaders
    """
    train_dataset = MonkeyDataset(task, datapath, bin_width=bin_width, is_train=True, new_data=new_data)
    val_dataset = MonkeyDataset(task, datapath, bin_width=bin_width, is_train=False, new_data=new_data)

    # split val and test dataset
    val_len = len(val_dataset)
    val_dataset, test_dataset = sequential_split(
        val_dataset, [val_len // 4, val_len - val_len // 4]
    )

    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=True,
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        num_workers=num_workers,
        pin_memory=True,
    )

    print(f"Task: {task}, Bin width: {bin_width} ms, Data path: {datapath}")
    print(
        f"Train: {len(train_dataset)}, Val: {len(val_dataset)}, Test: {len(test_dataset)}"
    )

    return train_loader, val_loader, test_loader


if __name__ == "__main__":
    task = "mc_maze"
    datapath = "data/000128/sub-Jenkins/"
    bin_width = 5
    batch_size = 32
    num_workers = 2

    train_loader, val_loader, test_loader = get_monkey_dataloaders(
        task, datapath, bin_width, batch_size, num_workers
    )

    for data in train_loader:
        print(data["signal"].shape, data["behavior"].shape)
        break

    for data in val_loader:
        print(data["signal"].shape, data["behavior"].shape)
        break

    for data in test_loader:
        print(data["signal"].shape, data["behavior"].shape)
        break
