import numpy as np
import sklearn.datasets
from PIL import Image
import scipy.io
import torch
from torchvision.datasets import MNIST, CIFAR10, CIFAR100, STL10
from torch.utils.data import DataLoader
from torchvision import transforms
#from torch_geometric.datasets import Planetoid
from sklearn import preprocessing


### LOAD DATA ###
def parse_dataset(args):
    if args.dataset == 'synthetic':
        ambient_dim = 9
        subspace_dim = 6
        nclass = 5
        num_points_per_subspace = 50
        X, label, nclass = synth_uos(ambient_dim, subspace_dim, nclass, num_points_per_subspace, 0.00)
    elif args.dataset == 'orl':
        X, label, nclass = load_orl()
    elif args.dataset == 'coil':
        X, label, nclass = load_coil(nclass=100)
    elif args.dataset == 'coil_simclr':
        X, label, nclass = load_coil_simclr(nclass=100)
    elif args.dataset == 'coil_scatter':
        X, label, nclass = load_coil_scatter(nclass=100)
    elif args.dataset == 'mnist':
        X, label, nclass = load_mnist(N=args.N, include_test=True)
    elif args.dataset == 'mnist_scatter':
        # max N = 70000
        X, label, nclass = load_mnist_scatter(N=args.N)
    elif args.dataset == 'cifar_10':
        # max N = 50000
        X, label, nclass = load_cifar10(N=args.N, include_test=True)
    elif args.dataset == 'cifar_10_simclr':
        # max N = 60000
        X, label, nclass = load_cifar10_simclr(N=args.N)
    elif args.dataset == 'cifar_10_scatter':
        X, label, nclass = load_cifar10_scatter(N=args.N)
    elif args.dataset == 'cifar_10_clip':
        X, label, nclass = load_cifar_clip(which='cifar10', N=args.N)
    elif args.dataset == 'cifar_100_clip':
        X, label, nclass = load_cifar_clip(which='cifar100', N=args.N)
    elif args.dataset == 'cifar_20_clip':
        X, label, nclass = load_cifar_clip(which='cifar20', N=args.N)
    elif args.dataset == 'flowers_simclr':
        X, label, nclass = load_oxford_flowers_simclr()
    elif args.dataset == 'covtype':
        X, label, nclass = load_covtype(N=args.N)
    elif args.dataset == 'stl10_simclr':
        X, label, nclass = load_stl10_simclr()
    elif args.dataset == 'emnist_digits_scatter':
        X, label, nclass = load_emnist_digits_scatter(N=args.N)
    elif args.dataset == 'emnist_letters_scatter':
        X, label, nclass = load_emnist_letters_scatter(N=args.N)
    else:
        raise ValueError('Invalid dataset')

    X, label = normalize_data(X, label, args.normalize)
    return X, label, nclass

def normalize_data(X, label, option):
    if isinstance(X, torch.Tensor):
        X = X.numpy()
    if isinstance(label, torch.Tensor):
        label = label.numpy()
    
    original_dtype = X.dtype  # Store the original dtype of X

    # Convert X to float32 if it's float16
    if original_dtype == np.float16:
        print('X is float16, converting to float32')
        X = X.astype(np.float32)
        
    if option == 'none':
        pass
    elif option == 'unit_sph':
        X = X / np.linalg.norm(X, axis=1, keepdims=True)
    elif option == 'whiten':
        X = X - np.mean(X, axis=1, keepdims=True)
        U, S, Vt = np.linalg.svd(X, full_matrices=False)
        S[S > 1e-6] = 1
        X = U @ np.diag(S) @ Vt
    elif option == 'whiten_unit_sph':
        X = X - np.mean(X, axis=1, keepdims=True)
        U, S, Vt = np.linalg.svd(X, full_matrices=False)
        S[S > 1e-6] = 1
        X = U @ np.diag(S) @ Vt
        X = X / np.linalg.norm(X, axis=1, keepdims=True)
    else:
        raise ValueError('Invalid normalization option')

    # compute the norm of X
    X_norm = np.linalg.norm(X, axis=1, keepdims=True)
    print("new X_norm mean:", X_norm.mean(), "std:", X_norm.std())

    # compute the mean of X
    X_mean = X.mean(axis=0)
    print("new X_mean norm:", np.linalg.norm(X_mean))

    # Convert X back to its original dtype
    if original_dtype == np.float16:
        print('converting back to float16')
        X = X.astype(np.float16)

    return X, label


def load_orl(zero_mean=False):
    """ loads 64 x 64 dim ORL dataset from sklearn """
    data = sklearn.datasets.fetch_olivetti_faces()
    X = data['data']
    if zero_mean:
        X = X - np.mean(X, axis=1, keepdims=True)
    label = data['target']
    nclass = np.unique(label).shape[0]
    norm_X = X / np.linalg.norm(X, axis=1, keepdims=True)
    return norm_X, label, nclass

