import numpy as np
import tensorflow as tf
import time
from teachDRL.spinup.algos.sac import core
from teachDRL.spinup.algos.sac.core import get_vars
from teachDRL.spinup.utils.logx import EpochLogger

class ReplayBuffer:
    """
    A simple FIFO experience replay buffer for SAC agents.
    """

    def __init__(self, obs_dim, act_dim, size):
        self.obs1_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.obs2_buf = np.zeros([size, obs_dim], dtype=np.float32)
        self.acts_buf = np.zeros([size, act_dim], dtype=np.float32)
        self.rews_buf = np.zeros(size, dtype=np.float32)
        self.done_buf = np.zeros(size, dtype=np.float32)
        self.ptr, self.size, self.max_size = 0, 0, size

    def store(self, obs, act, rew, next_obs, done):
        self.obs1_buf[self.ptr] = obs
        self.obs2_buf[self.ptr] = next_obs
        self.acts_buf[self.ptr] = act
        self.rews_buf[self.ptr] = rew
        self.done_buf[self.ptr] = done
        self.ptr = (self.ptr+1) % self.max_size
        self.size = min(self.size+1, self.max_size)

    def sample_batch(self, batch_size=32):
        idxs = np.random.randint(0, self.size, size=batch_size)
        return dict(obs1=self.obs1_buf[idxs],
                    obs2=self.obs2_buf[idxs],
                    acts=self.acts_buf[idxs],
                    rews=self.rews_buf[idxs],
                    done=self.done_buf[idxs])

class HERReplayBuffer:
    # TODO: rewrite baselines/her/replay_buffer.py
    #  sample_batch copy all transitions then use make_her_transitions is inefficient
    #  store_episode should be FIFO rather than randomly exlude episodes when buffer full

    def __init__(self, obs_dim, act_dim, size_in_transitions, T, sample_transitions):
        """Creates a replay buffer.
        Args:
            buffer_shapes (dict of ints): the shape for all buffers that are used in the replay
                buffer
            size_in_transitions (int): the size of the buffer, measured in transitions
            T (int): the time horizon for episodes
            sample_transitions (function): a function that samples from the replay buffer
        """
        # self.buffer_shapes = buffer_shapes
        self.size = size_in_transitions // T  # max_size
        self.T = T
        self.sample_transitions = sample_transitions

        # self.buffers is {key: array(size_in_episodes x T or T+1 x dim_key)}
        # self.buffers = {key: np.empty([self.size, *shape])
        #                 for key, shape in buffer_shapes.items()}

        self.buffers = dict(
            obs1=np.empty([self.size, T+1, obs_dim], dtype=np.float32),
            acts=np.empty([self.size, T, act_dim], dtype=np.float32),
            rews=np.empty([self.size, T], dtype=np.float32),
            done=np.empty([self.size, T], dtype=np.float32),
        )

        # memory management
        self.ptr = 0  # ptr
        self.current_size = 0  # size
        self.n_transitions_stored = 0  # total timesteps ever seen

    def store(self, episode):
        for key in self.buffers.keys():
            self.buffers[key][self.ptr] = episode[key]

        self.n_transitions_stored += self.T
        self.current_size = min(self.size, self.current_size+1)
        self.ptr = (self.ptr+1) % self.size

    def sample_batch(self, batch_size=32):
        if batch_size == 0:
            return None
        buffers = {}
        assert self.current_size > 0
        for key in self.buffers.keys():
            buffers[key] = self.buffers[key][:self.current_size]

        buffers['obs2'] = buffers['obs1'][:, 1:, :]
        transitions = self.sample_transitions(buffers, batch_size)
        return transitions


