import os
import shutil
import sys
os.environ["GIT_PYTHON_REFRESH"] = "quiet"
sys.path.append('./')
#sys.path.append('../')
import wandb
from wandb.keras import WandbCallback
import time
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import tensorflow as tf
from tensorflow.keras.callbacks import TensorBoard
from deel.lip.activations import GroupSort2, FullSort
from deel.lip.losses import KR, HKR
from deel.lip.custom_losses import wasserstein_acc
from tensorflow.keras.optimizers import Adam, SGD
import numpy as np
import seaborn as sns
from deel.utils.yaml_to_params import load_yaml_config,getParams, getFunctionFromModules, dumdict2yaml
from deel.datasets.load_dataset import load_dataset
from deel.utils.yaml_loader import load_model, loadFunctionList
from deel.utils.yaml_loader import load_optimizer_and_loss
from tensorflow.keras.callbacks import ReduceLROnPlateau, LearningRateScheduler
from tensorflow.keras import backend as K
import foolbox as fb
from foolbox.attacks import *
from deel.utils.adversarial_utils import compute_adversarial_robustness,wandb_log_robustness
from tensorflow.keras.losses import categorical_crossentropy, binary_crossentropy
from tensorflow.keras.metrics import top_k_categorical_accuracy,categorical_accuracy
import yaml
import matplotlib.pyplot as plt
#from tensorflow_riemopt.variable import assign_to_manifold
#import tensorflow_riemopt as riemopt
#from tensorflow_riemopt.manifolds import StiefelCayley,Euclidean
#from tensorflow_riemopt.optimizers.riemannian_adam import RiemannianAdam
from deel.lip.layers import (
    SpectralConv2D,
    SpectralDense,
    FrobeniusDense,
    ScaledAveragePooling2D,
    ScaledL2NormPooling2D,
    InvertibleDownSampling,
    ScaledGlobalAveragePooling2D)
from deel.utils.lip_utils import rescale_grad_unit,redressage_grad

def init_config(config):
    name = config['run_name']
    rep = config['result_path']
    folder = rep + name + "/"
    if not os.path.exists(folder):
        os.makedirs(folder)
    #if not os.path.exists(folder + "log"):
    #    os.makedirs(folder + "log")
    if not os.path.exists(folder + "curves"):
        os.makedirs(folder + "curves")
    if not os.path.exists(folder + "models"):
        os.makedirs(folder + "models")
    with open(folder + 'config.yml', 'w') as file:
        documents = yaml.dump(config, file)

def save_model(model, config):
    name = config['run_name']
    rep = config['result_path']
    folder = rep + name + "/models/"
    #model = model.vanilla_export()
    #model_json = model.to_json()
    #with open(os.path.join(folder, name + ".json"), "w") as json_file:
    #    json_file.write(model_json)
        # serialize weights to HDF5
    model.save_weights(os.path.join(folder, name + ".h5"))
    #if not os.path.exists(folder + name):
    #    os.makedirs(folder + name)
    #model.save_weights(folder)






def save_metrics(model,loss_fct,dtset,config):
    #model_vanilla = model.vanilla_export()
    #model_vanilla.compile(loss=loss_fct, optimizer=Adam())
    X,Y = dtset['test_XY']
    results = model.evaluate(X, Y, batch_size=128)
    wandb.log({'test_accuracy': results[-1]})
    attack_fct = FGM()
    df = compute_adversarial_robustness(model, attack_fct, test_gen=dtset['test'], test_size=8192,
                               batch_size = dtset['batch_size'],
                               last_layer = -1,
                               force_recompute=False,
                               bounds=[-2,2],
                               eps=[0,0.1,0.2,0.3,0.4,0.5,0.6,0.7,0.8,1,2,3,4,5,6,7 ,8])
    
    name = config['run_name']
    rep = config['result_path']
    folder = rep + name+"/"
    df.to_csv(folder+"robustness_fgsm.csv")
    wandb_log_robustness(df)


