import math

import torch
import numpy as np
from sklearn.model_selection import train_test_split
from datasets import load_dataset
from sklearn.decomposition import PCA
from transformers import AutoTokenizer, GPTNeoModel
import pandas as pd
import os


class DataSampler:
    def __init__(self, n_dims):
        self.n_dims = n_dims

    def sample_xs(self):
        raise NotImplementedError


def get_data_sampler(data_name, n_dims, **kwargs):
    names_to_classes = {
        "gaussian": GaussianSampler,
        "multigaussian": MultiGaussianSampler,
        "nl": NLSyntheticSampler,
        "nlreal": NLEmbeddingSampler,
    }
    if data_name in names_to_classes:
        sampler_cls = names_to_classes[data_name]
        return sampler_cls(n_dims, **kwargs)
    else:
        print("Unknown sampler")
        raise NotImplementedError


def sample_transformation(eigenvalues, normalize=False):
    n_dims = len(eigenvalues)
    U, _, _ = torch.linalg.svd(torch.randn(n_dims, n_dims))
    t = U @ torch.diag(eigenvalues) @ torch.transpose(U, 0, 1)
    if normalize:
        norm_subspace = torch.sum(eigenvalues**2)
        t *= math.sqrt(n_dims / norm_subspace)
    return t


class NLEmbeddingSampler(DataSampler):
    def __init__(self, n_dims, bias=None, scale=None):
        super().__init__(n_dims)
        self.bias = bias
        self.scale = scale

        data_path = "/data/pca_embeds.pkl"
        if os.path.exists(data_path):
            self.samples = pd.read_pickle(data_path)
        else:
            dataset = load_dataset("amazon_polarity")
            df = dataset["train"].to_pandas()
            train_df, test_df = train_test_split(df, test_size=0.2)
            # randomly sample 1000 sentences from the training set that is class balanced, and remove samples from train_df
            pca_train_set = (
                test_df.groupby("label")
                .apply(lambda x: x.sample(500, random_state=42))
                .reset_index(drop=True)
            )

            # create a list of sentences
            sentences = pca_train_set["sentence"].tolist()
            self.model = GPTNeoModel.from_pretrained(
                "EleutherAI/gpt-neo-125M", cache_dir="/u/scr/hub"
            )
            self.tokenizer = AutoTokenizer.from_pretrained(
                "EleutherAI/gpt-neo-125M", cache_dir="/u/scr/hub"
            )

            self.tokenizer.pad_token = self.tokenizer.eos_token

            # tokenize the sentences
            # create a list of sentences
            sentences = pca_train_set["sentence"].tolist()
            # tokenize the sentences
            embed_sentences = []
            for sent in sentences:
                tokenized_sentence = self.tokenizer(
                    [sent], return_tensors="pt"
                ).input_ids
                # get the embeddings
                embeddings = self.model(tokenized_sentence).last_hidden_state
                embed_final = embeddings[:, -1, :]
                embed_sentences.append(embed_final)
            # get the last embedding
            final = torch.stack(embed_sentences, dim=0)

            # instantiate and fit the PCA model
            self.pca = PCA(n_components=n_dims)
            self.pca.fit(final.detach())

            self.samples = train_df[["sentence"]]
            # add a column to the dataframe that contains the tokenized sentences. Apply a map function to the column
            self.samples["tokenized_sentences"] = self.samples["sentence"].map(
                lambda x: self.tokenizer(x, return_tensors="pt").input_ids
            )

            print("Generating PCA embeddings...")

            pca_embeds = []
            for i, row in self.samples.iterrows():
                # randomly sample a sentence from the training set
                sample_sentence = row["tokenized_sentences"]
                # tokenize the sentence
                embeddings = []
                embedding = self.model(sample_sentence).last_hidden_state
                final = embedding[:, -1, :]
                transformed = self.pca.transform(final.detach())
                sample_sequence = torch.tensor(transformed, dtype=torch.float32)
                pca_embeds.append(sample_sequence)

            self.samples["embeddings"] = pca_embeds
            self.samples.to_pickle(data_path)

    def sample_xs(self, n_points, b_size, n_dims_truncated=None, seeds=None):
        xs_b = []
        nl_sentences = []
        for i in range(b_size):
            # randomly sample a sentence from the training set
            sample_sentences = self.samples.sample(
                n=n_points, random_state=42, replace=False
            )
            embeddings = sample_sentences["embeddings"].tolist()
            raw_sent = sample_sentences["text"].tolist()
            # tokenize the sentence
            nl_sentences.append(raw_sent)

            sample_sequence = torch.stack(embeddings, dim=1)
            sample_sequence = sample_sequence.squeeze(0)
            xs_b.append(sample_sequence)
        xs_b = torch.stack(xs_b, dim=0)
        xs_mean = torch.mean(xs_b, dim=1)
        # normalize xs to have zero mean
        xs_b = xs_b - xs_mean.unsqueeze(1)
        # normalize xs to have unit norm
        xs_b = xs_b / torch.norm(xs_b, dim=2).unsqueeze(2)
        return xs_b, nl_sentences


