from design_baselines.utils import spearman
from design_baselines.utils import soft_noise
from design_baselines.utils import cont_noise
from collections import defaultdict
from tensorflow_probability import distributions as tfpd
import tensorflow as tf



class MaximumLikelihood(tf.Module):

    def __init__(self,
                 forward_model,
                 forward_model_optim=tf.keras.optimizers.Adam,
                 forward_model_lr=0.001,
                 noise_std=0.0):
        """Build a trainer for an ensemble of probabilistic neural networks
        trained on bootstraps of a dataset

        Args:

        oracles: List[tf.keras.Model]
            a list of keras model that predict distributions over scores
        oracle_optim: __class__
            the optimizer class to use for optimizing the oracle model
        oracle__lr: float
            the learning rate for the oracle model optimizer
        """

        super().__init__()
        self.fm = forward_model
        self.optim = forward_model_optim(
            learning_rate=forward_model_lr)
        self.noise_std = noise_std

    
    @tf.function(experimental_relax_shapes=True)
    def train_step(self,
                   x,
                   y,
                   b,
                   len_x_data,
                   eta_lambda_,
                   epsilon,
                   lambda_, 
                   rho, 
                   r):
        """Perform a training step of gradient descent on an ensemble
        using bootstrap weights for each model in the ensemble

        Args:

        x: tf.Tensor
            a batch of training inputs shaped like [batch_size, channels]
        y: tf.Tensor
            a batch of training labels shaped like [batch_size, 1]
        b: tf.Tensor
            bootstrap indicators shaped like [batch_size, num_oracles]

        Returns:

        statistics: dict
            a dictionary that contains logging information
        """

        # corrupt the inputs with noise
        x0 = cont_noise(x, self.noise_std)

        statistics = dict()

        # Compute gradient of original loss \nabla_\beta L_1 (B)
        with tf.GradientTape() as tape:
            d = self.fm.get_distribution(x0, training=True)
            nll = -d.log_prob(y)

            # evaluate how correct the rank fo the model predictions are
            rank_correlation = spearman(y[:, 0], d.mean()[:, 0])

            # model loss that combines maximum likelihood
            model_loss = nll

            # build the total and lagrangian losses
            denom = tf.reduce_sum(b)
            total_loss = tf.math.divide_no_nan(
                tf.reduce_sum(b * model_loss), denom)

        original_loss_grads = tape.gradient(total_loss, self.fm.trainable_variables)

        # Compute gradient \nabla_\beta G_B (\beta)
        with tf.GradientTape() as tape2:
            tape2.watch(self.fm.trainable_variables)
            forward_model_prediction = self.fm.get_distribution(x, training=True).mean()
            sum_forward_model_prediction = tf.reduce_sum(forward_model_prediction)/len_x_data
        dG_beta_dbeta = tape2.gradient(sum_forward_model_prediction, self.fm.trainable_variables)
        
        # Norm of gradient ||\nabla_\beta G_B (\beta)||
        norm_dG_beta_dbeta = tf.linalg.global_norm(dG_beta_dbeta)

        # Perturb surrogate model \beta^
        for i in range(len(self.fm.trainable_variables)):
            if dG_beta_dbeta[i] is not None:
                self.fm.trainable_variables[i].assign_add(r*dG_beta_dbeta[i]/norm_dG_beta_dbeta)
        
        # Compute gradient \nabla_\beta G_B (\beta^)
        with tf.GradientTape() as tape3:
            tape3.watch(self.fm.trainable_variables)
            perturbed_model_prediction = self.fm.get_distribution(x, training=True).mean()
            sum_perturbed_model_prediction = tf.reduce_sum(perturbed_model_prediction)/len_x_data
        dG_beta_hat_dbeta = tape3.gradient(sum_perturbed_model_prediction, self.fm.trainable_variables)
        
        # Reverse surrogate model
        for i in range(len(self.fm.trainable_variables)):
            if dG_beta_dbeta[i] is not None:
                self.fm.trainable_variables[i].assign_sub(r*dG_beta_dbeta[i]/norm_dG_beta_dbeta)
        
        # Combine the utimate gradient
        for i in range(len(self.fm.trainable_variables)):
            if dG_beta_dbeta[i] is not None:
                original_loss_grads[i] = original_loss_grads[i] + lambda_*rho/r*(dG_beta_hat_dbeta[i]-dG_beta_dbeta[i])

        self.optim.apply_gradients(zip(original_loss_grads, self.fm.trainable_variables))
        new_lambda = lambda_ + eta_lambda_*(rho*norm_dG_beta_dbeta - epsilon)

        statistics[f'train/nll'] = nll
        statistics[f'train/rank_corr'] = rank_correlation

        return statistics, new_lambda

    @tf.function(experimental_relax_shapes=True)
    def validate_step(self,
                      x,
                      y):
        """Perform a validation step on an ensemble of models
        without using bootstrapping weights

        Args:

        x: tf.Tensor
            a batch of validation inputs shaped like [batch_size, channels]
        y: tf.Tensor
            a batch of validation labels shaped like [batch_size, 1]

        Returns:

        statistics: dict
            a dictionary that contains logging information
        """

        # corrupt the inputs with noise
        x0 = cont_noise(x, self.noise_std)

        statistics = dict()

        # calculate the prediction error and accuracy of the model
        d = self.fm.get_distribution(x0, training=False)
        nll = -d.log_prob(y)

        # evaluate how correct the rank fo the model predictions are
        rank_correlation = spearman(y[:, 0], d.mean()[:, 0])

        statistics[f'validate/nll'] = nll
        statistics[f'validate/rank_corr'] = rank_correlation

        return statistics

    def train(self,
              eta_lambda_,
              epsilon,
              lambda_, 
              rho, 
              r,
              dataset):
        """Perform training using gradient descent on an ensemble
        using bootstrap weights for each model in the ensemble

        Args:

        dataset: tf.data.Dataset
            the training dataset already batched and prefetched

        Returns:

        loss_dict: dict
            a dictionary mapping names to loss values for logging
        """

        statistics = defaultdict(list)
        for x, y, b in dataset:
            len_x_data = x.shape[0]
            stat, new_lambda = self.train_step(x, y, b, len_x_data, eta_lambda_, epsilon, lambda_, rho, r)
            lambda_ = new_lambda

            for name, tensor in stat.items():
                statistics[name].append(tensor)
            
        for name in statistics.keys():
            statistics[name] = tf.concat(statistics[name], axis=0)
        return statistics, lambda_

    def validate(self,
                 dataset):
        """Perform validation on an ensemble of models without
        using bootstrapping weights

        Args:

        dataset: tf.data.Dataset
            the validation dataset already batched and prefetched

        Returns:

        loss_dict: dict
            a dictionary mapping names to loss values for logging
        """

        statistics = defaultdict(list)
        for x, y in dataset:
            for name, tensor in self.validate_step(x, y).items():
                statistics[name].append(tensor)
        for name in statistics.keys():
            statistics[name] = tf.concat(statistics[name], axis=0)
        return statistics

    def launch(self,
               eta_lambda_,
               epsilon,
               lambda_,
               rho,
               r,
               train_data,
               validate_data,
               logger,
               epochs,
               start_epoch=0,
               header=""):
        """Launch training and validation for the model for the specified
        number of epochs, and log statistics

        Args:

        train_data: tf.data.Dataset
            the training dataset already batched and prefetched
        validate_data: tf.data.Dataset
            the validation dataset already batched and prefetched
        logger: Logger
            an instance of the logger used for writing to tensor board
        epochs: int
            the number of epochs through the data sets to take
        """

        for e in range(start_epoch, start_epoch + epochs):
            stat, new_lambda_ = self.train(eta_lambda_,epsilon, lambda_, rho, r, train_data)
            lambda_ = new_lambda_
            for name, loss in stat.items():
                logger.record(header + name, loss, e)
            for name, loss in self.validate(validate_data).items():
                logger.record(header + name, loss, e)

    def get_saveables(self):
        """Collects and returns stateful objects that are serializeable
        using the tensorflow checkpoint format

        Returns:

        saveables: dict
            a dict containing stateful objects compatible with checkpoints
        """

        saveables = dict()
        for i in range(self.bootstraps):
            saveables[f'forward_model'] = self.fm
            saveables[f'forward_model_optim'] = self.optim
        return saveables


