import os
import sys
import time
import yaml
import tensorflow as tf
import wandb
from wandb.keras import WandbCallback
from tensorflow.keras.metrics import (AUC, binary_accuracy,
                                      categorical_accuracy,
                                      top_k_categorical_accuracy)

from deel.datasets.oneclass_dataset import (class_versus_random,
                                            get_random_distribution)
from deel.utils.lip_utils import redressage_grad, rescale_grad_unit
import random
from deel.utils.yaml_loader import load_model, loadFunctionList
from deel.utils.lip_trainer import  fit_hkr, train_epoch_binary
from deel.utils.yaml_to_params import load_yaml_config,getParams, getFunctionFromModules, dumdict2yaml
from deel.utils.yaml_loader import load_optimizer_and_loss
from tensorflow.keras.callbacks import ReduceLROnPlateau, LearningRateScheduler
from tensorflow.keras import backend as K
from deel.utils.lip_utils import rescale_grad_unit,redressage_grad,gradient_norm, Logger,get_random_name,set_global_variable
from deel.utils.yaml_to_params import load_yaml_config,getParams, getFunctionFromModules, dumdict2yaml
from deel.datasets.load_dataset import load_dataset
from deel.utils.lip_res_model import set_add_coeff
from deel.lip.normalizers import set_stop_grad_spectral,set_grad_passthrough_bjorck

def init_config(config):
    tf.get_logger().setLevel('ERROR')
    tf.config.set_soft_device_placement(True)
    name = config['run_name']
    rep = config['result_path']
    folder = rep + name + "/"
    if not os.path.exists(folder):
        os.makedirs(folder)
    #if not os.path.exists(folder + "log"):
    #    os.makedirs(folder + "log")
    if not os.path.exists(folder + "plots"):
        os.makedirs(folder + "plots")
    if not os.path.exists(folder + "models"):
        os.makedirs(folder + "models")
    with open(folder + 'config.yml', 'w') as file:
        documents = yaml.dump(config, file)
    filename = rep + name + "/logfile.log"
    sys.stdout = Logger(filename)


    
    