def gradient_norm(grads):
    total = 0
    nb = 0
    for g in grads:
        # print(np.linalg.norm(grads).shape)
        total += tf.norm(g)
        nb += 1
    return total / nb


def diff_grad(grad_1, grad_2):
    total_dif = 0
    for g1, g2 in zip(grad_1, grad_2):
        dif = (g2 - g1)
        dif = tf.math.abs(dif)
        total_dif += tf.reduce_sum(dif)
    return total_dif
    # print("diff",total_dif)


def apply_constraints(model, verbose=False):
    for w in model.layers:
        if ((type(w) is SpectralConv2D) or (type(w) is FrobeniusDense)):
            w.kernel.assign(w.kernel_constraint.__call__(w.kernel))
            if verbose:
                print(w, w.kernel_constraint)


def compute_singular(var, one=False):
    for w in var:
        w_val = K.get_value(w)

        w = w_val.reshape((-1, w_val.shape[-1]))
        _, coeff, _ = np.linalg.svd(w)
        print(w_val.shape, coeff[0])
        if one:
            break


@tf.function
def train_margin(x, y, model, loss_fn, t, optim_marg):
    with tf.GradientTape() as w_tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
        regul = add_model_regularizer_loss(model)
        final_loss = loss_value+ regul

    grads = w_tape.gradient(final_loss, [t])

    optim_marg.apply_gradients(zip(grads, [t]))
    t.assign(tf.clip_by_value(t, clip_value_min=0.0001, clip_value_max=200))


    acc = categorical_accuracy(y, logits)

    return loss_value, acc, gradient_norm(grads),regul


def add_model_regularizer_loss(model, lambda_orth = 0):
    loss=0
    for l in model.layers:
        if hasattr(l,'layers') and l.layers: # the layer itself is a model
            loss+=add_model_loss(l)
        if hasattr(l,'kernel_regularizer') and l.kernel_regularizer and lambda_orth!=0:
            loss+=lambda_orth*l.kernel_regularizer(l.kernel)
        if hasattr(l,'bias_regularizer') and l.bias_regularizer:
            loss+=l.bias_regularizer(l.bias)
        #if hasattr(l,'activity_regularizer') and l.bias_regularizer:
        #    loss+=l.bias_regularizer(l.bias)
    loss += tf.reduce_sum(model.losses)
    return loss




@tf.function
def train_step(x, y, model, loss_fn, t, optimizer, optim_prox = None,optim_margin=False,
                    redress =False,lambda_orth = 0):
    with tf.GradientTape() as w_tape:
        logits = model(x, training=True)
        loss_value = loss_fn(y, logits)
        regul = add_model_regularizer_loss(model,lambda_orth = lambda_orth)
        final_loss = loss_value+ regul
    weights = model.trainable_weights
   
    if optim_margin:
        weights = weights+[t]
    grads = w_tape.gradient(final_loss, weights)
    if redress:
        if optim_margin:
            grads = redressage_grad(grads[:-1], weights[:-1],coeff = 0.1)
        else :
            grads = redressage_grad(grads, weights,coeff = 0.1)
    
    optimizer.apply_gradients(zip(grads, weights))

    acc = categorical_accuracy(y, logits)
    
    #tf.print("loss value and grads",loss_value,gradient_norm(grads),output_stream=sys.stdout)
    y_pred = logits
    y_true = y
    results = {"loss" :loss_value,"acc" :acc, "regul" :regul,"grad_norm" : gradient_norm(grads)  }
   
   
    H1 = tf.where(y_true==1,tf.reduce_min(y_pred), y_pred) ## set y_true at minimum on batch to avoid being the max
    vYtrue = tf.reduce_sum(y_pred * y_true, axis=1)
    maxOthers = tf.reduce_max(H1, axis=1)
    results["robustness"] = tf.reduce_mean(vYtrue)
    results["avg_value"] = tf.reduce_mean(tf.abs(y_pred))
    results["abs_margin"] = tf.reduce_mean(tf.abs(vYtrue-maxOthers))
    results["margin"] = tf.reduce_mean(vYtrue-maxOthers)
    results["margin_std"] = tf.math.reduce_std(vYtrue-maxOthers)
    return results


