import tensorflow as tf
import tree

from .base_algorithm import BaseAlgorithm


DEFAULT_OPTIMIZER = {
    'class_name': 'Adam',
    'config': {}
}


class TD0(BaseAlgorithm):
    def __init__(self,
                 V,
                 alpha=2e-3,
                 gamma=0.9,
                 optimizer_params=None):
        self.V = V
        self._alpha = alpha
        self._gamma = gamma
        optimizer_params = optimizer_params or DEFAULT_OPTIMIZER
        assert 'learning_rate' not in optimizer_params

        self._optimizer = tf.optimizers.get({
            'class_name': optimizer_params['class_name'],
            'config': {
                **optimizer_params['config'],
                'learning_rate': alpha,
            },
        })

    @tf.function(experimental_relax_shapes=True)
    def update_V(self, state_0s, actions, state_1s, rewards, terminals, rhos):
        rewards = tf.cast(rewards, self.V.model.dtype)

        V_s1 = self.V.values(state_1s)
        continuation_probs = self._gamma * (
            1.0 - tf.cast(terminals, tf.float32))
        target = rewards + continuation_probs * V_s1

        with tf.GradientTape() as tape:
            V_s0 = self.V.values(state_0s)
            V_losses = 0.5 * tf.losses.MSE(y_pred=V_s0, y_true=target)

        V_gradients = tape.gradient(V_losses, self.V.trainable_variables)
        self._optimizer.apply_gradients(
            zip(V_gradients, self.V.trainable_variables))

        tree.map_structure(
            lambda x: tf.debugging.check_numerics(x, 'V'),
            self.V.trainable_variables)

        return {
            'V_loss': tf.reduce_mean(V_losses)
        }
