import random
import numpy as np
import time
from IPython import display
from d2l import torch as d2l
import torch
import matplotlib.pyplot as plt
import matplotlib as mpl


# %% set random seed
def set_seed(seed_id):
    torch.manual_seed(seed_id)
    torch.cuda.manual_seed(seed_id)
    np.random.seed(seed_id)
    random.seed(seed_id)

# %% set GPU
def try_gpu(i=0):
    """Return gpu(i) if exists, otherwise return cpu()"""
    if torch.cuda.device_count() >= i + 1:
        return torch.device(f'cuda:{i}')
    return torch.device('cpu')

# %% construct gaussian graph
def cal_dist_matrix(X):
    n = X.shape[0]
    X_sum = torch.sum(X ** 2, 1,keepdim=True).repeat((1,n))  
    XXT = torch.matmul(X,X.T)
    D = X_sum+X_sum.T-2*XXT
    return D

def cal_dist_gaussian(X, sigma = 1,num_neighbours = 5):
    X_new = torch.tensor(X)
    n,_ = X_new.shape 
    E = cal_dist_matrix(X_new)

    exp_E = torch.exp(-E/(2*sigma**2))
    A = torch.zeros((n,n))
    for i in range(n):
        dist_sum = exp_E[i,:i].sum()+exp_E[i,i+1:].sum()
        A[i,:i] = exp_E[i,:i]/(dist_sum+1e-6)
        A[i,i+1:] = exp_E[i,i+1:]/(dist_sum+1e-6)

    P_dist = (A+A.T)/2
    P_dist = torch.max(P_dist, torch.tensor([1e-21]))

    sorted_dist, sorted_ind = P_dist.sort(dim=1, descending=False)
    for i in range(X.shape[0]):
        P_dist[i,sorted_ind[i,0:n-num_neighbours]] = 0

    return P_dist.numpy()


# %% plot edge and graph
def plot_connection(X,S_ori,ind):
    fig = plt.figure(figsize=(10,7))
    plt.style.use('classic')
    mpl.rcParams['axes.spines.right'] = False
    mpl.rcParams['axes.spines.top'] = False
    fig.set_facecolor('w')
    plt.scatter(X[:,ind[0]],X[:,ind[1]],s=500,alpha=0.8,cmap='winter',edgecolor='k',linewidth=1)
    S = S_ori/S_ori.max()
    sorted_dist, sorted_ind = S.sort(dim=1, descending=True)
    for i in range(X.shape[0]):
        for j in range(5):
            ind_j = sorted_ind[i,j]
            plt.plot([X[i,ind[0]],X[ind_j,ind[0]]],[X[i,ind[1]],X[ind_j,ind[1]]],linewidth=sorted_dist[i,j],c = 'r')
    plt.tick_params(labelsize=20,pad=2)
    plt.xlabel('$X_1$',fontsize=30)
    plt.ylabel('$X_2$',fontsize=30)
    plt.tight_layout()