def fit_constraints(model, train,validation, loss_fct, optimizer, steps_per_epoch=50,
                    validation_step=50,
                    callbacks=[],
                    epochs=20,
                    verbose=2,
                    redress =False,
                    optim_margin=False,
                    margin_only=False,
                    margin_lr = 1.e-4,
                    optim_prox = None,
                   lambda_orth = 0):
    for c in callbacks:
        c.set_model(model)

    train_it = train.__iter__()
    val_it = validation.__iter__()
    model_vars = model.trainable_variables
    #compute_singular(model_vars, one=True)
    logs = {}
    for c in callbacks:
        c.on_train_begin(logs=logs)
   
    metrics = {}
    metrics["val_loss"] = tf.metrics.Mean()
    metrics["val_acc"] = tf.metrics.Mean()

    for e in range(epochs):
        start_time = time.time()
        #tf.print('epoch :',e,output_stream=sys.stdout)
        for c in callbacks:
            c.on_epoch_begin(e, logs=None)

        
        #model.condense()
        for batch in range(steps_per_epoch):
            #model.condense()
            x, y = next(train_it)
            for c in callbacks:
                c.on_batch_begin(batch, logs=None)
                #c.on_train_batch_begin(batch, logs=None)
            #
            if margin_only:
                loss_value,  acc, g_n,reg = train_margin(x, y, model, loss_fct, loss_fct.margins, optim_marg)
            else:
                results= train_step(x, y, model, loss_fct, loss_fct.margins, optimizer,redress =redress, optim_prox = optim_prox, optim_margin=optim_margin,lambda_orth = lambda_orth)

            logs ={}
            for k in results.keys():
                logs[k] = results[k].numpy().mean()
           
            #print(e, batch,"/",steps_per_epoch,"  ",logs)
            for c in callbacks:
                #c.on_train_batch_end(batch, logs=logs)
                c.on_batch_end(batch, logs=logs)
            for k in results.keys():
                if k not in metrics:
                    metrics[k] = tf.metrics.Mean()
                metrics[k].update_state(results[k])
          

        for batch in range(validation_step):
            x, y = next(val_it)
            logits = model(x, training=False)
            loss_value = loss_fct(y, logits)
            acc = categorical_accuracy(y, logits)
            metrics["val_loss"].update_state(loss_value.numpy())
            metrics["val_acc"].update_state(acc.numpy().mean())

            
        #apply_constraints(model)
        total_time =time.time() - start_time
        logs = {k: metrics[k].result() for k in metrics.keys()}
        logs ['time']  = total_time 
        
        for c in callbacks:
            c.on_epoch_end(e, logs=logs)
            
            
        print(f"Epoch {e + 1}/{epochs}")
        print(f"time : {total_time:.2f}s *** loss:",metrics["loss"].result(),"val_loss:",metrics["val_loss"].result()," acc: ",metrics["acc"].result(),"% val_acc:",metrics["val_acc"] )
        print("max m", tf.reduce_max(loss_fct.margins).numpy(), "min m", tf.reduce_min(loss_fct.margins).numpy(),
              "men m", tf.reduce_mean(loss_fct.margins).numpy())
        for k in metrics.keys():
            metrics[k].reset_states()

    for c in callbacks:
        c.on_train_end(logs=logs)


tf.config.set_soft_device_placement(True)
print(tf.config.list_physical_devices('GPU'))
#policy = tf.keras.mixed_precision.experimental.Policy('mixed_float16')
#tf.keras.mixed_precision.experimental.set_policy(policy) 
#K.set_floatx('float16')
filename="./configs/fashion_mnist_base.yml"
if len(sys.argv)>=2:
        filename =sys.argv[1]

full_config = load_yaml_config(filename)
init_config(full_config)
print(filename,full_config['run_name'])
os.environ["WANDB_API_KEY"] = "b43035018da7e61f798d9bf228eeb6c141debe26"
os.environ["WANDB_MODE"] = "online"
project_name = "cifar-rieman"
if 'project_name' in full_config:
    project_name = full_config['project_name']
