# %%
import math
import multiprocessing
import os
from os.path import join

import numpy as np
import yaml
from omegaconf import OmegaConf

import accelerate
import dysts
import dysts.flows
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader, Dataset
from tqdm.auto import tqdm

from ntldm.networks import AutoEncoder, CountWrapper, SinusoidalPosEmb
from ntldm.utils.utils import set_seed, standardize_array

ATTRACTORS = [
    "ForcedFitzHughNagumo",  # 3dim
    "Hopfield",  # 6dim
    "Lorenz",  # 3dim
    "LorenzCoupled",  # 6dim
    "MackeyGlass",  # 10dim
]


def simulate_single_trajectory(args):
    system_name, initial_condition, sequence_length, burn_in, random_seed = args
    system = getattr(dysts.flows, system_name)()  # instantiate system
    set_seed(random_seed)

    system.random_state = random_seed
    system.ic = initial_condition.tolist()

    trace = system.make_trajectory(sequence_length + burn_in)
    trace = trace[burn_in:]  # discard the burn-in period

    # Normalize the trace
    trace = np.array(trace)
    trace = trace - np.mean(trace, axis=0)
    trace = trace / np.std(trace, axis=0)
    return trace


def simulate_attractor_parallel(
    system_name, num_seq=10, burn_in=100, sequence_length=1000, random_seed=42
):
    """Simulate the attractor system in parallel using multiprocessing."""
    system = getattr(dysts.flows, system_name)()  # Caution sets seed inside
    initial_conditions = system.ic
    # set the seed again
    set_seed(random_seed)
    initial_conditions = (
        np.array(initial_conditions).reshape(1, -1).repeat(num_seq, axis=0)
    )
    initial_conditions += (
        np.random.randn(*initial_conditions.shape) * 2.0
    )  # perturb initial conditions

    args_list = [
        (system_name, initial_conditions[i], sequence_length, burn_in, random_seed + i)
        for i in range(num_seq)
    ]

    # Determine the number of processes to use
    num_workers = min(num_seq, multiprocessing.cpu_count())

    with multiprocessing.Pool(num_workers) as pool:
        results = []
        for result in tqdm(
            pool.imap(simulate_single_trajectory, args_list),
            total=num_seq,
            desc=f"Simulating {system_name}",
        ):
            results.append(result)

    # Convert results to a tensor and normalize
    latents = torch.tensor(results, dtype=torch.float32)

    return latents


# deprecated
def simulate_attractor(
    system_name, num_seq=10, burn_in=100, sequence_length=1000, random_seed=42
):

    system = getattr(dysts.flows, system_name)()
    set_seed(random_seed)

    system_dimension = len(system.ic)
    initial_conditions = system.ic
    initial_conditions = (
        torch.tensor(initial_conditions).unsqueeze(0).repeat(num_seq, 1)
    )
    initial_conditions = (
        initial_conditions + torch.randn(initial_conditions.shape, generator=None) * 1.0
    )

    latents = torch.empty(
        (num_seq, sequence_length, system_dimension), dtype=torch.float32
    )
    # sim for each initial condition
    for i in tqdm(range(num_seq), desc=f"Simulating {system_name}"):
        x = initial_conditions[i]

        system = getattr(dysts.flows, system_name)()  # reset the system
        system.random_state = random_seed + i
        system.ic = x.numpy().tolist()

        trace = system.make_trajectory(sequence_length + burn_in)
        trace = trace[burn_in:]

        latents[i] = torch.tensor(trace)

    latents = latents - torch.mean(latents, dim=(0, 1), keepdim=True)
    latents = latents / torch.std(latents, dim=(0, 1), keepdim=True)

    return latents