class BinaryLipTrainer():
    def __init__(self,full_config):
        self.xp_id = get_random_name()
        self.full_config = full_config
        full_config["xp_id"] = self.xp_id    
        self.rescale_grads = full_config.get('rescale_grads', False)
        self.grad_coeff = full_config.get('grad_coeff', 0.1)
        self.optim_margin = full_config.get('optim_margin', False)
        self.lambda_orth = full_config.get('lambda_orth', 0.)
        self.wandb = full_config.get('wandb', False)
        
        self.add_coeff = full_config.get('add_coeff', 0.5)
        set_global_variable(full_config,verbose = True)


    def init_training(self,full_config):
        self.name = full_config['run_name']
        self.rep = full_config['result_path']
        folder = self.rep + self.name + "/"
        if not os.path.exists(folder):
            os.makedirs(folder)
    
        if not os.path.exists(folder + "curves"):
            os.makedirs(folder + "curves")
        if not os.path.exists(folder + "models"):
            os.makedirs(folder + "models")
        with open(folder +'config.yml', 'w') as file:
            documents = yaml.dump(full_config, file)
        filename =folder+self.xp_id+"_logfile.log"
        sys.stdout = Logger(filename)
        if self.wandb:
            os.environ["WANDB_API_KEY"] = "b43035018da7e61f798d9bf228eeb6c141debe26"
            os.environ["WANDB_MODE"] = "online"
            wandb.init(project = full_config['project_name'],
            group=full_config['run_name'],sync_tensorboard=False, 
            name= self.xp_id,
            config={"loss" : full_config['loss']['type'],
            "loss_param" : str(full_config['loss']['params'])
            })
    def load_model_and_parameters(self,full_config):
        self.model = load_model(getParams(full_config,'network'))
        if "weights" in full_config.keys():
            expe_name = full_config["weights"]["expe_name"]
            rep_mod = "results/"+expe_name+"/models/"
            self.model.load_weights(rep_mod+expe_name+'.h5')
        self.optimizer = load_optimizer_and_loss(getParams(full_config,'optimizer'))
        self.callbacks = loadFunctionList(getParams(full_config,'callbacks'))
        if self.wandb:
            self.callbacks.append(WandbCallback(save_model=False))
        self.loss_fct = load_optimizer_and_loss(getParams(full_config,'loss'))
        #self.metrics = loadFunctionList(getParams(full_config,'metrics'))
        self.metrics = None
        self.epochs=getParams(full_config,'epochs')
        self.steps_per_epoch=getParams(full_config,'steps_per_epoch')
        if hasattr(self.loss_fct, 'margins'):
            self.margin_variables = self.loss_fct.margins
        else :
            self.margin_variables = None
        self.model.compile(loss=self.loss_fct, optimizer=self.optimizer, metrics=self.metrics)
        
        return self.model
    def save_model(self):
        folder = self.rep + self.name + "/models/"
        self.model.save_weights(os.path.join(folder, self.name + ".h5"))
        self.model.save_weights(os.path.join(folder, self.name + "_full.h5"))
    @tf.function
    def _train_step(self,x, y):
        #tf.print(x,y)
        with tf.GradientTape() as w_tape :
            logits = self.model(x, training=True)
            loss_value = self.loss_fct(y, logits)
            if self.lambda_orth != 0:
                regul = self.lambda_orth*tf.reduce_sum(self.model.losses)
            else :
                regul = 0
            final_loss = loss_value+ regul

        weights = self.model.trainable_weights
        #tf.print(loss_value.dtype,loss_value)
        #tf.print(logits.dtype,logits )
        #tf.print("regul",regul,lambda_orth)
        if self.optim_margin:
            weights = weights+[self.margin_variables]
        grads = w_tape.gradient(final_loss, weights)
        if self.rescale_grads:
            if self.optim_margin:
                grads[:-1] = redressage_grad(grads[:-1], weights[:-1],coeff = grad_coeff,spectral = self.spectral)
            else :
                grads = redressage_grad(grads, weights,coeff = self.grad_coeff,spectral = self.spectral)
        self.optimizer.apply_gradients(zip(grads, weights))


        #tf.print("final",diff_grad(saved_weights,new_weights),diff_grad(saved_weights,last_weights))
        acc=binary_accuracy(y,logits,threshold=0.)
        #tf.print(loss_value,tf.reduce_mean(top_k),tf.reduce_mean(acc))
        #train_acc_metric.update_state(y, logits)

        results = {"loss" :loss_value,"accuracy" :acc, "regul" :regul,"grad_norm" : gradient_norm(grads)}
        y_pred = logits
        y_true = y

        H1 = tf.where(y_true==1,tf.reduce_min(y_pred), y_pred) ## set y_true at minimum on batch to avoid being the max
        vYtrue = tf.reduce_sum(y_pred * y_true, axis=1)
        maxOthers = tf.reduce_max(H1, axis=1)
        results["robustness"] = tf.reduce_mean(vYtrue)
        results["avg_value"] = tf.reduce_mean(tf.abs(y_pred))
        results["abs_margin"] = tf.reduce_mean(tf.abs(vYtrue-maxOthers))
        results["margin"] = tf.reduce_mean(vYtrue-maxOthers)
        results["margin_std"] = tf.math.reduce_std(vYtrue-maxOthers)
        return results
    

    def _train_epoch(self,e,
                    train_it,
                    val_it,
                    metrics,
                    steps_per_epoch,
                    validation_step):
        start_time = time.time()
        logs ={}
        for batch in range(steps_per_epoch):
            x,y = next(train_it)


            for c in self.callbacks:
                c.on_batch_begin(batch, logs=None)
                #c.on_train_batch_begin(batch, logs=None)


            results= self._train_step(x, y)
            logs ={}
            #print(results)
            for k in results.keys():
                #print(k)
                logs[k] = results[k].numpy().mean()
            for c in self.callbacks:
                c.on_train_batch_end(batch, logs=logs)
            for k in results.keys():
                if k not in metrics:
                    metrics[k] = tf.metrics.Mean()
                metrics[k].update_state(results[k])

        total_time =time.time() - start_time
        #apply_constraints(model)
        
        if val_it is not None:
            auc = AUC()
            for batch in range(validation_step):
                x, y = next(val_it)
                logits = self.model(x, training=False)
                loss_value = self.loss_fct(y, logits)
                #tf.print(logits)
                metrics["AUC"].update_state(y, tf.math.sigmoid(logits))
                acc = binary_accuracy(y,logits,threshold = 0.)
                metrics["val_loss"].update_state(loss_value.numpy())
                metrics["val_acc"].update_state(acc.numpy())
        logs = {k: metrics[k].result() for k in metrics.keys()}
        logs ['time']  = total_time       

        for c in self.callbacks:
            c.on_epoch_end(e, logs=logs)
        #print("max m",tf.reduce_max(loss_fct.margins).numpy(),"min m",tf.reduce_min(loss_fct.margins).numpy(), "men m",tf.reduce_mean(loss_fct.margins).numpy())
        for k in metrics.keys():
            metrics[k].reset_states()
        return logs    
    
    def fit(self,dtset):
        train = dtset['train']
        train_it = train.__iter__()
        validation = dtset['test']
        for c in self.callbacks:
            c.set_model(self.model)
        if validation is not None:
            val_it = validation.__iter__()
        else :
            val_it = None
        model_vars = self.model.trainable_variables
        logs = {}
        for c in self.callbacks:
            c.on_train_begin(logs=logs)
        metrics = {}
        metrics["val_loss"] = tf.metrics.Mean()
        metrics["val_acc"] = tf.metrics.Mean()
        metrics["AUC"] = tf.metrics.Mean()
        if hasattr(self.loss_fct, 'margins'):
            margin_variables = self.loss_fct.margins
        else :
            margin_variables = None
        batch_size = dtset['batch_size']
        if dtset['testSize'] is not None:
            validation_steps=int(dtset['testSize']/batch_size)
        else :
            validation_steps = 0
        for e in range(self.epochs):
            logs = self._train_epoch(e,train_it,val_it,metrics,self.steps_per_epoch//batch_size,validation_steps)
            #print(logs)
            print(f"time : {logs['time']:.2f}s *** loss: {logs['loss'].numpy():0.3f}","accuracy:",logs["accuracy"].numpy(), end = " ")
            if margin_variables is not None:
                print("margin",self.loss_fct.margins.numpy(),end = " ")
            if val_it is not None:
                print("val accuracy", logs["val_acc"].numpy()," val AUC",logs ['AUC'].numpy())
            #if 'AUC' in logs.keys():
            #    tf.print()
            sys.stdout.flush()
        for c in self.callbacks:
            c.on_train_end(logs=logs)




