import sys
import os
sys.path.append('./')
import numpy as np
import tensorflow as tf
import math
from deel.datasets.util_generator_dataset import simple_generator,simple_dataset_generator
import matplotlib as mpl

from sklearn.datasets import make_moons,make_circles
import seaborn as sns


import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib.ticker import LinearLocator, FormatStrFormatter
import pandas as pd
import seaborn as sns

def plot_moons(X,Y,alpha=1,sub=3000):
    X1=X[Y==1]
    X2=X[Y==0]
    #sns.scatterplot(X1[:1000,0],X1[:1000,1],alpha=alpha, palette="deep")
    #sns.scatterplot(X2[:1000,0],X2[:1000,1],alpha=alpha, palette="deep")
    sns.scatterplot(x=X[:sub,0],y=X[:sub,1],hue=Y[:sub],alpha=alpha, palette="deep")

def test_display_levelset(model,dtset):
    X,Y=dtset['src_dataset']
    batch_size=1024

    x = np.linspace(X[:,0].min()-2, X[:,0].max()+2, 120)
    y = np.linspace(X[:,1].min()-2, X[:,1].max()+2,120)
    xx, yy = np.meshgrid(x, y, sparse=False)
    print(xx.min(),xx.max(),yy.min(),yy.max())

    X_pred=np.stack((xx.ravel(),yy.ravel()),axis=1)
    print(X_pred.shape)
    top_preds=tf.math.top_k(model.predict(X_pred), 2)
    pred=top_preds[0][:, 0]-top_preds[0][:, 1]
    pred = pred.numpy()
    for classNb in range(1):#pred.shape[1]):
        Y_pred=pred #pred[:,classNb]
        Y_pred=Y_pred.reshape(x.shape[0],y.shape[0])
        fig = plt.figure(figsize=(10,7))
        ax1 = fig.add_subplot(111)
        for classNb2 in range(4): #.shape[1]):
            sns.scatterplot(X[Y[:,classNb2]==1,0],X[Y[:,classNb2]==1,1],alpha=0.1,ax=ax1)
        cset =ax1.contour(xx,yy,Y_pred,cmap='twilight',levels = 10)
        sns.scatterplot(X_pred[pred<0.1,0],X_pred[pred<0.1,1],alpha=0.1,ax=ax1)
        ax1.clabel(cset, inline=1, fontsize=10)

    plt.show()
    return plt.gcf()

def test_display_levelset_binary(model,dtset,levels = 20,coeff = 1.1,filename=None):
    X,Y=dtset['src_dataset']
    batch_size=1024
    x_min = X[:,0].min()*coeff
    x_max = X[:,0].max()*coeff
    y_min = X[:,1].min()*coeff
    y_max = X[:,1].max()*coeff 
    x_max = y_max = max(x_max,y_max)
    x_min = y_min = min(x_min,y_min)
    x = np.linspace(x_min, x_max, 120)
    y = np.linspace(y_min, y_max,120)
    xx, yy = np.meshgrid(x, y, sparse=False)
    print(xx.min(),xx.max(),yy.min(),yy.max())

    X_pred=np.stack((xx.ravel(),yy.ravel()),axis=1)
    print(X_pred.shape)
    pred = model.predict(X_pred)
    print(pred.shape)
    # pred=pred-pred[:,0].mean()
    Y_pred = pred
    Y_pred_f = pred
    Y_pred_f = Y_pred_f.reshape(x.shape[0], y.shape[0])

    fig = plt.figure(figsize=(10, 10))
    ax1 = fig.add_subplot(111)
    ax1.spines['left'].set_position('zero')
    ax1.spines['right'].set_color('none')
    ax1.spines['bottom'].set_position('zero')
    ax1.spines['top'].set_color('none')

    grid_x_ticks = np.arange(x_min, x_max, 0.2)
    grid_y_ticks = np.arange(y_min, y_max, 0.2)
    #ax1.set_ticks_position('both')
    ax1.set_xticks(grid_x_ticks , minor=True)
    ax1.set_yticks(grid_y_ticks , minor=True)
    #ax1.grid(which='both')
    ax1.grid(True, 'major', ls='solid', lw=0.5, color='gray')
    ax1.grid(True, 'minor', ls='solid', lw=0.2, color='gray')
    #ax1.set_minor_locator(mpl.ticker.AutoMinorLocator())
    #ax1.grid(which='minor', alpha=0.3)
    # ax2 = fig.add_subplot(312)
    # ax3 = fig.add_subplot(313)
    sns.scatterplot(x=X[Y == 1, 0],y=X[Y == 1, 1], color = sns.color_palette()[1], alpha=0.2, ax=ax1)
    sns.scatterplot(x=X[Y == 0, 0],y=X[Y == 0, 1], color = sns.color_palette()[0],alpha=0.2, ax=ax1)
    cset = ax1.contour(xx, yy, Y_pred_f, cmap='plasma',levels = levels)
    ax1.clabel(cset, inline=1, fontsize=10)
    cset =ax1.contour(xx, yy, Y_pred_f,  [0.0],colors  = 'red',linestyles = 'dashed', linewidths  = 6)
    ax1.clabel(cset, inline=1, fontsize=14)
    ax1.patch.set_edgecolor('black')
    if filename is not None:
        plt.savefig(filename)
        plt.close()
    else :
        plt.show()
        