print(project_name)
wandb.init(project=project_name, sync_tensorboard=False, group=full_config['run_name'],
          config={
    "loss" : full_config['loss']['type'],
    "loss_param" : str(full_config['loss']['params']),
    "loss" : full_config['loss']['type'],
    "dataset": "CIFAR-10"
})
name = full_config['run_name']

rep = full_config['result_path']
file = rep + name + "/logfile.log"
sys.stdout = open(file, "w")
sys.stdout.flush()
dtset = load_dataset(getParams(full_config,'dataset'))
batch_size = dtset['batch_size']
model = load_model(getParams(full_config,'network'))
if "weights" in full_config.keys():
    expe_name = full_config["weights"]["expe_name"]
    rep_mod = "results/"+expe_name+"/models/"
    model.load_weights(rep_mod+expe_name+'.h5')
optimizer = load_optimizer_and_loss(getParams(full_config,'optimizer'))
callbacks = loadFunctionList(getParams(full_config,'callbacks'))
loss_fct = load_optimizer_and_loss(getParams(full_config,'loss'))
metrics = loadFunctionList(getParams(full_config,'metrics'))
print(metrics)
sys.stdout.flush()
redress = full_config.get('redress', False)

print("redress",redress)

epochs=getParams(full_config,'epochs')
steps_per_epoch=getParams(full_config,'steps_per_epoch')

validation_steps = 100
print(callbacks[0])
sys.stdout.flush()
if dtset['testSize'] is not None:
    validation_steps=int(dtset['testSize']/dtset['batch_size'])
callbacks.append(WandbCallback(save_model=False))
#callbacks.append(TensorBoard(log_dir=wandb.run.dir))
train = tf.data.Dataset.from_generator(dtset['train'],(tf.float32, tf.float32)).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)
test = tf.data.Dataset.from_generator(dtset['test'],(tf.float32, tf.float32)).prefetch(buffer_size=tf.data.experimental.AUTOTUNE)

optim_margin = False
margin_only = False
margin_lr = 1.e-3


if "optim_margin" in full_config.keys():
    optim_margin = full_config['optim_margin']['margin']
    margin_lr = full_config['optim_margin']['margin_lr']
if "margin_only" in full_config.keys():
    margin_only = full_config['margin_only']
lambda_orth = 0
if "lambda_orth" in full_config.keys():
    lambda_orth = full_config['lambda_orth']
print('lambda_orth :',lambda_orth)
optim_prox = None
if "projected" in full_config.keys():
    optim_prox =  SGD(learning_rate=full_config["projected"]["learning_rate"])
    print("projected gradient", optim_prox.get_config())
#if type(optimizer) is RiemannianAdam:
#    for w in model.layers:
#        if type(w) is SpectralConv2D:
#            print(w,"caley")
#            assign_to_manifold(w.kernel,riemopt.manifolds.StiefelCayley())
#        elif type(w) is FrobeniusDense:
#            print(w, "euclidian")
#            assign_to_manifold(w.kernel, riemopt.manifolds.Euclidean())
model.compile(loss=loss_fct, optimizer=optimizer, metrics=metrics)
fit_constraints(model, train,test, loss_fct,optimizer, steps_per_epoch=steps_per_epoch// batch_size,
                    validation_step=validation_steps,
                    callbacks=callbacks,
                    epochs=epochs,
                    verbose=2,
                    redress =redress,
                    optim_prox = optim_prox,
                    optim_margin=optim_margin,
                    margin_only=margin_only,
                    margin_lr = 1.e-3,
                    lambda_orth=lambda_orth)

#hist=model.fit( ,
#                steps_per_epoch=steps_per_epoch// batch_size,
#                callbacks=callbacks,
#                validation_data=dtset['test'],
#                validation_steps=validation_steps,
#                epochs=epochs,verbose=2)
#save_metrics(model,loss_fct,dtset,full_config)
save_model(model, full_config)



#make_curves(full_config, hist)
wandb.finish()
