import argparse
from teachDRL.spinup.utils.run_utils import setup_logger_kwargs
from teachDRL.spinup.algos.sac.sac import sac
from teachDRL.spinup.algos.sac import core
import gym
# import teachDRL.gym_flowers
from teachDRL.teachers.teacher_controller import TeacherController
from collections import OrderedDict
import os
import numpy as np
import sys


def main(args):
    # Argument definition
    parser = argparse.ArgumentParser()

    parser.add_argument('--exp_name', type=str, default='test')
    parser.add_argument('--seed', '-s', type=int, default=0)
    parser.add_argument('--log_path', type=str, default=None)

    # Deep RL student arguments, so far only works with SAC
    parser.add_argument('--hid', type=int, default=-1)  # number of neurons in hidden layers
    parser.add_argument('--l', type=int, default=1)  # number of hidden layers
    parser.add_argument('--gamma', type=float, default=0.99)
    parser.add_argument('--epochs', type=int, default=100)
    parser.add_argument('--gpu_id', type=int, default=-1)  # default is no GPU
    parser.add_argument('--ent_coef', type=float, default=0.005)
    parser.add_argument('--max_ep_len', type=int, default=2000)
    parser.add_argument('--steps_per_ep', type=int, default=200000)  # nb of env steps per epochs (stay above max_ep_len)
    parser.add_argument('--buf_size', type=int, default=2000000)
    parser.add_argument('--nb_test_episodes', type=int, default=50)
    parser.add_argument('--lr', type=float, default=1e-3)
    parser.add_argument('--train_freq', type=int, default=10)
    parser.add_argument('--batch_size', type=int, default=1000)

    # Parameterized bipedal walker arguments, so far only works with bipedal-walker-continuous-v0
    parser.add_argument('--env', type=str, default="bipedal-walker-continuous-v0")

    # Choose student (walker morphology)
    parser.add_argument('--leg_size', type=str, default="default")  # choose walker type ("short", "default" or "quadru")


    # Selection of parameter space
    # So far 3 choices: "--max_stump_h 3.0 --max_obstacle_spacing 6.0" (aka Stump Tracks) or "-hexa" (aka Hexagon Tracks)
    # or "-seq" (untested experimental env)
    parser.add_argument('--max_stump_h', type=float, default=None)
    parser.add_argument('--max_stump_w', type=float, default=None)
    parser.add_argument('--max_stump_r', type=float, default=None)
    parser.add_argument('--roughness', type=float, default=None)
    parser.add_argument('--max_obstacle_spacing', type=float, default=None)
    parser.add_argument('--max_gap_w', type=float, default=None)
    parser.add_argument('--step_h', type=float, default=None)
    parser.add_argument('--step_nb', type=float, default=None)
    parser.add_argument('--hexa_shape', '-hexa', action='store_true')
    parser.add_argument('--stump_seq', '-seq', action='store_true')

    # Teacher-specific arguments:
    parser.add_argument('--teacher', type=str, default="ALP-GMM")  # ALP-GMM, Covar-GMM, RIAC, Oracle, Random

    # ALPGMM (Absolute Learning Progress - Gaussian Mixture Model) related arguments
    parser.add_argument('--gmm_fitness_fun', '-fit', type=str, default=None)
    parser.add_argument('--nb_em_init', type=int, default=None)
    parser.add_argument('--min_k', type=int, default=None)
    parser.add_argument('--max_k', type=int, default=None)
    parser.add_argument('--fit_rate', type=int, default=None)
    parser.add_argument('--weighted_gmm', '-wgmm', action='store_true')
    parser.add_argument('--alp_max_size', type=int, default=None)

    # CovarGMM related arguments
    parser.add_argument('--absolute_lp', '-alp', action='store_true')

    # RIAC related arguments
    parser.add_argument('--max_region_size', type=int, default=None)
    parser.add_argument('--alp_window_size', type=int, default=None)

    parser.add_argument('--size_ensemble', type=int, default=3)
    parser.add_argument('--presample_size', type=int, default=1000)
    parser.add_argument('--disagreement_str', type=str, default='id')
    parser.add_argument('--random_task_ratio', type=float, default=0.1)
    parser.add_argument('--reuse_past_goals_ratio', type=float, default=0.1)

    parser.add_argument('--use_her', type=bool, default=True)

    args = parser.parse_args(args)

    # logger_kwargs = setup_logger_kwargs(args.exp_name, args.seed)
    log_path = os.path.expanduser(args.log_path)
    os.makedirs(os.path.expanduser(log_path), exist_ok=True)
    logger_kwargs = dict(
        output_dir=log_path, exp_name=args.exp_name,
    )

    # Bind this run to specific GPU if there is one
    if args.gpu_id != -1:
        os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)

    # Set up Student's DeepNN architecture if provided
    ac_kwargs = dict()
    if args.hid != -1:
        ac_kwargs['hidden_sizes'] = [args.hid] * args.l

    # Set bounds for environment's parameter space format:[min, max, nb_dimensions] (if no nb_dimensions, assumes only 1)
    param_env_bounds = OrderedDict()
    # if args.max_stump_h is not None:
    #     param_env_bounds['stump_height'] = [0, args.max_stump_h]
    # if args.max_stump_w is not None:
    #     param_env_bounds['stump_width'] = [0, args.max_stump_w]
    # if args.max_stump_r is not None:
    #     param_env_bounds['stump_rot'] = [0, args.max_stump_r]
    # if args.max_obstacle_spacing is not None:
    #     param_env_bounds['obstacle_spacing'] = [0, args.max_obstacle_spacing]
    # if args.hexa_shape:
    #     param_env_bounds['poly_shape'] = [0, 4.0, 12]
    # if args.stump_seq:
    #     param_env_bounds['stump_seq'] = [0, 6.0, 10]

    # Set Teacher hyperparameters
    params = {}
    if args.teacher == 'ALP-GMM':
        if args.gmm_fitness_fun is not None:
            params['gmm_fitness_fun'] = args.gmm_fitness_fun
        if args.min_k is not None and args.max_k is not None:
            params['potential_ks'] = np.arange(args.min_k, args.max_k, 1)
        if args.weighted_gmm is True:
            params['weighted_gmm'] = args.weighted_gmm
        if args.nb_em_init is not None:
            params['nb_em_init'] = args.nb_em_init
        if args.fit_rate is not None:
            params['fit_rate'] = args.fit_rate
        if args.alp_max_size is not None:
            params['alp_max_size'] = args.alp_max_size
    elif args.teacher == 'Covar-GMM':
        if args.absolute_lp is True:
            params['absolute_lp'] = args.absolute_lp
    elif args.teacher == "RIAC":
        if args.max_region_size is not None:
            params['max_region_size'] = args.max_region_size
        if args.alp_window_size is not None:
            params['alp_window_size'] = args.alp_window_size
    elif args.teacher == "Oracle":
        if 'stump_height' in param_env_bounds and 'obstacle_spacing' in param_env_bounds:
            params['window_step_vector'] = [0.1, -0.2]  # order must match param_env_bounds construction
        elif 'poly_shape' in param_env_bounds:
            params['window_step_vector'] = [0.1] * 12
            print('hih')
        elif 'stump_seq' in param_env_bounds:
            params['window_step_vector'] = [0.1] * 10
        else:
            print('Oracle not defined for this parameter space')
            exit(1)
    elif args.teacher == 'Disagreement':
        assert args.size_ensemble > 0
        pass
    elif args.teacher == 'Random':
        pass

    if args.teacher != 'Disagreement':
        args.size_ensemble = 0

    from gym import Wrapper, GoalEnv
    from gym.wrappers import FlattenObservation

    class RoboticsWrapper(Wrapper):
        def reset(self, *, random_goal=False):
            self.env.env._elapsed_steps = 0
            GoalEnv.reset(self.unwrapped)
            did_reset_sim = False
            while not did_reset_sim:
                did_reset_sim = self.unwrapped._reset_sim()
            if random_goal:
                self.unwrapped.goal = self.unwrapped._sample_goal().copy()
            # else:
            #     self.goal = None
            obs = self.observation(self.unwrapped._get_obs())
            return obs

        def compute_goal(self, *args, **kwargs):
            raise NotImplementedError

        def set_environment(self, **param_dict):
            # assert self.goal is None, "must call reset(random_goal=False) before Teacher sets env goals"
            if 'goal' in param_dict:
                assert len(param_dict) == 1
                goal = param_dict['goal']
            else:
                goal = self.compute_goal(**param_dict)
            self.unwrapped.goal = goal
            obs = self.observation(self.unwrapped._get_obs())
            return obs

    # def _sample_goal(self):
    #     if self.has_object:
    #         goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
    #         goal += self.target_offset
    #         goal[2] = self.height_offset
    #         if self.target_in_the_air and self.np_random.uniform() < 0.5:
    #             goal[2] += self.np_random.uniform(0, 0.45)
    #     else:
    #         goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
    #     return goal.copy()

    class FetchWrapper(RoboticsWrapper):
        def compute_goal(self, *, goal_noise):
            if self.has_object:
                goal = self.initial_gripper_xpos[:3] + (goal_noise * 2 - 1) * self.target_range
                goal += self.target_offset
                goal[2] = self.height_offset
                if self.target_in_the_air and goal_noise[2] < 0.5:
                    goal[2] += goal_noise[2] / 0.5 * 0.45
            else:
                goal = self.initial_gripper_xpos[:3] + (goal_noise * 2 - 1) * self.target_range

            return goal


    # def _sample_goal(self):
    #     thumb_name = 'robot0:S_thtip'
    #     finger_names = [name for name in FINGERTIP_SITE_NAMES if name != thumb_name]
    #     finger_name = self.np_random.choice(finger_names)
    #
    #     thumb_idx = FINGERTIP_SITE_NAMES.index(thumb_name)
    #     finger_idx = FINGERTIP_SITE_NAMES.index(finger_name)
    #     assert thumb_idx != finger_idx
    #
    #     # Pick a meeting point above the hand.
    #     meeting_pos = self.palm_xpos + np.array([0.0, -0.09, 0.05])
    #     meeting_pos += self.np_random.normal(scale=0.005, size=meeting_pos.shape)
    #
    #     # Slightly move meeting goal towards the respective finger to avoid that they
    #     # overlap.
    #     goal = self.initial_goal.copy().reshape(-1, 3)
    #     for idx in [thumb_idx, finger_idx]:
    #         offset_direction = (meeting_pos - goal[idx])
    #         offset_direction /= np.linalg.norm(offset_direction)
    #         goal[idx] = meeting_pos - 0.005 * offset_direction
    #
    #     if self.np_random.uniform() < 0.1:
    #         # With some probability, ask all fingers to move back to the origin.
    #         # This avoids that the thumb constantly stays near the goal position already.
    #         goal = self.initial_goal.copy()
    #     return goal.flatten()

    from gym.envs.robotics.hand.reach import FINGERTIP_SITE_NAMES
    class HandReachWrapper(RoboticsWrapper):
        def compute_goal(self, *, goal_noise, normal_noise):
            thumb_name = 'robot0:S_thtip'
            thumb_idx = 4
            assert FINGERTIP_SITE_NAMES[thumb_idx] == thumb_name

            finger_idx = int(goal_noise[0] * 4) // 4
            assert thumb_idx != finger_idx
            #
            # finger_names = [name for name in FINGERTIP_SITE_NAMES if name != thumb_name]
            # finger_name = self.np_random.choice(finger_names)
            #
            # thumb_idx = FINGERTIP_SITE_NAMES.index(thumb_name)
            # finger_idx = FINGERTIP_SITE_NAMES.index(finger_name)
            # assert thumb_idx != finger_idx

            # Pick a meeting point above the hand.
            meeting_pos = self.palm_xpos + np.array([0.0, -0.09, 0.05])
            meeting_pos += normal_noise * 0.005

            # Slightly move meeting goal towards the respective finger to avoid that they
            # overlap.
            goal = self.initial_goal.copy().reshape(-1, 3)
            for idx in [thumb_idx, finger_idx]:
                offset_direction = (meeting_pos - goal[idx])
                offset_direction /= np.linalg.norm(offset_direction)
                goal[idx] = meeting_pos - 0.005 * offset_direction

            if goal_noise[1] < 0.1:
                # With some probability, ask all fingers to move back to the origin.
                # This avoids that the thumb constantly stays near the goal position already.
                goal = self.initial_goal.copy()
            return goal.flatten()

    # def _sample_goal(self):
    #     # Select a goal for the object position.
    #     target_pos = None
    #     if self.target_position == 'random':
    #         assert self.target_position_range.shape == (3, 2)
    #         offset = self.np_random.uniform(self.target_position_range[:, 0], self.target_position_range[:, 1])
    #         assert offset.shape == (3,)
    #         target_pos = self.sim.data.get_joint_qpos('object:joint')[:3] + offset
    #     elif self.target_position in ['ignore', 'fixed']:
    #         target_pos = self.sim.data.get_joint_qpos('object:joint')[:3]
    #     else:
    #         raise error.Error('Unknown target_position option "{}".'.format(self.target_position))
    #     assert target_pos is not None
    #     assert target_pos.shape == (3,)
    #
    #     # Select a goal for the object rotation.
    #     target_quat = None
    #     if self.target_rotation == 'z':
    #         angle = self.np_random.uniform(-np.pi, np.pi)
    #         axis = np.array([0., 0., 1.])
    #         target_quat = quat_from_angle_and_axis(angle, axis)
    #     elif self.target_rotation == 'parallel':
    #         angle = self.np_random.uniform(-np.pi, np.pi)
    #         axis = np.array([0., 0., 1.])
    #         target_quat = quat_from_angle_and_axis(angle, axis)
    #         parallel_quat = self.parallel_quats[self.np_random.randint(len(self.parallel_quats))]
    #         target_quat = rotations.quat_mul(target_quat, parallel_quat)
    #     elif self.target_rotation == 'xyz':
    #         angle = self.np_random.uniform(-np.pi, np.pi)
    #         axis = self.np_random.uniform(-1., 1., size=3)
    #         target_quat = quat_from_angle_and_axis(angle, axis)
    #     elif self.target_rotation in ['ignore', 'fixed']:
    #         target_quat = self.sim.data.get_joint_qpos('object:joint')
    #     else:
    #         raise error.Error('Unknown target_rotation option "{}".'.format(self.target_rotation))
    #     assert target_quat is not None
    #     assert target_quat.shape == (4,)
    #
    #     target_quat /= np.linalg.norm(target_quat)  # normalized quaternion
    #     goal = np.concatenate([target_pos, target_quat])
    #     return goal

    from gym import error
    from gym.envs.robotics.hand.manipulate import quat_from_angle_and_axis, rotations

    class HandManipulateWrapper(RoboticsWrapper):
        def compute_goal(self, *, goal_noise):
            # Select a goal for the object position.
            target_pos = None
            if self.target_position == 'random':
                assert self.target_position_range.shape == (3, 2)
                offset = goal_noise[0] * (self.target_position_range[:, 1] - self.target_position_range[:, 0]) + \
                    self.target_position_range[:, 0]
                assert offset.shape == (3,)
                target_pos = self.sim.data.get_joint_qpos('object:joint')[:3] + offset
            elif self.target_position in ['ignore', 'fixed']:
                target_pos = self.sim.data.get_joint_qpos('object:joint')[:3]
            else:
                raise error.Error('Unknown target_position option "{}".'.format(self.target_position))
            assert target_pos is not None
            assert target_pos.shape == (3,)

            # Select a goal for the object rotation.
            target_quat = None
            if self.target_rotation == 'z':
                angle = goal_noise[1] * 2 * np.pi - np.pi
                axis = np.array([0., 0., 1.])
                target_quat = quat_from_angle_and_axis(angle, axis)
            elif self.target_rotation == 'parallel':
                angle = goal_noise[1] * 2 * np.pi - np.pi
                axis = np.array([0., 0., 1.])
                target_quat = quat_from_angle_and_axis(angle, axis)
                parallel_quat_idx = int(goal_noise[2] * len(self.parallel_quats)) // len(self.parallel_quats)  # rare case: goal_noise[2] ~ 1
                parallel_quat = self.parallel_quats[parallel_quat_idx]
                target_quat = rotations.quat_mul(target_quat, parallel_quat)
            elif self.target_rotation == 'xyz':
                angle = goal_noise[1] * 2 * np.pi - np.pi
                axis = goal_noise[2:5] * 2 - 1
                target_quat = quat_from_angle_and_axis(angle, axis)
            elif self.target_rotation in ['ignore', 'fixed']:
                target_quat = self.sim.data.get_joint_qpos('object:joint')
            else:
                raise error.Error('Unknown target_rotation option "{}".'.format(self.target_rotation))
            assert target_quat is not None
            assert target_quat.shape == (4,)

            target_quat /= np.linalg.norm(target_quat)  # normalized quaternion
            goal = np.concatenate([target_pos, target_quat])
            return goal

    # def _sample_goal(self):
    #     goal_ind = self.grid_free_index[np.random.choice(len(self.grid_free_index))]
    #     return (goal_ind + self.np_random.uniform(low=0, high=1, size=2)) / self.grid_size * 2 - 1

    class MazeWrapper(Wrapper):
        def reset(self, *, random_goal=False):
            self.env.env._elapsed_steps = 0
            GoalEnv.reset(self.unwrapped)
            if random_goal:
                self.unwrapped.goal = self.unwrapped._sample_goal()
            # else:
            #     self.unwrapped.goal = None
            obs = self.observation(self.unwrapped._get_obs())
            return obs

        def compute_goal(self, *, goal_noise):
            idx = int(goal_noise[0] * len(self.grid_free_index)) // len(self.grid_free_index)
            goal_ind = self.grid_free_index[idx]
            return (goal_ind + goal_noise[1:3]) / self.grid_size * 2 - 1

        def set_environment(self, **param_dict):
            # assert self.goal is None, "must call reset(random_goal=False) before Teacher sets env goals"
            if 'goal' in param_dict:
                assert len(param_dict) == 1
                goal = param_dict['goal']
            else:
                goal = self.compute_goal(**param_dict)
            self.unwrapped.goal = goal
            obs = self.observation(self.unwrapped._get_obs())
            return obs

    env_init = {}
    # env_init['leg_size'] = args.leg_size

    # env = env_f()

    if args.env.startswith('Fetch'):
        env_f = lambda: FetchWrapper(FlattenObservation(gym.make(args.env)))
        param_env_bounds['goal_noise'] = [0, 1, 3]
        total_steps = int(4E6)  # int(1E6)
    elif args.env.startswith('HandManipulate'):
        env_f = lambda: HandManipulateWrapper(FlattenObservation(gym.make(args.env)))
        param_env_bounds['goal_noise'] = [0, 1, 5]
        total_steps = int(4E6)
    elif args.env.startswith('HandReach'):
        env_f = lambda: HandReachWrapper(FlattenObservation(gym.make(args.env)))
        param_env_bounds['goal_noise'] = [0, 1, 2]
        param_env_bounds['normal_noise'] = [-np.inf, np.inf, 3]
        total_steps = int(4E6)
    elif args.env.startswith('Maze'):
        env_f = lambda: MazeWrapper(FlattenObservation(gym.make(args.env)))
        param_env_bounds['goal_noise'] = [0, 1, 3]
        total_steps = int(4E5)
    else:
        raise NotImplementedError(args.env + 'is not supported')

    assert 'goal' not in param_env_bounds

    env = env_f()
    args.max_ep_len = env.spec.max_episode_steps
    args.steps_per_ep = args.max_ep_len * 100
    args.epochs = total_steps // args.steps_per_ep

    reward_bounds = (-args.max_ep_len, 0)  # assume binary reward in {-1, 0}

    # Initialize teacher
    if args.teacher == 'Disagreement':
        # sample from goal space rather than noise space
        del param_env_bounds['goal_noise']

        from gym.spaces import flatdim

        def presample(o, g_list=None):
            if g_list is None:
                g_list = [env.unwrapped._sample_goal() for _ in range(args.presample_size)]
            else:
                assert len(g_list) == args.presample_size
            assert o.shape == env.observation_space.shape

            obs = dict()
            start = 0
            for k, v in env.unwrapped.observation_space.spaces.items():
                # take slice
                obs[k] = o[start:start + flatdim(v)]
                start += flatdim(v)
                # preprocessing: swap or repeat
                if k == 'desired_goal':
                    obs[k] = g_list
                else:
                    obs[k] = np.repeat(obs[k][np.newaxis, :], repeats=args.presample_size, axis=0)

            obs = np.concatenate([obs[k] for k in env.unwrapped.observation_space.spaces.keys()], axis=1)
            return obs, g_list

        teacher_params = {
            'presample_fun': presample,
            'disagreement_str': args.disagreement_str,
            'random_task_ratio': args.random_task_ratio,
            'reuse_past_goals_ratio': args.reuse_past_goals_ratio,
            'presample_size': args.presample_size,
        }
    else:
        teacher_params = {}

    Teacher = TeacherController(args.teacher, args.nb_test_episodes, param_env_bounds, reward_bounds=reward_bounds,
                                seed=args.seed, teacher_params={**params, **teacher_params})

    # Launch Student training
    sac(env_f, actor_critic=core.mlp_actor_critic, ac_kwargs=ac_kwargs, gamma=args.gamma, seed=args.seed, epochs=args.epochs,
        logger_kwargs=logger_kwargs, alpha=args.ent_coef, max_ep_len=args.max_ep_len, steps_per_epoch=args.steps_per_ep,
        replay_size=args.buf_size, env_init=env_init, env_name=args.env, nb_test_episodes=args.nb_test_episodes, lr=args.lr,
        train_freq=args.train_freq, batch_size=args.batch_size, Teacher=Teacher, teacher=args.teacher, size_ensemble=args.size_ensemble,
        use_her=args.use_her, presample_size=args.presample_size, disagreement_str=args.disagreement_str, random_task_ratio=args.random_task_ratio,
        reuse_past_goals_ratio=args.reuse_past_goals_ratio,
        )


if __name__ == "__main__":
    main(None)