import numpy as np
from deel.datasets.util_generator_dataset import simple_generator,simple_dataset_generator
import tensorflow as tf

def _build_adv(X, model = None, batch_adv = 128,scale=None,add = 0., coeff_grad = 1.1):
    X_adv = []
    
    start = 0
    while (start < X.shape[0]):

        inp = tf.Variable(X[start:(start+batch_adv)], dtype= tf.float32)
        with tf.GradientTape() as tape:
                tape.watch(inp)
                prediction=model(inp,training = False)
        grads = tape.gradient(prediction, inp)
        #print(grads.shape)
        coeff = add+coeff_grad*tf.random.uniform(prediction.shape)
        #coeff = tf.random.uniform(prediction.shape)
        #coeff = tf.fill(prediction.shape,0.99)
        norm = tf.expand_dims(tf.norm(tf.reshape(grads, [grads.shape[0],-1]),axis = 1),axis = 1)
        modif = coeff*prediction/(norm*norm)*tf.reshape(grads, [grads.shape[0],-1])
        #print(dir_coeff.shape,(dir_coeff*tf.reshape(grads, [grads.shape[0],-1])).shape)
        new_v = inp -tf.reshape(modif,inp.shape)
        #print(new_v.shape)
        new_v = new_v.numpy()
        if scale is not None:
           new_v = np.clip(new_v, scale[0], scale[1])
        X_adv.append(new_v)
        start+=batch_adv
    return np.concatenate(X_adv)

def get_random_distribution(X, X_prev_vs = None, model = None,change = 0.1, batch_adv = 128, scale=None, coeff_grad = 1.1,add = 0.):
    nb = X.shape[0]
    shape = X.shape[1:]
    if scale is None:
        min = X.min()
        max = X.max()
    else :
        min,max = scale
    Y = np.random.uniform(low=min,high=max, size=X.shape)
    
        
    if X_prev_vs is not None:
        
        X_prev_vs = X_prev_vs.copy()
        nb_diff = int(nb*change)
        print("change ",nb_diff)
        #ind=np.random.choice(nb//2, nb_diff,replace = False)
        ind=np.random.choice(nb, nb_diff,replace = False)
        X_prev_vs[ind] = Y[ind]
        Y = X_prev_vs
    #print()
    if model is not None:
        Y2 =_build_adv(Y,model = model,batch_adv = batch_adv, scale = scale,add = add, coeff_grad = coeff_grad)
        Y = np.clip(Y,min,max)
    return Y

def class_versus_random(X,batch_size = 16,X_prev_vs = None, model = None, change = 0.1,batch_adv = 128, scale = None,coeff_grad = 1.1,add = 0.):
    X_vs = get_random_distribution(X,model = model, X_prev_vs = X_prev_vs,change = change, batch_adv = batch_adv, scale=scale,add = add,coeff_grad=coeff_grad)
    Y = np.full(X_vs.shape[0]*2, 1)
    Y[X_vs.shape[0]:] = 0
    res = np.concatenate((X,X_vs))
    dtset = {'train' : simple_dataset_generator(batch_size,res,Y.reshape(Y.shape[0], 1),shuffle = True), 'trainSize': None ,
             #'valid' : None, 'validSize': None, 
             #'test' : simple_generator(batch_size,X_test,Y_test.reshape(Y.shape[0], 1)), 'testSize': None,
             'batch_size': batch_size, 'src_dataset': (res,Y) ,
             'curent_vs':X_vs }
    return dtset