def load_coil(nclass=100, return_X=True):
    dataset = scipy.io.loadmat('data/coil_100.mat')
    data, label = dataset['data'], dataset['label']
    label = label.flatten()
    n = 72*nclass
    data = data[:n]
    label = label[:n]
    if return_X:
        X = data.reshape(data.shape[0], -1)
        X = X / X.max()
        return X, label, nclass
    else:
        return data, label, nclass


def load_coil_raw(resize=True, gray=True):
    """ load coil from image files """
    label = np.zeros(7200, dtype=int)
    pathname = 'data/coil-100'
    idx = 0
    #data = np.zeros((7200,32,32,3), dtype=np.uint8)
    data = np.zeros((7200,32,32))
    for obj_num in range(1,101):
        for deg in range(0,360,5):
            img = Image.open(f'{pathname}/obj{obj_num}__{deg}.png')
            img = img.convert('L')
            if resize:
                img = img.resize((32,32), resample=1)
            #data[idx,:,:,:] = np.array(img)
            data[idx,:,:] = np.array(img)
            label[idx] = obj_num - 1
            idx += 1
    save_dict = {'data':data, 'label':label}
    scipy.io.savemat('data/coil_100.mat', save_dict)
    return data, label

def load_mnist(N=1000, include_test=False):
    fulldata = MNIST('data/', download=True, train=True, transform=transforms.ToTensor())
    loader = DataLoader(fulldata, batch_size=N, shuffle=False)

    for X, label in loader:
        break
    X = X.reshape(X.shape[0], -1)
    
    if include_test:
        fulldata = MNIST('data/', download=True, train=False, transform=transforms.ToTensor())
        loader = DataLoader(fulldata, batch_size=10000, shuffle=False)
        
        for X_test, label_test in loader:
            break

        X_test = X_test.reshape(X_test.shape[0], -1)
        X = np.concatenate((X, X_test), axis=0)
        label = np.concatenate((label, label_test))

    nclass = 10
    
    return X, label, nclass

def load_cifar10(N=1000, include_test=False, download=False):
    fulldata = CIFAR10('data/', train=True, transform=transforms.ToTensor(), download=download)
    loader = DataLoader(fulldata, batch_size=N, shuffle=False)

    for X, label in loader:
        break

    X = X.reshape(X.shape[0], -1)

    if include_test:
        fulldata = CIFAR10('data/', train=False, transform=transforms.ToTensor())
        loader = DataLoader(fulldata, batch_size=10000, shuffle=False)

        for X_test, label_test in loader:
            break
        X_test = X_test.reshape(X_test.shape[0], -1)

        X = np.concatenate((X, X_test), axis=0)
        label = np.concatenate((label, label_test))

    nclass = 10
    
    return X, label, nclass

def load_cifar_clip(which, N=15000):
    if which == 'cifar10':
        # filename = 'results/221126-clipfeatures/cifar10-clipfeat.pt'
        filename = './data/cifar10-clipfeat.pt'
        nclass = 10
    elif which == 'cifar100':
        filename = './data/cifar100-clipfeat.pt'
        nclass = 100
    elif which == 'cifar20':
        filename = './data/cifar100coarse-clipfeat.pt'
        nclass = 20
    else:
        raise ValueError('invalid dataset')
    # this is a dict, with keys 'features' and 'labels'
    fulldata = torch.load(filename)
    X = fulldata['features']
    label = fulldata['labels']

    
    # downsample_idx = torch.randperm(X.shape[0])[:N]
    # X = X[downsample_idx]
    # label = label[downsample_idx]
    X = X[:N]
    label = label[:N]

    # compute the norm of X
    X_norm = X.norm(dim=1, keepdim=True).to(torch.float32)
    print("X_norm mean:", X_norm.mean().item(), "std:", X_norm.std().item())

    # compute the mean of X
    X_mean = X.to(torch.float32).mean(dim=0)
    print("X_mean norm:", X_mean.norm().item())

    return X, label, nclass

def load_coil_simclr(nclass=100):
    fulldata = scipy.io.loadmat('data/coil100_simclr.mat')
    X = fulldata['X']
    label = fulldata['label'].flatten()
    X = X[label < nclass, :]
    label = label[label < nclass]
    return X, label, nclass

def load_cifar10_simclr(N=60000):
    fulldata = scipy.io.loadmat('data/cifar10_simclr.mat')
    X = fulldata['X']
    label = fulldata['label'].flatten()
    X = X[:N, :]
    label = label[:N]
    nclass = 10
    return X, label, nclass