class NLSyntheticSampler(DataSampler):
    def __init__(self, n_dims, bias=None, scale=None):
        super().__init__(n_dims)
        self.bias = bias
        self.scale = scale

    def sample_xs(self, n_points, b_size, n_dims_truncated=None, seeds=None):
        xs_b = np.random.choice([-1, 1], (b_size, n_points, self.n_dims))
        # set sample_sentence to a tensor of type double
        xs_b = torch.tensor(xs_b, dtype=torch.float32)
        if self.scale is not None:
            xs_b = xs_b @ self.scale
        if self.bias is not None:
            xs_b += self.bias
        if n_dims_truncated is not None:
            xs_b[:, :, n_dims_truncated:] = -1
        return xs_b, None


class GaussianSampler(DataSampler):
    def __init__(self, n_dims, bias=None, scale=None):
        super().__init__(n_dims)
        self.bias = bias
        self.scale = scale

    def sample_xs(self, n_points, b_size, n_dims_truncated=None, seeds=None):
        if seeds is None:
            xs_b = torch.randn(b_size, n_points, self.n_dims)
        else:
            xs_b = torch.zeros(b_size, n_points, self.n_dims)
            generator = torch.Generator()
            assert len(seeds) == b_size
            for i, seed in enumerate(seeds):
                generator.manual_seed(seed)
                xs_b[i] = torch.randn(n_points, self.n_dims, generator=generator)
        if self.scale is not None:
            xs_b = xs_b @ self.scale
        if self.bias is not None:
            xs_b += self.bias
        if n_dims_truncated is not None:
            xs_b[:, :, n_dims_truncated:] = 0
        return xs_b, None


class MultiGaussianSampler(DataSampler):
    def __init__(self, n_dims, bias=None, scale=None, z=None):
        super().__init__(n_dims)
        self.bias = bias
        self.scale = scale
        self.z = z
        self.n_dims = n_dims
        self.pb = torch.distributions.bernoulli.Bernoulli(torch.tensor([0.5]))

    def sample_xs(self, n_points, b_size, n_dims_truncated=None, seeds=None):
        w_b = torch.randn(b_size, self.n_dims, 1)
        z_b = self.pb.sample((b_size, n_points))
        z_b = ((z_b * 2) - 1) * 10
        shift = w_b.expand(b_size, self.n_dims, n_points).permute(0, 2, 1)

        if seeds is None:
            # sample xs_b for the normal distribution of mean = (z_b * w_b) and std = 1
            xs_b = torch.randn(b_size, n_points, self.n_dims)

            shift = z_b * shift

            # broadcast shift along dim=1 to have size n_points
            # shift = shift.expand(b_size, self.n_dims, n_points).permute(0, 2, 1)
            xs_b += shift

        else:
            xs_b = torch.zeros(b_size, n_points, self.n_dims)
            generator = torch.Generator()
            assert len(seeds) == b_size
            for i, seed in enumerate(seeds):
                generator.manual_seed(seed)
                xs_b[i] = torch.randn((n_points, self.n_dims), generator=generator)
                shift = z_b[i] * w_b[i]
                shift = shift.expand(self.n_dims, n_points).permute(1, 0)
                xs_b[i] += shift[i]

        if self.scale is not None:
            xs_b = xs_b @ self.scale
        if self.bias is not None:
            xs_b += self.bias
        if n_dims_truncated is not None:
            xs_b[:, :, n_dims_truncated:] = 0
        return xs_b, w_b
