import copy

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
import tree

from softlearning.models.feedforward import feedforward_model
from softlearning.models.utils import create_inputs
from softlearning.utils.tensorflow import apply_preprocessors
from softlearning import preprocessors as preprocessors_lib
from softlearning.utils.tensorflow import cast_and_concat

from softlearning.value_functions.vanilla import (
    create_ensemble_value_function,
    double_feedforward_Q_function,
    feedforward_Q_function,
)
from softlearning.value_functions.base_value_function import (
    StateActionValueFunction)


tfd = tfp.distributions


def random_prior_ensemble_feedforward_Q_function(N, *args, **kwargs):
    return create_ensemble_value_function(
        N, random_prior_feedforward_Q_function, *args, **kwargs)


def random_prior_feedforward_Q_function(input_shapes,
                                        *args,
                                        preprocessors=None,
                                        observation_keys=None,
                                        prior_loc=0.0,
                                        prior_scale=1.0,
                                        name='random_prior_feedforward_Q_function',
                                        **kwargs):
    inputs = create_inputs(input_shapes)

    if preprocessors is None:
        preprocessors = tree.map_structure(lambda _: None, inputs)

    preprocessors = tree.map_structure_up_to(
        inputs, preprocessors_lib.get, preprocessors)

    preprocessed_inputs = apply_preprocessors(preprocessors, inputs)

    model_inputs = tf.keras.layers.Lambda(
        cast_and_concat
    )(preprocessed_inputs)

    Q_predictor_model_out = feedforward_model(
        *args,
        output_shape=[1],
        name=f'{name}-predictor',
        **kwargs,
    )(model_inputs)

    assert 'trainable' not in kwargs or not kwargs['trainable']
    assert 'kernel_initializer' not in kwargs
    assert 'bias_initializer' not in kwargs

    Q_prior_kwargs = {
        **kwargs,
        'kernel_regularizer': None,
        'bias_regularizer': None,
    }

    Q_prior_model = feedforward_model(
        *args,
        output_shape=[1],
        name=f'{name}-prior',
        trainable=False,
        # kernel_initializer={
        #     'class_name': 'VarianceScaling',
        #     'config': {
        #         'scale': prior_scale,
        #         # 'scale': 1.0,
        #         'mode': 'fan_avg',
        #         'distribution': 'uniform',
        #     },
        # },
        **Q_prior_kwargs,
    )

    Q_prior_model_out = Q_prior_model(model_inputs)
    Q_prior_model_out = tf.keras.layers.Lambda(
        lambda x: prior_loc + x * prior_scale,
    )(Q_prior_model_out)

    Q_model_out = tf.keras.layers.Add()(
        (Q_predictor_model_out, Q_prior_model_out))

    Q_model = tf.keras.Model(inputs, Q_model_out, name=name)

    assert not Q_model.get_layer(f'{name}-prior').losses
    assert not Q_prior_model.trainable_variables, (
        Q_prior_model.trainable_variables)

    Q_function = StateActionValueFunction(
        model=Q_model, observation_keys=observation_keys, name=name)

    return Q_function