def load_cifar100_simclr(N=60000, coarse=True):
    fulldata = scipy.io.loadmat('data/cifar100_simclr.mat')
    X = fulldata['X']
    if coarse:
        label = fulldata['coarse_label'].flatten()
    else:
        label = fulldata['label'].flatten()
    X = X[:N, :]
    label = label[:N]
    nclass = 20 if coarse else 100
    return X, label, nclass

def load_oxford_flowers_simclr():
    fulldata = scipy.io.loadmat('data/oxford_flowers102_simclr.mat')
    X = fulldata['X']
    label = fulldata['label'].flatten()
    nclass = 102
    return X, label, nclass

def load_mnist_scatter(N=10000):
    fulldata = scipy.io.loadmat('data/MNIST_70000_scatter.mat')
    X = np.ascontiguousarray(fulldata['data'].T)
    label = fulldata['label'].flatten()
    nclass = 10
    X = X[:N, :]
    label = label[:N]

    return X, label, nclass

def sbm(on_fn, off_fn, n_blocks=10, block_size=20, p=.3, q=.05, rng=None):
    """ (Symmetric) Stochastic Block Model
    Adapted from https://github.com/dmlc/dgl/blob/master/python/dgl/data/sbm.py
    Parameters
    ----------
    on_fn = function to generate weights for on-diagonal
    off_fn = function to generate weights for off-diagonal
    n_blocks: Number of blocks.
    block_size: Block size.
    p: Probability for intra-community edge.
    q: Probability for inter-community edge.
    Returns
    -------
    scipy sparse matrix
        The adjacency matrix of generated graph.
    """
    n = n_blocks * block_size
    label = np.zeros(n, dtype=np.int)
    for class_num in range(n_blocks):
        label[class_num*block_size:(class_num+1)*block_size] = class_num
    rng = np.random.RandomState() if rng is None else rng

    rows = []
    cols = []
    weights = []
    for i in range(n_blocks):
        for j in range(i, n_blocks):
            density = p if i == j else q
            data_rvs = on_fn if i == j else off_fn
            block = scipy.sparse.random(block_size, block_size, density=density,
                                     random_state=rng, data_rvs=data_rvs)
            rows.append(block.row + i * block_size)
            cols.append(block.col + j * block_size)
            weights.append(block.data)
            

    rows = np.hstack(rows)
    cols = np.hstack(cols)
    weights = np.hstack(weights)
    a = scipy.sparse.coo_matrix((weights, (rows, cols)), shape=(n, n))
    K = scipy.sparse.triu(a,1) + scipy.sparse.triu(a, 1).transpose()
    return K, label, n_blocks

def load_torch_geo(data):
    X = data.x
    train_idx = torch.where(data.train_mask)[0]
    valid_idx = torch.where(data.val_mask)[0]
    test_idx = torch.where(data.test_mask)[0]
    label = data.y
    nclass = torch.unique(label).size(0)
    return X, label, nclass, train_idx, valid_idx, test_idx

"""
def load_planetoid(name, return_idx=False):
    data = Planetoid('data/', name)[0]
    X, label, nclass, train_idx, valid_idx, test_idx = load_torch_geo(data)
    
    if return_idx:
        return X, label, nclass, train_idx, valid_idx, test_idx
    
    return X, label, nclass
"""

def load_coil_scatter(nclass=100):
    """ Loads scattered COIL dataset """
    pathname = "data/COIL_scatter.mat"
    fulldata = scipy.io.loadmat(pathname)
    X = fulldata["data"]
    label = fulldata["label"]
    X = np.ascontiguousarray(X.T)
    norm_X = X / np.linalg.norm(X, axis=1, keepdims=True)
    norm_X = norm_X[:nclass*72, :]
    label = np.array(label).squeeze() - 1
    label = label[:nclass*72];
    print(norm_X.shape, label.shape)
    return norm_X, label, nclass

def load_covtype(N=200000):
    """ note: standardizes features, then normalizes data points"""
    fulldata = np.genfromtxt('data/covtype.data', delimiter=',')
    X = fulldata[:N,:-1]
    X = preprocessing.StandardScaler().fit_transform(X)
    X = X / np.linalg.norm(X, axis=1, keepdims=True)
    label = (fulldata[:N, -1]-1).astype(np.int)
    nclass = 7
    return X, label, nclass