class LatentAttractor(Dataset):
    def __init__(
        self,
        system_name,
        n_neurons,
        sequence_length=100,
        noise_std=0.05,
        n_ic=5,
        mean_spike_count=500.0,
        random_seed=42,
        softplus_beta=1,
        original=True,
    ):
        """
        Initialize latent space with attractor dynamics.
        Attractor is from dysts.flows, 5 systems for now in ATTRACTORS.
        1. Initialize the attractor system, with output dimension
        2. Generate sequences for each initial condition
        3. Generate Poisson rates and samples
        4. Store for dataset

        :param system_name: Name of the attractor system
        :param n_neurons: Number of output dimensions
        :param sequence_length: Length of the sequence to generate
        :param noise_std: Std of noise added to the latent dynamics
        :param n_ic: Number of initial conditions
        :param mean_spike_count: Mean spike count for Poisson observation model
        """

        set_seed(random_seed)

        self.system_name = system_name
        self.n_neurons = n_neurons
        self.sequence_length = sequence_length

        # Generate sequences for each initial condition
        latents = simulate_attractor_parallel(
            system_name,
            num_seq=n_ic,
            sequence_length=sequence_length,
            random_seed=random_seed,
        )
        if original:
            # Projection matrix to higher-dimensional data space
            self.C = torch.randn(n_neurons, latents.shape[-1], dtype=torch.float32)
            self.C /= torch.norm(self.C, dim=1, keepdim=True)

            self.C = self.C[self.C[:, 0].argsort()]  # Sort based on the first column

            self.b = torch.log(
                torch.tensor(mean_spike_count) / sequence_length
            ) * torch.ones(n_neurons, 1)

            # Generate Poisson rates and samples
            # Compute log rates
            self.log_rates = (
                torch.einsum("ij,klj->kli", self.C, latents) + self.b.squeeze()
            )

            self.poisson_rates = F.softplus(self.log_rates, beta=softplus_beta)

            self.samples = torch.poisson(self.poisson_rates)
            # self.samples = torch.clip(torch.poisson(self.poisson_rates), min=0, max=1) # AS 9.Mai
        else:

            self.C = torch.rand(n_neurons, latents.shape[-1], dtype=torch.float32) - 0.5
            self.C *= 2

            self.C = self.C[self.C[:, 0].argsort()]  # Sort based on the first column

            self.log_rates = torch.einsum("ij,klj->kli", self.C, latents)

            baseline_fr = 0  # spk/s
            bias = 0.01  # spk/s
            dt = 0.01
            fr_mod = (torch.rand(1, n_neurons) - 0.5) * bias  # spk/s
            self.poisson_rates = torch.exp(self.log_rates) + baseline_fr + fr_mod
            self.poisson_rates = F.softplus(self.poisson_rates, beta=softplus_beta)
            self.samples = torch.poisson(self.poisson_rates * dt)

        # Store for dataset
        self.latents = latents
        self.rates = self.poisson_rates

    def __len__(self):
        """
        Return the length.
        """
        return len(self.samples)

    def __getitem__(self, index):
        """
        Get a sample from the dataset. [B, D, L]
        """
        return_dict = {}
        return_dict["signal"] = self.samples[index].permute(1, 0)
        return_dict["latents"] = self.latents[index].permute(1, 0)
        return_dict["rates"] = self.rates[index].permute(1, 0)

        return return_dict