def kmoons_dataset_binary(circle_or_moons = 1,# 0 for circle , 1 for moons
                   n_samples=5000,
                   factor=0.4,
                   noise=0.05,
                   center = False,
                   reduce = False):
    if circle_or_moons == 0:
        X,Y=make_circles(n_samples=n_samples,noise=noise,factor=factor)
    else:
        X,Y=make_moons(n_samples=n_samples,noise=noise)
    #Y[Y==1]=-1
    #Y[Y==0]=1
    #X=X*10
    #X=X+5
    if center :
        X = X - X.mean(axis=0)
    if reduce :
        X = X/X.std()
    return X,Y
    
def kmoons_dataset(kmoons=2,
                   circle_or_moons = 1,# 0 for circle , 1 for moons
                   n_samples=5000,
                   factor=0.4,
                   noise=0.05
                   ):
  

    k = (kmoons//2)-1
    if circle_or_moons == 0:
        X,Y=make_circles(n_samples=n_samples,noise=noise,factor=factor)
    else:
        X,Y=make_moons(n_samples=n_samples,noise=noise)
    XX = [X]
    YY = [Y]
    for kk in range(k):
        X2=X.copy()
        X2[:,0] = X[:,0] + (np.max(X[Y==0,0])-np.min(X[Y==0,0]))*(kk+1)
        Y2 = Y + (kk+1)*2
        XX.append(X2)
        YY.append(Y2)
    newX = np.concatenate(XX,axis=0)
    #newX=newX*10
    newY = np.concatenate(YY,axis=0)
    newY = tf.keras.utils.to_categorical(newY,kmoons)

    return newX,newY

    
def kmoons_generator(batch_size,kmoons=2,
                   circle_or_moons = 1,# 0 for circle , 1 for moons
                   n_samples=5000,
                   factor=0.4,
                   noise=0.05):
    X,Y=kmoons_dataset(kmoons,circle_or_moons=circle_or_moons,n_samples=n_samples,factor=factor,noise=noise)
    dtset = {'train' : simple_generator(batch_size,X,Y), 'trainSize': None ,
             'valid' : None, 'validSize': None, 
             'test' : None, 'testSize': None, 
             'batch_size': batch_size, 'src_dataset': (X,Y) }
    return dtset

def kmoons_generator_binary(batch_size,
                   circle_or_moons = 1,# 0 for circle , 1 for moons
                   n_samples=5000,
                   factor=0.4,
                   tf_dataset = False,
                   center = False,
                   reduce = False,
                   noise=0.05):
    X,Y=kmoons_dataset_binary(circle_or_moons=circle_or_moons,n_samples=n_samples,factor=factor,noise=noise,center=center,reduce=reduce)

    
    X_test,Y_test=kmoons_dataset_binary(circle_or_moons=circle_or_moons,n_samples=n_samples,factor=factor,noise=noise,center=center,reduce=reduce)
    
    dtset = {'train' : simple_generator(batch_size,X,Y.reshape(Y.shape[0], 1)), 'trainSize': None ,
             'valid' : None, 'validSize': None, 
             'test' : simple_generator(batch_size,X_test,Y_test.reshape(Y.shape[0], 1)), 'testSize': n_samples,
             'batch_size': batch_size, 'src_dataset': (X,Y) }
    if tf_dataset:
        dtset["train"] = tf.data.Dataset.from_generator(simple_dataset_generator(batch_size,X,Y.reshape(Y.shape[0], 1)),
                                                        (tf.float32, tf.float32)).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
        dtset["test"] = tf.data.Dataset.from_generator(simple_dataset_generator(batch_size,X_test,Y_test.reshape(Y.shape[0], 1)),
                                                        (tf.float32, tf.float32)).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
    return dtset
    
def kmoons_generator_oneclass(circle_or_moons = 1,# 0 for circle , 1 for moons
                   n_samples=5000,
                   factor=0.4,
                   noise=0.05,
                   center = False,
                   reduce = False,
                   ):
    X,Y=kmoons_dataset_binary(circle_or_moons=circle_or_moons,n_samples=n_samples,factor=factor,noise=noise)
    if center :
        X = X - X.mean(axis=0)
    if reduce :
        X = X/X.std()
    return X

def koch_generator_binary(batch_size):

    path = __file__[:-17]+"koch.csv"
    df = pd.read_csv(path)
    X = df[['x','y']].values
    
    Y = df['class'].values.astype(np.int16)
    ind=np.random.permutation(Y.shape[0])
    X = X[ind]
    Y = Y[ind]
    
    X_test,Y_test=X,Y
    dtset = {'train' : simple_generator(batch_size,X,Y.reshape(Y.shape[0], 1)), 'trainSize': None ,
             'valid' : None, 'validSize': None, 
             'test' : simple_generator(batch_size,X_test,Y_test.reshape(Y.shape[0], 1)), 'testSize': None,
             'batch_size': batch_size, 'src_dataset': (X,Y) }
    return dtset

import matplotlib.pyplot as plt
if __name__ == "__main__":
    k=6
    X,Y = kmoons_generator(k)
    X=X*10
    print(X.shape)
    print(Y.shape)
    print(np.max(Y))
    for kk in range(2*(k+1)):
        X1=X[Y==kk]
        sns.scatterplot(X1[:1000,0],X1[:1000,1])
    plt.show()