import sys
sys.path.append('./')
sys.path.append('../')
import numpy as np
import pickle as pkl
import cv2
import pandas as pd
from matplotlib import pyplot as plt
from deel.datasets.util_generator_dataset import simple_generator
from deel.utils.ImageTransformer import ImageTransformer
import tensorflow as tf
def display_img_vect(img):
    fig, ax = plt.subplots()
    imgplot = ax.imshow(img)

def load_celeb_a_data(df_name,pkl_name,target,neg=-1):
    df=pd.read_csv(df_name)
    Y=df[target].values.astype(np.int8)
    X= pkl.load(open(pkl_name, 'rb'))
    X=X[...,::-1].astype(np.float32)
    X=2*X-1
    Y=Y.reshape(Y.shape[0],1)
    if neg==0:
        Y[Y==-1]=0
    return X,Y

def simple_generator_aug(batch_size,X,Y,trans,aug=False,shuffle=True,cat=False):
    #Y_ix=np.arange(Y.shape[0])
    index=0
    while True:
        batch_x = np.zeros(((batch_size,)+(X[0].shape)), dtype=np.float32)
        batch_y=np.zeros((batch_size,Y.shape[1]), dtype=np.float32)
        if shuffle:
            ind=np.random.randint(0,Y.shape[0],size=batch_size)#np.random.choice(Y_ix,size=batch_size,replace=False)
        else:
            ind=np.arange(index,index+batch_size)
            if index+2*batch_size>Y.shape[0]:
                index=0
            else:
                index=index+batch_size
        batch_x[:]=X[ind]
        if aug:
            for i in range(batch_size):
                batch_x[i]=trans.random_transform(batch_x[i])
        batch_y[:]=Y[ind]
        if cat:
            batch_y = tf.keras.utils.to_categorical(batch_y,2)
        yield  batch_x, batch_y
    
def celeba_glass_generator(batch_size,path,size,neg=-1,cat=False):
    trans=ImageTransformer(rotation_range=7,
                           zoom_range=0.01,
                           fill_mode='nearest',
                           flip_horizontal=True,
                           height_shift_range=0.07,
                           width_shift_range=0.07)
    X_train,Y_train=load_celeb_a_data(path+'celeba_glasses_train.csv',
                                      path+'celeba_glasses_train_'+str(size)
                                      +"x"+str(size)+'.pkl',
                                      'Eyeglasses',neg=neg)
    print('size train :',X_train.shape[0],'freq',Y_train[Y_train==1].shape[0]/X_train.shape[0])
    X_test,Y_test=load_celeb_a_data(path+'celeba_glasses_test.csv',
                                      path+'celeba_glasses_test_'+str(size)
                                      +"x"+str(size)+'.pkl',
                                      'Eyeglasses',neg=neg)
    print('size train :',X_test.shape[0],Y_test[Y_test==1].shape[0]/X_test.shape[0])
    dtset = {'train' : simple_generator_aug(batch_size,X_train,Y_train,trans,aug=True,cat=cat),
             'trainSize': None ,
             'valid' : simple_generator_aug(batch_size,X_test,Y_test,trans,aug=False,cat=cat), 'validSize': 640, 
             'test' :simple_generator_aug(batch_size,X_test,Y_test,trans,aug=False,cat=cat),
             'testSize': X_test.shape[0], 
             'batch_size': batch_size, 'src_dataset': (X_train,Y_train) }
    return dtset


if __name__ == "__main__":
    dtset=celeba_glass_generator(batch_size=32,
                                path='/data/',
                                size=64,cat=True,neg=0)
    X,Y=next(dtset['test'])
    display_img_vect(X[0])
    print(Y)
    