import json
import os
import scipy
import numba

import scanpy as sc
import pandas as pd
import numpy as np

from annoy import AnnoyIndex
from sklearn.decomposition import PCA
from sklearn.datasets import make_swiss_roll, make_s_curve

from typing import Optional


def data_prep(dataset: str='MNIST', size: int=-1, dim: Optional[int]=None, pca: bool=False):
    """A custom function that prepares the dataset for pacmap training.
    """
    DATA_DIR = "" # Insert your data dir
    if dataset == 'MNIST':
        X = np.load(f'{DATA_DIR}/mnist_images.npy', allow_pickle=True).reshape(70000, -1)
        labels = np.load(f'{DATA_DIR}/mnist_labels.npy', allow_pickle=True)
    elif dataset == 'FMNIST':
        X = np.load(f'{DATA_DIR}/fmnist_images.npy', allow_pickle=True).reshape(70000, -1)
        labels = np.load(f'{DATA_DIR}/fmnist_labels.npy', allow_pickle=True)
    elif dataset == 'mammoth':
        with open(f'{DATA_DIR}/mammoth_3d.json', 'r') as f:
            X = json.load(f)
        X = np.array(X)
        with open(f'{DATA_DIR}/mammoth_umap.json', 'r') as f:
            labels = json.load(f)
        labels = labels['labels']
        labels = np.array(labels)
        labels = np.sort(labels)
    elif dataset == 'hierarchy':
        X = np.load(f'{DATA_DIR}/hierarchical_threelayer_dataset.npy', allow_pickle=True)
        labels = np.load(f"{DATA_DIR}/hierarchical_threelayer_label.npy", allow_pickle=True)
    elif dataset == 'kazer':
        X = np.load(f"{DATA_DIR}/hiv_70.npy", allow_pickle=True)
        labels = np.load(f"{DATA_DIR}/kazer_raw_labels.npy", allow_pickle=True)
    elif dataset == 'stuart':
        X = np.load(f"{DATA_DIR}/seurat_bmnc_rna_70.npy", allow_pickle=True)
        labels = np.load(f"{DATA_DIR}/stuart_labels.npy", allow_pickle=True)
    elif dataset == 'muraro':
        X = np.load(f"{DATA_DIR}/muraro_log_pca.npy", allow_pickle=True)
        labels = np.load(f"{DATA_DIR}/muraro_labels.npy", allow_pickle=True)
    elif dataset == 'kang':
        X = np.load(f"{DATA_DIR}/kang_log_pca.npy", allow_pickle=True)
        labels = np.load(f"{DATA_DIR}/kang_labels.npy", allow_pickle=True)
    elif dataset == 'circle':
        if size == -1:
            size = 5000  # Set default
        X, labels = make_circle(10, num_points_per_cluster=size//10)
    elif dataset == '20NG':
        X = np.load(f"{DATA_DIR}/20NG.npy")
        labels = np.load("{DATA_DIR}/20NG_labels.npy")
    elif dataset == 'coil20':
        X = np.load(f"{DATA_DIR}/coil_20.npy", allow_pickle=True)
        X = X.reshape([X.shape[0], -1])
        labels = np.load(f"{DATA_DIR}/coil_20_labels.npy", allow_pickle=True)
    elif dataset == 'coil100':
        X = np.load(f"{DATA_DIR}/coil_100.npy")
        X = X.reshape([X.shape[0], -1])
        labels = np.load(f"{DATA_DIR}/coil_100_labels.npy")
    elif dataset == 'usps':
        X = np.load(f"{DATA_DIR}/USPS.npy")
        X = X.reshape([X.shape[0], -1])
        labels = np.load(f"{DATA_DIR}/USPS_labels.npy")
    else:
        raise ValueError(f'Unsupported dataset: {dataset}')
    if size != -1:
        X = X[:size, :dim]
        labels = labels[:size]
    else:
        X = X[:, :dim]
    X = X.copy(order='C')
    X = X.astype(np.float32)
    if pca and X.shape[1] > 100:
        reducer = PCA(n_components=100)
        X = reducer.fit_transform(X)
    return X, labels


def generate_pair(
        X,
        n_neighbors,
        n_MN,
        n_FP,
        distance='euclidean',
        verbose=True,
        random_state=None
):
    '''Generate pairs for the dataset.
    '''
    n, dim = X.shape
    # sample more neighbors than needed
    n_neighbors_extra = min(n_neighbors + 50, n - 1)
    tree = AnnoyIndex(dim, metric=distance)
    if random_state is not None:
        tree.set_seed(random_state)
    for i in range(n):
        tree.add_item(i, X[i, :])
    tree.build(20)

    option = distance_to_option(distance=distance)

    nbrs = np.zeros((n, n_neighbors_extra), dtype=np.int32)
    knn_distances = np.empty((n, n_neighbors_extra), dtype=np.float32)

    for i in range(n):
        nbrs_ = tree.get_nns_by_item(i, n_neighbors_extra + 1)
        nbrs[i, :] = nbrs_[1:]
        # This line is subject to change
        for j in range(n_neighbors_extra):
            knn_distances[i, j] = tree.get_distance(i, nbrs[i, j])
    print_verbose("Found nearest neighbor", verbose)
    sig = np.maximum(np.mean(knn_distances[:, 3:6], axis=1), 1e-10)
    print_verbose("Calculated sigma", verbose)
    scaled_dist = scale_dist(knn_distances, sig, nbrs)
    print_verbose("Found scaled dist", verbose)
    pair_neighbors = sample_neighbors_pair(X, scaled_dist, nbrs, n_neighbors)
    if random_state is None:
        pair_MN = sample_MN_pair(X, n_MN, option)
        pair_FP = sample_FP_pair(X, pair_neighbors, n_neighbors, n_FP)
    else:
        pair_MN = sample_MN_pair_deterministic(X, n_MN, random_state, option)
        pair_FP = sample_FP_pair_deterministic(
            X, pair_neighbors, n_neighbors, n_FP, random_state)
    return pair_neighbors, pair_MN, pair_FP, tree


def distance_to_option(distance='euclidean'):
    '''A helper function that translates distance metric to int options.
    Such a translation is useful for numba acceleration.
    '''
    if distance == 'euclidean':
        option = 0
    elif distance == 'manhattan':
        option = 1
    elif distance == 'angular':
        option = 2
    elif distance == 'hamming':
        option = 3
    else:
        raise NotImplementedError('Distance other than euclidean, manhattan,' +
                                  'angular or hamming is not supported')
    return option


def print_verbose(msg, verbose, **kwargs):
    if verbose:
        print(msg, **kwargs)


@numba.njit("f4(f4[:])")
def l2_norm(x):
    """
    L2 norm of a vector.
    """
    result = 0.0
    for i in range(x.shape[0]):
        result += x[i] ** 2
    return np.sqrt(result)


@numba.njit("f4(f4[:],f4[:])")
def euclid_dist(x1, x2):
    """
    Euclidean distance between two vectors.
    """
    result = 0.0
    for i in range(x1.shape[0]):
        result += (x1[i] - x2[i]) ** 2
    return np.sqrt(result)


@numba.njit("f4(f4[:],f4[:])")
def manhattan_dist(x1, x2):
    """
    Manhattan distance between two vectors.
    """
    result = 0.0
    for i in range(x1.shape[0]):
        result += np.abs(x1[i] - x2[i])
    return result


@numba.njit("f4(f4[:],f4[:])")
def angular_dist(x1, x2):
    """
    Angular (i.e. cosine) distance between two vectors.
    """
    x1_norm = np.maximum(l2_norm(x1), 1e-20)
    x2_norm = np.maximum(l2_norm(x2), 1e-20)
    result = 0.0
    for i in range(x1.shape[0]):
        result += x1[i] * x2[i]
    return np.sqrt(2.0 - 2.0 * result / x1_norm / x2_norm)


@numba.njit("f4(f4[:],f4[:])")
def hamming_dist(x1, x2):
    """
    Hamming distance between two vectors.
    """
    result = 0.0
    for i in range(x1.shape[0]):
        if x1[i] != x2[i]:
            result += 1.0
    return result


@numba.njit()
def calculate_dist(x1, x2, distance_index):
    if distance_index == 0:  # euclidean
        return euclid_dist(x1, x2)
    elif distance_index == 1:  # manhattan
        return manhattan_dist(x1, x2)
    elif distance_index == 2:  # angular
        return angular_dist(x1, x2)
    elif distance_index == 3:  # hamming
        return hamming_dist(x1, x2)


@numba.njit("i4[:](i4,i4,i4[:])", nogil=True)
def sample_FP(n_samples, maximum, reject_ind):
    result = np.empty(n_samples, dtype=np.int32)
    for i in range(n_samples):
        reject_sample = True
        while reject_sample:
            j = np.random.randint(maximum)
            for k in range(i):
                if j == result[k]:
                    break
            for k in range(reject_ind.shape[0]):
                if j == reject_ind[k]:
                    break
            else:
                reject_sample = False
        result[i] = j
    return result


@numba.njit("i4[:,:](f4[:,:],f4[:,:],i4[:,:],i4)", parallel=True, nogil=True)
def sample_neighbors_pair(X, scaled_dist, nbrs, n_neighbors):
    n = X.shape[0]
    pair_neighbors = np.empty((n*n_neighbors, 2), dtype=np.int32)

    for i in numba.prange(n):
        scaled_sort = np.argsort(scaled_dist[i])
        for j in numba.prange(n_neighbors):
            pair_neighbors[i*n_neighbors + j][0] = i
            pair_neighbors[i*n_neighbors + j][1] = nbrs[i][scaled_sort[j]]
    return pair_neighbors


@numba.njit("i4[:,:](i4,f4[:,:],f4[:,:],i4[:,:],i4)", parallel=True, nogil=True)
def sample_neighbors_pair_basis(n_basis, X, scaled_dist, nbrs, n_neighbors):
    '''Sample Nearest Neighbor pairs for additional data.'''
    n = X.shape[0]
    pair_neighbors = np.empty((n*n_neighbors, 2), dtype=np.int32)

    for i in numba.prange(n):
        scaled_sort = np.argsort(scaled_dist[i])
        for j in numba.prange(n_neighbors):
            pair_neighbors[i*n_neighbors + j][0] = n_basis + i
            pair_neighbors[i*n_neighbors + j][1] = nbrs[i][scaled_sort[j]]
    return pair_neighbors


@numba.njit("i4[:,:](f4[:,:],i4,i4)", nogil=True)
def sample_MN_pair(X, n_MN, option=0):
    '''Sample Mid Near pairs.'''
    n = X.shape[0]
    pair_MN = np.empty((n*n_MN, 2), dtype=np.int32)
    for i in numba.prange(n):
        for jj in range(n_MN):
            sampled = np.random.randint(0, n, 6)
            dist_list = np.empty((6), dtype=np.float32)
            for t in range(sampled.shape[0]):
                dist_list[t] = calculate_dist(
                    X[i], X[sampled[t]], distance_index=option)
            min_dic = np.argmin(dist_list)
            dist_list = np.delete(dist_list, [min_dic])
            sampled = np.delete(sampled, [min_dic])
            picked = sampled[np.argmin(dist_list)]
            pair_MN[i*n_MN + jj][0] = i
            pair_MN[i*n_MN + jj][1] = picked
    return pair_MN


@numba.njit("i4[:,:](f4[:,:],i4,i4,i4)", nogil=True)
def sample_MN_pair_deterministic(X, n_MN, random_state, option=0):
    '''Sample Mid Near pairs using the given random state.'''
    n = X.shape[0]
    pair_MN = np.empty((n*n_MN, 2), dtype=np.int32)
    for i in numba.prange(n):
        for jj in range(n_MN):
            # Shifting the seed to prevent sampling the same pairs
            np.random.seed(random_state + i * n_MN + jj)
            sampled = np.random.randint(0, n, 6)
            dist_list = np.empty((6), dtype=np.float32)
            for t in range(sampled.shape[0]):
                dist_list[t] = calculate_dist(
                    X[i], X[sampled[t]], distance_index=option)
            min_dic = np.argmin(dist_list)
            dist_list = np.delete(dist_list, [min_dic])
            sampled = np.delete(sampled, [min_dic])
            picked = sampled[np.argmin(dist_list)]
            pair_MN[i*n_MN + jj][0] = i
            pair_MN[i*n_MN + jj][1] = picked
    return pair_MN


@numba.njit("i4[:,:](f4[:,:],i4[:,:],i4,i4)", parallel=True, nogil=True)
def sample_FP_pair(X, pair_neighbors, n_neighbors, n_FP):
    '''Sample Further pairs.'''
    n = X.shape[0]
    pair_FP = np.empty((n * n_FP, 2), dtype=np.int32)
    for i in numba.prange(n):
        for k in numba.prange(n_FP):
            FP_index = sample_FP(
                n_FP, n, pair_neighbors[i*n_neighbors: i*n_neighbors + n_neighbors, 1])
            pair_FP[i*n_FP + k][0] = i
            pair_FP[i*n_FP + k][1] = FP_index[k]
    return pair_FP


@numba.njit("i4[:,:](f4[:,:],i4[:,:],i4,i4,i4)", parallel=True, nogil=True)
def sample_FP_pair_deterministic(X, pair_neighbors, n_neighbors, n_FP, random_state):
    '''Sample Further pairs using the given random state.'''
    n = X.shape[0]
    pair_FP = np.empty((n * n_FP, 2), dtype=np.int32)
    for i in numba.prange(n):
        for k in numba.prange(n_FP):
            np.random.seed(random_state+i*n_FP+k)
            FP_index = sample_FP(
                n_FP, n, pair_neighbors[i*n_neighbors: i*n_neighbors + n_neighbors, 1])
            pair_FP[i*n_FP + k][0] = i
            pair_FP[i*n_FP + k][1] = FP_index[k]
    return pair_FP


@numba.njit("f4[:,:](f4[:,:],f4[:],i4[:,:])", parallel=True, nogil=True)
def scale_dist(knn_distance, sig, nbrs):
    '''Scale the distance'''
    n, num_neighbors = knn_distance.shape
    scaled_dist = np.zeros((n, num_neighbors), dtype=np.float32)
    for i in numba.prange(n):
        for j in numba.prange(num_neighbors):
            scaled_dist[i, j] = knn_distance[i, j] ** 2 / \
                sig[i] / sig[nbrs[i, j]]
    return scaled_dist


def make_circle(num_clusters=10, num_points_per_cluster=10, sep=True):
    '''Create a number of points in 2D that will form a circle.
    '''
    num_points = num_clusters * num_points_per_cluster # Total number of points
    angles = []

    # Generate a set of angles that will appear on a circle
    if sep:
        for i in range(0, 2 * num_clusters, 2):
            angles += list(np.random.uniform(low=i / (2 * num_clusters) * 2 * np.pi,
                                             high=(i + 1) / (2 * num_clusters) * 2 * np.pi,
                                             size=int(num_points / num_clusters)))
    else:
        angles = np.random.rand(num_points) * 2 * np.pi
        angles = np.sort(angles)

    circle = np.array([[np.cos(a), np.sin(a)] for a in angles])
    colors = []
    for i in range(num_points):
        colors.append(int(num_clusters * angles[i] / (2 * np.pi)))
    colors = np.array(colors)
    return circle, colors