def make_sample_her_transitions(reward_fun, flatten_fun, unflatten_fun, replay_strategy="future", replay_k=4):
    """Creates a sample function that can be used for HER experience replay.
    Args:
        replay_strategy (in ['future', 'none']): the HER replay strategy; if set to 'none',
            regular DDPG experience replay is used
        replay_k (int): the ratio between HER replays and regular replays (e.g. k = 4 -> 4 times
            as many HER replays as regular replays are used)
        reward_fun (function): function to re-compute the reward with substituted goals
    """
    if replay_strategy == 'future':
        future_p = 1 - (1. / (1 + replay_k))
    else:  # 'replay_strategy' == 'none'
        future_p = 0

    def _sample_her_transitions(episode_batch, batch_size_in_transitions):
        """episode_batch is {key: array(buffer_size x T x dim_key)}
        """

        T = episode_batch['acts'].shape[1]
        rollout_batch_size = episode_batch['acts'].shape[0]
        batch_size = batch_size_in_transitions

        # Select which episodes and time steps to use.
        episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
        t_samples = np.random.randint(T, size=batch_size)
        transitions = {key: episode_batch[key][episode_idxs, t_samples].copy()
                       for key in episode_batch.keys()}

        # Select future time indexes proportional with probability future_p. These
        # will be used for HER replay by substituting in future goals.
        her_indexes = np.where(np.random.uniform(size=batch_size) < future_p)
        future_offset = np.random.uniform(size=batch_size) * (T - t_samples)
        future_offset = future_offset.astype(int)
        future_t = (t_samples + 1 + future_offset)[her_indexes]

        # # Replace goal with achieved goal but only for the previously-selected
        # # HER transitions (as defined by her_indexes). For the other transitions,
        # # keep the original goal.
        # future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]
        # transitions['g'][her_indexes] = future_ag
        #
        #
        # # Reconstruct info dictionary for reward  computation.
        # info = {}
        # for key, value in transitions.items():
        #     if key.startswith('info_'):
        #         info[key.replace('info_', '')] = value
        #
        # # Re-compute reward since we may have substituted the goal.
        # reward_params = {k: transitions[k] for k in ['ag_2', 'g']}
        # reward_params['info'] = info
        # transitions['r'] = reward_fun(**reward_params)

        unflatten_obs1 = unflatten_fun(transitions['obs1'])
        unflatten_obs2 = unflatten_fun(transitions['obs2'])
        future_ag = unflatten_fun(episode_batch['obs1'][episode_idxs[her_indexes], future_t])['achieved_goal']
        g = unflatten_obs1['desired_goal']
        g[her_indexes] = future_ag
        ag_2 = unflatten_obs2['achieved_goal']
        reward_params = {'achieved_goal': ag_2, 'desired_goal': g, 'info': {}}
        r = reward_fun(**reward_params)

        unflatten_obs1['desired_goal'] = g
        unflatten_obs2['desired_goal'] = g
        transitions['obs1'] = flatten_fun(unflatten_obs1)
        transitions['obs2'] = flatten_fun(unflatten_obs2)
        transitions['rews'] = r

        transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:])
                       for k in transitions.keys()}

        assert(transitions['acts'].shape[0] == batch_size_in_transitions)

        return transitions

    return _sample_her_transitions

