import copy

from ray import tune

from softlearning.utils.dict import deep_update
from softlearning.utils.git import get_git_rev

from examples.development import (
    get_variant_spec as get_development_variant_spec)
from examples.development.variants import (
    ALGORITHM_PARAMS_BASE,
    get_algorithm_params)
import itertools


from examples.development.variants import (
    get_variant_spec as development_get_variant_spec,
    MAX_PATH_LENGTH_PER_UNIVERSE_DOMAIN_TASK,
    EPOCH_LENGTH_PER_UNIVERSE_DOMAIN_TASK,
    TOTAL_STEPS_PER_UNIVERSE_DOMAIN_TASK,
    ALGORITHM_PARAMS_ADDITIONAL,
    ENVIRONMENT_PARAMS_PER_UNIVERSE_DOMAIN_TASK,
    DEFAULT_KEY)

ENVIRONMENT_PARAMS_PER_UNIVERSE_DOMAIN_TASK.setdefault('dm_control', {})
ENVIRONMENT_PARAMS_PER_UNIVERSE_DOMAIN_TASK['dm_control'].update({
    'cartpole': {
        'custom_swingup_sparse': {
            'control_cost_weight': 1e-1,
        },
        'custom_swingup_vision_sparse': {
            'control_cost_weight': 1e-1,
        },
    },
})


def get_variant_spec(args):
    HIDDEN_LAYER_SIZE = 256

    universe, domain, task = args.universe, args.domain, args.task

    variant_spec = get_development_variant_spec(args)

    assert args.algorithm == 'BayesianBellmanActorCritic', args.algorithm

    bbac_algorithm_kwargs = {
        'class_name': args.algorithm,
        'config': {
            'eval_n_episodes': 10,
            'policy_lr': 3e-4,
            'Q_lr': 3e-4,
            'alpha_lr': tune.sample_from(lambda spec: (
                {
                    True: float('nan'),
                    False: 3e-4
                }[spec.get('config', spec)
                  ['algorithm_params']
                  ['config']
                  ['exploitation_policy_update_type']
                  == 'vbac']
            )),
            'reward_scale': 1.0,
            'discount': 0.99,
            'tau': 5e-3,  # For BBAC
            # 'tau': 1.0,  # For BAC
            'target_entropy': 'auto',
            'target_update_interval': 1,

            'exploration_num_actor_updates': 1,
            'exploration_num_critic_updates': 1,
            'exploration_num_target_q_samples': 1,
            'exploration_q_target_reduce_type': 'mean',
            'exploration_q_target_type': 'virel',
            'exploration_policy_update_type': 'vbac',
            'exploration_policy_num_q_samples': 1,
            'exploration_policy_q_reduce_type': 'mean',
            'exploration_policy_q_type': 'Q',
            'exploration_policy_q_ensemble_subset_size': None,

            'exploitation_num_actor_updates': 1,
            'exploitation_num_critic_updates': tune.sample_from(lambda spec: (
                {
                    'exploitation_Q': 1,
                    'exploitation_Q_target': 1,
                    'exploration_Q': 0,
                    'exploration_Q_target': 0,
                }[spec.get('config', spec)
                  ['algorithm_params']
                  ['config']
                  ['exploitation_policy_q_type']]
            )),
            'exploitation_num_target_q_samples': 1,
            'exploitation_q_target_ensemble_subset_size': 2,
            'exploitation_q_target_reduce_type': 'min',
            'exploitation_q_target_type': 'sac',
            'exploitation_policy_update_type': 'sac',
            'exploitation_policy_num_q_samples': 1,
            'exploitation_policy_q_reduce_type': 'mean',
            'exploitation_policy_q_type': 'exploitation_Q',
            'exploitation_policy_q_ensemble_subset_size': None,
       },
    }

    if (domain, task) in (('Pendulum', 'v0'), ('MountainCar', 'Continuous-v0')):
        bbac_algorithm_kwargs['config']['num_warmup_samples'] = 0

    prior_loc = 0.0
    prior_scale = {
        ('MountainCar', 'Continuous-v0'): 100.0,
        ('cartpole', 'custom_swingup_sparse'): tune.grid_search(
            [4.0, 8.0, 16.0, 32.0, 64.0, 128.0, 256.0]),
    }.get((domain, task), 1.0)

    prior_loss_weight = tune.grid_search(
        [0.0, 3e-8, 3e-7, 3e-6, 3e-5, 3e-4, 3e-3])

    exploration_ensemble_N_grid = tune.grid_search([1, 2, 4, 8, 16, 32])
    exploitation_ensemble_N_grid = tune.sample_from(lambda spec: (
        {
            'exploitation_Q': 2,
            'exploitation_Q_target': 2,
            'exploration_Q': 0,
            'exploration_Q_target': 0,
        }[spec.get('config', spec)
          ['algorithm_params']
          ['config']
          ['exploitation_policy_q_type']]
    ))
    activation = 'relu'

    variant_spec['sampler_params']['class_name'] = 'BBACSampler'
    variant_spec['sampler_params']['config'].update({
        'exploitation_policy_sample_ratio': tune.grid_search([
            0.0,
        ])
    })
    variant_spec['Q_params'] = {
        'class_name': 'random_prior_ensemble_feedforward_Q_function',
        'config': {
            'N': exploration_ensemble_N_grid,
            'hidden_layer_sizes': tune.sample_from(lambda spec: (
                spec.get('config', spec)
                ['policy_params']
                ['config']
                ['hidden_layer_sizes']
            )),
            'observation_keys': None,
            'preprocessors': None,
            'activation': activation,
            'kernel_regularizer': {
                'class_name': 'l2',
                'config': {
                    'l': prior_loss_weight,
                },
            },
            'prior_loc': prior_loc,
            'prior_scale': prior_scale,
        },
    }

    variant_spec['algorithm_params'] = deep_update(
        ALGORITHM_PARAMS_BASE,
        get_algorithm_params(universe, domain, task),
        bbac_algorithm_kwargs,
    )

    variant_spec['policy_params'] = deep_update(
        variant_spec['policy_params'],
        {
            'config': {
                'hidden_layer_sizes': (HIDDEN_LAYER_SIZE, HIDDEN_LAYER_SIZE),
            }
        },
    )

    variant_spec['exploration_policy_params'] = deep_update(
        variant_spec['policy_params'],
        {
            'config': {
                'hidden_layer_sizes': (HIDDEN_LAYER_SIZE, HIDDEN_LAYER_SIZE),
            }
        },
    )

    variant_spec['Q_target_params'] = tune.sample_from(lambda spec: (
        spec.get('config', spec)
        ['Q_params']
    ))

    variant_spec['exploitation_Q_params'] = {
        'class_name': 'ensemble_feedforward_Q_function',
        'config': {
            'N': exploitation_ensemble_N_grid,
            'hidden_layer_sizes': tune.sample_from(lambda spec: (
                spec.get('config', spec)
                ['policy_params']
                ['config']
                ['hidden_layer_sizes']
            )),
            'observation_keys': None,
            'preprocessors': None,
            'activation': activation,
        },
    }

    variant_spec['exploitation_Q_target_params'] = tune.sample_from(lambda spec: (
        spec.get('config', spec)
        ['exploitation_Q_params']
    ))

    if variant_spec['environment_params']['training']['domain'] in (
            'MountainCar', 'Pendulum'):
        variant_spec['environment_params']['training']['kwargs'] = {
            'rescale_action_range': (-1.0, 1.0),
            'rescale_observation_range': (-1.0, 1.0),
        }

    variant_spec['git_sha-bbac'] = get_git_rev(__file__)
    return variant_spec