def get_attractor_dataset(
    system_name,
    n_neurons=128,
    sequence_length=500,
    noise_std=0.05,
    n_ic=100,
    mean_spike_count=200,
    batch_size=100,
    train_frac=0.7,
    valid_frac=0.15,
    random_seed=42,
    softplus_beta=1,
):
    """
    generates and splits data from a given attractor system into train, val, and test sets.
    this function uses the LatentAttractor dataset class defined above.

    parameters:
    - system_name: name of the attractor system to simulate.
    - n_neurons: number of neurons (output dimensionality).
    - sequence_length: length of each sequence in timesteps.
    - noise_std: standard deviation of noise to add to the latents.
    - n_ic: number of initial conditions to simulate.
    - mean_spike_count: mean spike count for the poisson observation model.
    - batch_size: size of the batches for data loading.
    - train_frac: fraction of the dataset to use for training.
    - valid_frac: fraction of the dataset to use for validation.
    - random_seed: seed for random number generation for reproducibility.

    returns:
    - a tuple of DataLoader objects for training, validation, and test sets.
    """
    try:
        getattr(dysts.flows, system_name)
    except AttributeError:
        raise ValueError(
            f"Invalid system_name: {system_name}. Ideally, choose from {ATTRACTORS}"
        )

    # set the seed for reproducibility
    set_seed(random_seed)

    # total number of sequences to be generated
    total_sequences = n_ic
    # calculate number of sequences for train, valid, test sets based on fractions
    num_train = int(train_frac * total_sequences)
    num_valid = int(valid_frac * total_sequences)
    num_test = total_sequences - num_train - num_valid

    # generate the dataset
    attractor_dataset = LatentAttractor(
        system_name=system_name,
        n_neurons=n_neurons,
        sequence_length=sequence_length,
        noise_std=noise_std,
        n_ic=n_ic,
        mean_spike_count=mean_spike_count,
        random_seed=random_seed,
        softplus_beta=softplus_beta,
    )

    # split dataset into non-overlapping train, val, test subsets
    train_dataset, valid_dataset, test_dataset = torch.utils.data.random_split(
        attractor_dataset,
        [num_train, num_valid, num_test],
        generator=torch.Generator().manual_seed(random_seed),
    )

    # create dataloaders for each subset
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=False)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    return train_dataloader, valid_dataloader, test_dataloader


class LatentDataset(torch.utils.data.Dataset):
    def __init__(
        self, dataloader, ae_model, clip=True, latent_means=None, latent_stds=None
    ):
        self.full_dataloader = dataloader
        self.ae_model = ae_model
        self.latents = self.create_latents()
        # normalize to N(0, 1)
        if latent_means is None or latent_stds is None:
            self.latent_means = self.latents.mean(dim=(0, 2)).unsqueeze(0).unsqueeze(2)
            self.latent_stds = self.latents.std(dim=(0, 2)).unsqueeze(0).unsqueeze(2)
        else:
            self.latent_means = latent_means
            self.latent_stds = latent_stds
        self.latents = (self.latents - self.latent_means) / self.latent_stds
        if clip:
            self.latents = self.latents.clamp(-5, 5)

    def create_latents(self):
        latent_dataset = []
        self.ae_model.eval()
        for i, batch in tqdm(
            enumerate(self.full_dataloader),
            total=len(self.full_dataloader),
            desc="Creating latent dataset",
        ):
            with torch.no_grad():
                z = self.ae_model.encode(batch["signal"])
                latent_dataset.append(z.cpu())
        return torch.cat(latent_dataset)

    def __len__(self):
        return len(self.latents)

    def __getitem__(self, idx):
        return self.latents[idx]


