# Copyright 2020 The Weakly-Supervised Control Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, Tuple
import importlib
import os

import gym
from gym.spaces import Box
import numpy as np
from tensorflow.io import gfile

import multiworld
from multiworld.core.image_env import ImageEnv
from rlkit.core import logger
from rlkit.torch.networks import FlattenMlp

from weakly_supervised_control.disentanglement.train_utils import DisentanglementModel
from weakly_supervised_control.envs import register_all_envs
from weakly_supervised_control.envs.env_util import get_camera_fn
from weakly_supervised_control.vae.conv_vae import ConvVAE
from weakly_supervised_control.vae.vae_trainer import ConvVAETrainer


def replace_paths(d):
    """Replaces string paths with pointer to a function or absolute path."""
    for key in d.keys():
        val = d[key]
        if isinstance(val, dict):
            replace_paths(val)
        else:
            if isinstance(val, str):
                if val.startswith('multiworld/'):
                  # Replace val with a file path (str)
                    d[key] = os.path.join(
                        os.path.dirname(
                            os.path.dirname(multiworld.__file__)),
                        val
                    )
                elif val.startswith('rlkit.') or val.startswith('multiworld.'):
                    # Replace val with pointer to a function
                    items = val.split('.')
                    module = importlib.import_module('.'.join(items[:-1]))
                    d[key] = getattr(module, items[-1])


def load_config(config_path: str):
    with open(config_path, 'r', encoding='utf-8') as f:
        variant = eval(f.read())
    replace_paths(variant)
    return variant


def disable_tensorflow_gpu():
    """Disables CUDA for Tensorflow."""
    cuda_devices = os.environ.get('CUDA_VISIBLE_DEVICES', None)
    os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
    # This is necessary to prevent tensorflow from using the GPU.
    from tensorflow.python.client import device_lib
    print(device_lib.list_local_devices())
    if cuda_devices is None:
        del os.environ['CUDA_VISIBLE_DEVICES']
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = cuda_devices


def load_disentanglement_model(
        model_path: str, factors: str,
        train_dset,
        factor_indices: Tuple[int]):
    model = DisentanglementModel(train_dset, factors)
    model.load_checkpoint(model_path)

    data = train_dset.sample(train_dset.size)
    z_min, z_max = model.get_latent_range(data)
    z_min = np.array([
        z_min[i] for i in factor_indices])
    z_max = np.array([
        z_max[i] for i in factor_indices])
    space = Box(z_min, z_max, dtype=np.float32)

    return model, space


def load_presampled_goals(presampled_goals_path: str):
    with gfile.GFile(presampled_goals_path, 'rb') as fin:
        presampled_goals = np.load(fin, allow_pickle=True)
    if hasattr(presampled_goals, 'item'):
        presampled_goals = presampled_goals.item()
    # Note: These must be deleted due to multiworld.core.image_env.py:224
    n = presampled_goals['image_desired_goal'].shape[0]
    presampled_goals = {
        k: v for k, v in presampled_goals.items() if len(v) == n}
    return presampled_goals


def create_image_env(env_id: str, imsize: int = 48, presampled_goals=None):
    register_all_envs()
    env = gym.make(env_id)
    init_camera = get_camera_fn(env_id)
    env.unwrapped.initialize_camera(init_camera)
    return ImageEnv(
        env,
        imsize,
        transpose=True,
        normalize=True,
        presampled_goals=presampled_goals,
    )


def train_vae(
    vae: ConvVAE,
    dset,
    num_epochs: int = 0,  # Do not pre-train by default
    save_period: int = 5,
    test_p: float = 0.1,  # data proportion to use for test
    vae_trainer_kwargs: Dict = {},
):
    logger.remove_tabular_output(
        'progress.csv', relative_to_snapshot_dir=True
    )
    logger.add_tabular_output(
        'vae_progress.csv', relative_to_snapshot_dir=True
    )

    # Flatten images.
    n = dset.data.shape[0]
    data = dset.data.transpose(0, 3, 2, 1)
    assert data.shape[1] == 3
    data = data.reshape((n, -1))

    # Un-normalize images.
    if data.dtype != np.uint8:
        assert np.min(data) >= 0.0
        assert np.max(data) <= 1.0
        data = (data * 255).astype(np.uint8)

    # Normalize factors
    factors = dset.factors
    factors = (factors - np.min(factors)) / np.ptp(factors)

    # Split into train and test set.
    test_size = int(n * test_p)
    test_data = data[:test_size]
    train_data = data[test_size:]
    train_factors = factors[test_size:]
    test_factors = factors[:test_size]

    logger.get_snapshot_dir()

    t = ConvVAETrainer(train_data, test_data, vae,
                       train_factors=train_factors,
                       test_factors=test_factors,
                       **vae_trainer_kwargs)

    for epoch in range(num_epochs):
        should_save_imgs = (epoch % save_period == 0)
        t.train_epoch(epoch)
        t.test_epoch(epoch, save_reconstruction=should_save_imgs)
        if should_save_imgs:
            t.dump_samples(epoch)
        t.update_train_weights()

    logger.save_extra_data(vae, 'vae.pkl', mode='pickle')
    logger.remove_tabular_output(
        'vae_progress.csv',
        relative_to_snapshot_dir=True,
    )
    logger.add_tabular_output(
        'progress.csv',
        relative_to_snapshot_dir=True,
    )

    return vae, train_data, test_data, train_factors, test_factors