"""

Soft Actor-Critic

(With slight variations that bring it closer to TD3)

"""
def sac(env_fn, actor_critic=core.mlp_actor_critic, ac_kwargs=dict(), seed=0,
        steps_per_epoch=200000, epochs=100, replay_size=int(1e6), gamma=0.99,
        polyak=0.995, lr=1e-3, alpha=0.005, batch_size=1000, start_steps=10000,
        max_ep_len=2000, logger_kwargs=dict(), save_freq=1, env_init=dict(),
        env_name='unknown', nb_test_episodes=50, train_freq=10, Teacher=None,
        size_ensemble=0, ve_str='v', use_her=False, **kwargs,
        ):
    """

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: A function which takes in placeholder symbols 
            for state, ``x_ph``, and action, ``a_ph``, and returns the main 
            outputs from the agent's Tensorflow computation graph:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``mu``       (batch, act_dim)  | Computes mean actions from policy
                                           | given states.
            ``pi``       (batch, act_dim)  | Samples actions from policy given 
                                           | states.
            ``logp_pi``  (batch,)          | Gives log probability, according to
                                           | the policy, of the action sampled by
                                           | ``pi``. Critical: must be differentiable
                                           | with respect to policy parameters all
                                           | the way through action sampling.
            ``q1``       (batch,)          | Gives one estimate of Q* for 
                                           | states in ``x_ph`` and actions in
                                           | ``a_ph``.
            ``q2``       (batch,)          | Gives another estimate of Q* for 
                                           | states in ``x_ph`` and actions in
                                           | ``a_ph``.
            ``q1_pi``    (batch,)          | Gives the composition of ``q1`` and 
                                           | ``pi`` for states in ``x_ph``: 
                                           | q1(x, pi(x)).
            ``q2_pi``    (batch,)          | Gives the composition of ``q2`` and 
                                           | ``pi`` for states in ``x_ph``: 
                                           | q2(x, pi(x)).
            ``v``        (batch,)          | Gives the value estimate for states
                                           | in ``x_ph``. 
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_critic 
            function you provided to SAC.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs) 
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs to run and train agent.

        replay_size (int): Maximum length of replay buffer.

        gamma (float): Discount factor. (Always between 0 and 1.)

        polyak (float): Interpolation factor in polyak averaging for target 
            networks. Target networks are updated towards main networks 
            according to:

            .. math:: \\theta_{\\text{targ}} \\leftarrow 
                \\rho \\theta_{\\text{targ}} + (1-\\rho) \\theta

            where :math:`\\rho` is polyak. (Always between 0 and 1, usually 
            close to 1.)

        lr (float): Learning rate (used for both policy and value learning).

        alpha (float): Entropy regularization coefficient. (Equivalent to 
            inverse of reward scale in the original SAC paper.)

        batch_size (int): Minibatch size for SGD.

        start_steps (int): Number of steps for uniform-random action selection,
            before running real policy. Helps exploration.

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """


    logger = EpochLogger(**logger_kwargs)
    hyperparams = locals()
    if Teacher: del hyperparams['Teacher']  # remove teacher to avoid serialization error
    logger.save_config(hyperparams)

    tf.set_random_seed(seed)
    np.random.seed(seed)

    env, test_env = env_fn(), env_fn()

    # initialize environment (choose between short, default or quadrupedal walker)
    if len(env_init.items()) > 0:
        env.env.my_init(env_init)
        test_env.env.my_init(env_init)

    obs_dim = env.env.observation_space.shape[0]
    print(obs_dim)
    act_dim = env.env.action_space.shape[0]

    # Action limit for clamping: critically, assumes all dimensions share the same bound!
    act_limit = env.action_space.high[0]

    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space

    # Inputs to computation graph
    x_ph, a_ph, x2_ph, r_ph, d_ph = core.placeholders(obs_dim, act_dim, obs_dim, None, None)

    ve = [dict() for _ in range(size_ensemble)]
    for e in range(size_ensemble):
        ve[e]['x_ph'], ve[e]['a_ph'], ve[e]['x2_ph'], ve[e]['r_ph'], ve[e]['d_ph'] = core.placeholders(obs_dim, act_dim,
                                                                                                       obs_dim, None,
                                                                                                       None)

    # Main outputs from computation graph
    with tf.variable_scope('main'):
        mu, pi, logp_pi, q1, q2, q1_pi, q2_pi, v = actor_critic(x_ph, a_ph, **ac_kwargs)
    
    # Target value network
    with tf.variable_scope('target'):
        _, _, _, _, _, _, _, v_targ  = actor_critic(x2_ph, a_ph, **ac_kwargs)

    # Train Q ensemble
    if ve_str == 'v':
        for e in range(size_ensemble):
            with tf.variable_scope(f've_main/{e}'):
                _, _, _, _, _, _, _, ve[e]['v'] = actor_critic(ve[e]['x_ph'], ve[e]['a_ph'], **ac_kwargs)
            with tf.variable_scope(f've_target/{e}'):
                _, _, _, _, _, _, _, ve[e]['v_targ'] = actor_critic(ve[e]['x2_ph'], ve[e]['a_ph'], **ac_kwargs)
    elif ve_str == 'q':
        raise NotImplementedError
    else:
        raise NotImplementedError

    def compute_vals(o):
        feed_dict = {}
        for e in range(size_ensemble):
            feed_dict.update({ve[e]['x_ph']: o})
        vals = sess.run([ve[e]['v'] for e in range(size_ensemble)], feed_dict=feed_dict)
        vals = np.asarray(vals)
        # assert vals.shape == (size_ensemble, o.shape[0])
        return vals

    # Experience buffer
    if not use_her:
        replay_buffer = ReplayBuffer(obs_dim=obs_dim, act_dim=act_dim, size=replay_size)
    else:
        from gym.spaces import flatdim

        def unflatten_fun_2d(flatten_obs):
            obs = dict()
            start = 0
            for k, v in env.unwrapped.observation_space.spaces.items():
                obs[k] = flatten_obs[:, start:start+flatdim(v)]
                start += flatdim(v)
            return obs

        def flatten_fun_2d(unflatten_obs):
            obs = []
            for k in env.unwrapped.observation_space.spaces.keys():
                obs.append(unflatten_obs[k])
            obs = np.concatenate(obs, axis=1)
            return obs

        sample_her_transitions = make_sample_her_transitions(
            reward_fun=env.compute_reward,
            flatten_fun=flatten_fun_2d,
            unflatten_fun=unflatten_fun_2d
        )
        replay_buffer = HERReplayBuffer(
            obs_dim=obs_dim, act_dim=act_dim, T=env.spec.max_episode_steps,
            sample_transitions=sample_her_transitions,
            size_in_transitions=replay_size,
        )

    # Count variables
    var_counts = tuple(core.count_vars(scope) for scope in 
                       ['main/pi', 'main/q1', 'main/q2', 'main/v', 'main'])
    print(('\nNumber of parameters: \t pi: %d, \t' + \
           'q1: %d, \t q2: %d, \t v: %d, \t total: %d\n')%var_counts)

    print(f"Number of parameters for value ensemble: {[core.count_vars('ve_main/' + str(e)) for e in range(size_ensemble)]}]")

    # Min Double-Q:
    min_q_pi = tf.minimum(q1_pi, q2_pi)

    # Targets for Q and V regression
    q_backup = tf.stop_gradient(r_ph + gamma*(1-d_ph)*v_targ)
    v_backup = tf.stop_gradient(min_q_pi - alpha * logp_pi)

    for e in range(size_ensemble):
        ve[e]['v_backup'] = tf.stop_gradient(ve[e]['r_ph']) + gamma*(1-ve[e]['d_ph'])*ve[e]['v_targ']

    # Soft actor-critic losses
    pi_loss = tf.reduce_mean(alpha * logp_pi - q1_pi)
    q1_loss = 0.5 * tf.reduce_mean((q_backup - q1)**2)
    q2_loss = 0.5 * tf.reduce_mean((q_backup - q2)**2)
    v_loss = 0.5 * tf.reduce_mean((v_backup - v)**2)
    value_loss = q1_loss + q2_loss + v_loss

    for e in range(size_ensemble):
        ve[e]['loss'] = 0.5 * tf.reduce_mean((ve[e]['v_backup'] - ve[e]['v'])**2)

    # Policy train op 
    # (has to be separate from value train op, because q1_pi appears in pi_loss)
    pi_optimizer = tf.train.AdamOptimizer(learning_rate=lr)
    train_pi_op = pi_optimizer.minimize(pi_loss, var_list=get_vars('main/pi'))

    # Value train op
    # (control dep of train_pi_op because sess.run otherwise evaluates in nondeterministic order)
    value_optimizer = tf.train.AdamOptimizer(learning_rate=lr)
    value_params = get_vars('main/q') + get_vars('main/v')
    with tf.control_dependencies([train_pi_op]):
        train_value_op = value_optimizer.minimize(value_loss, var_list=value_params)

    for e in range(size_ensemble):
        ve[e]['optimizer'] = tf.train.AdamOptimizer(learning_rate=lr)
        ve[e]['train_op'] = ve[e]['optimizer'].minimize(ve[e]['loss'], var_list=get_vars(f've_main/{e}'))

    # Polyak averaging for target variables
    # (control flow because sess.run otherwise evaluates in nondeterministic order)
    with tf.control_dependencies([train_value_op]):
        target_update = tf.group([tf.assign(v_targ, polyak*v_targ + (1-polyak)*v_main)
                                  for v_main, v_targ in zip(get_vars('main'), get_vars('target'))])

    for e in range(size_ensemble):
        with tf.control_dependencies([ve[e]['train_op']]):
            ve[e]['target_update'] = tf.group([
                tf.assign(v_targ, polyak*v_targ + (1-polyak)*v_main)
                for v_main, v_targ in zip(get_vars(f've_main/{e}'), get_vars(f've_target/{e}'))
            ])

    # All ops to call during one training step
    step_ops = [pi_loss, q1_loss, q2_loss, v_loss, q1, q2, v, logp_pi, 
                train_pi_op, train_value_op, target_update]

    for e in range(size_ensemble):
        ve[e]['step_ops'] = [ve[e]['loss'], ve[e]['v'], ve[e]['train_op'], ve[e]['target_update']]

    # Initializing targets to match main variables
    target_init = tf.group([tf.assign(v_targ, v_main)
                              for v_main, v_targ in zip(get_vars('main'), get_vars('target'))])

    config = tf.ConfigProto()
    config.gpu_options.allow_growth = True
    sess = tf.Session(config=config)
    sess.run(tf.global_variables_initializer())
    sess.run(target_init)

    o = env.reset(random_goal=False)
    o = Teacher.set_env_params(env, sample_task_params={'compute_vals_fun': compute_vals, 'init_o': o})

    # Setup model saving
    logger.setup_tf_saver(sess, inputs={'x': x_ph, 'a': a_ph}, 
                                outputs={'mu': mu, 'pi': pi, 'q1': q1, 'q2': q2, 'v': v})

    def get_action(o, deterministic=False):
        act_op = mu if deterministic else pi
        return sess.run(act_op, feed_dict={x_ph: o.reshape(1,-1)})[0]

    def test_agent(n=10):
        global sess, mu, pi, q1, q2, q1_pi, q2_pi
        for j in range(n):
            o, r, d, ep_ret, ep_len = test_env.reset(random_goal=True), 0, False, 0, 0
            # o = Teacher.set_test_env_params(test_env, o)
            while not(d or (ep_len == max_ep_len)):
                # Take deterministic actions at test time 
                o, r, d, info = test_env.step(get_action(o, True))
                ep_ret += r
                ep_len += 1
            logger.store(TestEpRet=ep_ret, TestEpLen=ep_len, TestEpSucc=info['is_success'])
            Teacher.record_test_episode(ep_ret, ep_len)

    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(random_goal=True), 0, False, 0, 0
    episode = dict(obs1=[], acts=[], rews=[], done=[])
    total_steps = steps_per_epoch * epochs

    RANDOM_ACT_EPS = 0.3

    # Main loop: collect experience in env and update/log each epoch
    for t in range(total_steps):
        """
        Until start_steps have elapsed, randomly sample actions
        from a uniform distribution for better exploration. Afterwards, 
        use the learned policy. 
        """
        if t > start_steps and np.random.uniform(0, 1) > RANDOM_ACT_EPS:
            a = get_action(o)
        else:
            a = env.env.action_space.sample()

        # Step the env

        o2, r, d, info = env.step(a)
        ep_ret += r
        ep_len += 1

        # Ignore the "done" signal if it comes from hitting the time
        # horizon (that is, when it's an artificial terminal signal
        # that isn't based on the agent's state)
        d = False if ep_len==max_ep_len else d

        # Store experience to replay buffer
        if not use_her:
            replay_buffer.store(o, a, r, o2, d)
        else:
            episode['obs1'].append(o)
            episode['acts'].append(a)
            episode['rews'].append(r)
            episode['done'].append(d)

        # Super critical, easy to overlook step: make sure to update 
        # most recent observation!
        o = o2

        if d or (ep_len == max_ep_len):
            if use_her:
                episode['obs1'].append(o)
                episode = {k: np.stack(v) for k, v in episode.items()}
                replay_buffer.store(episode)
                episode = dict(obs1=[], acts=[], rews=[], done=[])

            """
            Perform all SAC updates at the end of the trajectory.
            This is a slight difference from the SAC specified in the
            original paper.
            """
            for j in range(np.ceil(ep_len/train_freq).astype('int')):
                _t = time.time()
                batch = replay_buffer.sample_batch(batch_size)
                logger.store(TimePolicyBatch=time.time()-_t)
                _t = time.time()
                feed_dict = {x_ph: batch['obs1'],
                             x2_ph: batch['obs2'],
                             a_ph: batch['acts'],
                             r_ph: batch['rews'],
                             d_ph: batch['done'],
                            }
                outs = sess.run(step_ops, feed_dict)
                logger.store(TimePolicyTrain=time.time()-_t)
                # logger.store(LossPi=outs[0], LossQ1=outs[1], LossQ2=outs[2],
                #              LossV=outs[3], Q1Vals=outs[4], Q2Vals=outs[5],
                #              VVals=outs[6], LogPi=outs[7])
                _t = time.time()
                batch = replay_buffer.sample_batch(batch_size*size_ensemble)
                logger.store(TimeVEBatch=time.time()-_t)
                _t = time.time()
                feed_dict = {}
                for e in range(size_ensemble):
                    feed_dict.update({
                        ve[e]['x_ph']: batch['obs1'][e*batch_size:(e+1)*batch_size],
                        ve[e]['x2_ph']: batch['obs2'][e*batch_size:(e+1)*batch_size],
                        ve[e]['a_ph']: batch['acts'][e*batch_size:(e+1)*batch_size],
                        ve[e]['r_ph']: batch['rews'][e*batch_size:(e+1)*batch_size],
                        ve[e]['d_ph']: batch['done'][e*batch_size:(e+1)*batch_size],
                    })
                outs = sess.run(sum([ve[e]['step_ops'] for e in range(size_ensemble)], []), feed_dict)
                logger.store(TimeVETrain=time.time() - _t)
                for e in range(size_ensemble):
                    logger.store(
                        LossVE=outs[e*len(ve[e]['step_ops'])+0],
                        VE=outs[e*len(ve[e]['step_ops'])+1],
                    )

            logger.store(EpRet=ep_ret, EpLen=ep_len, EpSucc=info['is_success'])

            Teacher.record_train_episode(ep_ret, ep_len)
            o, r, d, ep_ret, ep_len = env.reset(random_goal=False), 0, False, 0, 0
            o = Teacher.set_env_params(env, sample_task_params={'compute_vals_fun': compute_vals, 'init_o': o})

        # End of epoch wrap-up
        if t > 0 and (t + 1) % steps_per_epoch == 0:
            epoch = (t + 1) // steps_per_epoch

            # Save model
            # if (epoch % save_freq == 0) or (epoch == epochs-1):
            #     logger.save_state({'env': env}, None)#itr=epoch)

            # Test the performance of the deterministic version of the agent.
            test_agent(n=nb_test_episodes)
            # Log info about epoch
            logger.log_tabular('Epoch', epoch)
            logger.log_tabular('EpSucc')
            logger.log_tabular('TestEpSucc')
            logger.log_tabular('EpRet', with_min_and_max=True)
            logger.log_tabular('TestEpRet', with_min_and_max=True)
            logger.log_tabular('EpLen', average_only=True)
            logger.log_tabular('TestEpLen', average_only=True)
            logger.log_tabular('TotalEnvInteracts', t+1)
            #logger.log_tabular('Q1Vals', with_min_and_max=True)
            #logger.log_tabular('Q2Vals', with_min_and_max=True)
            #logger.log_tabular('VVals', with_min_and_max=True)
            #logger.log_tabular('LogPi', with_min_and_max=True)
            #logger.log_tabular('LossPi', average_only=True)
            #logger.log_tabular('LossQ1', average_only=True)
            #logger.log_tabular('LossQ2', average_only=True)
            #logger.log_tabular('LossV', average_only=True)
            logger.log_tabular('TimePolicyBatch', average_only=True)
            logger.log_tabular('TimePolicyTrain', average_only=True)
            if size_ensemble > 0:
                logger.log_tabular('LossVE', average_only=True)
                logger.log_tabular('VE')
                logger.log_tabular('TimeVEBatch', average_only=True)
                logger.log_tabular('TimeVETrain', average_only=True)
            logger.log_tabular('Time', time.time()-start_time)
            logger.dump_tabular()

            # Pickle parameterized env data
            #print(logger.output_dir+'/env_params_save.pkl')
            Teacher.dump(logger.output_dir+'/env_params_save.pkl')