class Lorenz_AE_OUTPUT(Dataset):
    """
    Dataset of the embedded Lorenz time series AUGUSTE

    Supports positional embeddings.
    """

    def __init__(
        self,
        with_time_emb=False,
        cond_time_dim=32,
        filepath=None,
        filename_cfg=None,
        filename_model=None,
        n_latents=16,
        max_samples=1000,
    ):
        super().__init__()

        self.with_time_emb = with_time_emb
        self.cond_time_dim = cond_time_dim
        self.signal_length = 256
        self.num_channels = n_latents

        with open(join(filepath, filename_cfg), "r") as f:
            cfg_ae = OmegaConf.create(yaml.safe_load(f))

        ae_model = AutoEncoder(
            C_in=cfg_ae.model.C_in,
            C=cfg_ae.model.C,
            C_latent=cfg_ae.model.C_latent,
            L=cfg_ae.dataset.signal_length,
            kernel=cfg_ae.model.kernel,
            num_blocks=cfg_ae.model.num_blocks,
            num_lin_per_mlp=cfg_ae.model.get("num_lin_per_mlp", 2),  # default 2
        )

        ae_model = CountWrapper(
            ae_model, use_sin_enc=cfg_ae.model.get("use_sin_enc", False)
        )

        ae_model.load_state_dict(torch.load(join(filepath, filename_model)))

        # load rthe dataset
        train_dataloader, val_dataloader, test_dataloader = get_attractor_dataset(
            system_name=cfg_ae.dataset.system_name,
            n_neurons=cfg_ae.model.C_in,
            sequence_length=cfg_ae.dataset.signal_length,
            noise_std=0.05,
            n_ic=cfg_ae.dataset.n_ic,
            mean_spike_count=cfg_ae.dataset.get("mean_rate", 0.5)
            * cfg_ae.dataset.signal_length,
            train_frac=cfg_ae.dataset.split_frac,
            random_seed=cfg_ae.training.random_seed,
            batch_size=cfg_ae.training.batch_size,
            softplus_beta=cfg_ae.dataset.get("softplus_beta", 2.0),
        )

        # perform the same accelerator setup
        accelerator = accelerate.Accelerator(
            mixed_precision=cfg_ae.training.precision, log_with="wandb"
        )

        (
            ae_model,
            train_dataloader,
            val_dataloader,
            test_dataloader,
        ) = accelerator.prepare(
            ae_model, train_dataloader, val_dataloader, test_dataloader
        )

        # now we can use the model to embed the data
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")

        all_outputs = []
        ae_model = ae_model.to(device)

        for batch in train_dataloader:
            signals = batch["signal"].to(device)

            # Forward pass through the model
            outputs = ae_model(signals.to(torch.bfloat16))

            # Detach and move the output tensor to CPU
            output_tensor = outputs[1].detach().cpu()
            del outputs
            torch.cuda.empty_cache()  # Optionally clear GPU cache

            # Append the output to the list
            all_outputs.append(output_tensor)

        # Stack the output tensors along the batch dimension
        all_outputs = torch.cat(all_outputs, dim=0)

        # Move output tensor to numpy
        temp_array = all_outputs.numpy()

        # Clear intermediate variables
        ae_model = ae_model.cpu()
        torch.cuda.empty_cache()  # Optionally clear GPU cache

        # Transpose and standardize temp_array as before
        temp_array = temp_array.transpose(0, 2, 1)  #

        self.data_array = standardize_array(temp_array, ax=(0, 2))
        del temp_array

        temp_emb = SinusoidalPosEmb(cond_time_dim).forward(
            torch.arange(self.signal_length)
        )
        self.emb = torch.transpose(temp_emb, 0, 1)

    def __getitem__(self, index, cond_channel=None):
        return_dict = {}
        return_dict["signal"] = torch.from_numpy(np.float32(self.data_array[index]))
        cond = self.get_cond()
        if cond is not None:
            return_dict["cond"] = cond
        return return_dict

    def get_cond(self):
        cond = None
        if self.with_time_emb:
            cond = self.emb
        return cond

    def __len__(self):
        return len(self.data_array)


# %%
if __name__ == "__main__":

    attractor_dataset = LatentAttractor(
        # "ForcedFitzHughNagumo",
        "LorenzCoupled",
        n_neurons=128,
        sequence_length=1024,
        noise_std=0.05,
        n_ic=100,
        mean_spike_count=200,
        random_seed=42,
    )
    import matplotlib.pyplot as plt

    print(attractor_dataset.samples.shape)
    print(attractor_dataset[1]["signal"].shape)
    plt.matshow(attractor_dataset[1]["signal"], aspect="auto")
    plt.colorbar()
    plt.show()

    plt.matshow(attractor_dataset[1]["rates"], aspect="auto")
    plt.colorbar()
    plt.show()

    # plt.hist(attractor_dataset.poisson_rates.flatten(), bins=100)
    # plt.show()

    # plt.hist(attractor_dataset.samples.flatten(), bins=100)
    # plt.show()

    train_dataloader, valid_dataloader, test_dataloader = get_attractor_dataset(
        "Lorenz", n_neurons=128, sequence_length=500, noise_std=0.05, n_ic=100
    )

    for batch in train_dataloader:
        print(batch["signal"].shape, batch["latents"].shape, batch["rates"].shape)
        break

    for batch in valid_dataloader:
        print(batch["signal"].shape, batch["latents"].shape, batch["rates"].shape)
        break

    for batch in test_dataloader:
        print(batch["signal"].shape, batch["latents"].shape, batch["rates"].shape)
        break
# %%