def synth_uos(ambient_dim, subspace_dim, num_subspaces, num_points_per_subspace, noise_level=0.0):
    """
    from https://github.com/ChongYou/subspace-clustering/blob/master/gen_union_of_subspaces.py
    This funtion generates a union of subspaces under random model, i.e., 
    subspaces are independently and uniformly distributed in the ambient space,
    data points are independently and uniformly distributed on the unit sphere of each subspace
    Parameters
    -----------
    ambient_dim : int
        Dimention of the ambient space
    subspace_dim : int
        Dimension of each subspace (all subspaces have the same dimension)
    num_subspaces : int
        Number of subspaces to be generated
    num_points_per_subspace : int
        Number of data points from each of the subspaces
    noise_level : float
        Amount of Gaussian noise on data
    """

    data = np.empty((num_points_per_subspace* num_subspaces, ambient_dim))
    label = np.empty(num_points_per_subspace * num_subspaces, dtype=int)
  
    for i in range(num_subspaces):
        basis = np.random.normal(size=(ambient_dim, subspace_dim))
        basis = scipy.linalg.orth(basis)
        coeff = np.random.normal(size=(subspace_dim, num_points_per_subspace))
        coeff = preprocessing.normalize(coeff, norm='l2', axis=0, copy=False)
        data_per_subspace = np.matmul(basis, coeff).T

        base_index = i*num_points_per_subspace
        data[(0+base_index):(num_points_per_subspace+base_index), :] = data_per_subspace
        label[0+base_index:num_points_per_subspace+base_index,] = i

    data += np.random.normal(size=(num_points_per_subspace * num_subspaces, ambient_dim)) * noise_level
  
    return data, label, num_subspaces

def load_stl10_simclr(unlabeled=False):
    fulldata = scipy.io.loadmat('data/stl10_simclr.mat')
    X = fulldata['X']
    label = fulldata['label'].flatten()
    nclass = 10
    if not unlabeled:
        X = X[label >= 0, :]
        label = label[label >= 0]
    return X, label, nclass

def load_emnist_digits(N=10000):
    filename = 'data/emnist/emnist-digits.mat'
    fulldata = scipy.io.loadmat(filename)
    train_X = fulldata['dataset'][0,0][0][0,0][0]
    train_label = fulldata['dataset'][0,0][0][0,0][1]
    test_X = fulldata['dataset'][0,0][1][0,0][0]
    test_label = fulldata['dataset'][0,0][1][0,0][1]
    X = np.vstack((train_X, test_X)).astype(np.float)
    X = X / X.max()
    label = np.squeeze(np.vstack((train_label, test_label))).astype(np.int)
    label = label
    X = X[:N, :]
    label = label[:N]
    nclass = 10
    return X, label, nclass

def load_emnist_digits_scatter(N=10000):
    filename = 'data/emnist_digits_scatter.mat'
    fulldata = scipy.io.loadmat(filename)
    X = fulldata['X']
    label = np.squeeze(fulldata['label'])
    X = X[:N, :]
    label = label[:N]
    nclass = 10
    return X, label, nclass

def load_emnist_letters(N=10000):
    filename = 'data/emnist/emnist-letters.mat'
    fulldata = scipy.io.loadmat(filename)
    train_X = fulldata['dataset'][0,0][0][0,0][0]
    train_label = fulldata['dataset'][0,0][0][0,0][1]
    test_X = fulldata['dataset'][0,0][1][0,0][0]
    test_label = fulldata['dataset'][0,0][1][0,0][1]
    X = np.vstack((train_X, test_X)).astype(np.float)
    X = X / X.max()
    label = np.squeeze(np.vstack((train_label, test_label))).astype(np.int)
    label = label - 1
    X = X[:N, :]
    label = label[:N]
    nclass = 26
    return X, label, nclass

def load_emnist_letters_scatter(N=10000):
    filename = 'data/emnist_letters_scatter.mat'
    fulldata = scipy.io.loadmat(filename)
    X = fulldata['X']
    label = np.squeeze(fulldata['label'])
    X = X[:N, :]
    label = label[:N]
    nclass = 26
    return X, label, nclass

def load_cifar10_scatter(N=10000):
    filename = 'data/cifar10_scatter.mat'
    fulldata = scipy.io.loadmat(filename)
    X = fulldata['X']
    label = np.squeeze(fulldata['label'])
    X = X[:N, :]
    label = label[:N]
    nclass = 10
    return X, label, nclass

def load_stl10(N=1000, include_test=False, download=False):
    fulldata = STL10('data/', split='train', transform=transforms.ToTensor(), download=download)
    loader = DataLoader(fulldata, batch_size=N, shuffle=False)

    for X, label in loader:
        break

    X = X.reshape(X.shape[0], -1)

    if include_test:
        fulldata = STL10('data/', split='test', transform=transforms.ToTensor())
        loader = DataLoader(fulldata, batch_size=10000, shuffle=False)

        for X_test, label_test in loader:
            break
        X_test = X_test.reshape(X_test.shape[0], -1)

        X = np.concatenate((X, X_test), axis=0)
        label = np.concatenate((label, label_test))

    nclass = 10
    
    return X, label, nclass

def load_coil20_scatter():
    filename = 'data/COIL20_scatter_2.mat'
    fulldata = scipy.io.loadmat(filename)
    X = fulldata['X'].T
    label = np.squeeze(fulldata['label'])
    nclass = 20
    return X, label, nclass