class VAETrainer(tf.Module):

    def __init__(self,
                 vae,
                 vae_optim=tf.keras.optimizers.Adam,
                 vae_lr=0.001, beta=1.0):
        """Build a trainer for an ensemble of probabilistic neural networks
        trained on bootstraps of a dataset

        Args:

        oracles: List[tf.keras.Model]
            a list of keras model that predict distributions over scores
        oracle_optim: __class__
            the optimizer class to use for optimizing the oracle model
        oracle__lr: float
            the learning rate for the oracle model optimizer
        """

        super().__init__()
        self.vae = vae
        self.beta = beta

        # create optimizers for each model in the ensemble
        self.vae_optim = vae_optim(learning_rate=vae_lr)

    @tf.function(experimental_relax_shapes=True)
    def train_step(self,
                   x):
        """Perform a training step of gradient descent on an ensemble
        using bootstrap weights for each model in the ensemble

        Args:

        x: tf.Tensor
            a batch of training inputs shaped like [batch_size, channels]

        Returns:

        statistics: dict
            a dictionary that contains logging information
        """

        statistics = dict()

        with tf.GradientTape() as tape:

            latent = self.vae.encode(x, training=True)
            z = latent.mean()
            prediction = self.vae.decode(z)

            nll = -prediction.log_prob(x)

            kld = latent.kl_divergence(
                tfpd.MultivariateNormalDiag(
                    loc=tf.zeros_like(z), scale_diag=tf.ones_like(z)))

            total_loss = tf.reduce_mean(
                nll) + tf.reduce_mean(kld) * self.beta

        variables = self.vae.trainable_variables

        self.vae_optim.apply_gradients(zip(
            tape.gradient(total_loss, variables), variables))

        statistics[f'vae/train/nll'] = nll
        statistics[f'vae/train/kld'] = kld

        return statistics

    @tf.function(experimental_relax_shapes=True)
    def validate_step(self,
                      x):
        """Perform a validation step on an ensemble of models
        without using bootstrapping weights

        Args:

        x: tf.Tensor
            a batch of validation inputs shaped like [batch_size, channels]

        Returns:

        statistics: dict
            a dictionary that contains logging information
        """

        statistics = dict()

        latent = self.vae.encode(x, training=True)
        z = latent.mean()
        prediction = self.vae.decode(z)

        nll = -prediction.log_prob(x)

        kld = latent.kl_divergence(
            tfpd.MultivariateNormalDiag(
                loc=tf.zeros_like(z), scale_diag=tf.ones_like(z)))

        statistics[f'vae/validate/nll'] = nll
        statistics[f'vae/validate/kld'] = kld

        return statistics

    def train(self,
              dataset):
        """Perform training using gradient descent on an ensemble
        using bootstrap weights for each model in the ensemble

        Args:

        dataset: tf.data.Dataset
            the training dataset already batched and prefetched

        Returns:

        loss_dict: dict
            a dictionary mapping names to loss values for logging
        """

        statistics = defaultdict(list)
        for x, y in dataset:
            for name, tensor in self.train_step(x).items():
                statistics[name].append(tensor)
        for name in statistics.keys():
            statistics[name] = tf.concat(statistics[name], axis=0)
        return statistics

    def validate(self,
                 dataset):
        """Perform validation on an ensemble of models without
        using bootstrapping weights

        Args:

        dataset: tf.data.Dataset
            the validation dataset already batched and prefetched

        Returns:

        loss_dict: dict
            a dictionary mapping names to loss values for logging
        """

        statistics = defaultdict(list)
        for x, y in dataset:
            for name, tensor in self.validate_step(x).items():
                statistics[name].append(tensor)
        for name in statistics.keys():
            statistics[name] = tf.concat(statistics[name], axis=0)
        return statistics

    def launch(self,
               train_data,
               validate_data,
               logger,
               epochs):
        """Launch training and validation for the model for the specified
        number of epochs, and log statistics

        Args:

        train_data: tf.data.Dataset
            the training dataset already batched and prefetched
        validate_data: tf.data.Dataset
            the validation dataset already batched and prefetched
        logger: Logger
            an instance of the logger used for writing to tensor board
        epochs: int
            the number of epochs through the data sets to take
        """

        for e in range(epochs):
            for name, loss in self.train(train_data).items():
                logger.record(name, loss, e)
            for name, loss in self.validate(validate_data).items():
                logger.record(name, loss, e)
