# Must be run with OMP_NUM_THREADS=1

import logging
import os
import pprint
import threading
import time
import timeit
import traceback
from collections import Counter, defaultdict

import coolname
import hydra
import numpy as np
import pandas as pd
import plotly.express as px
import torch
import wandb
from omegaconf import DictConfig, OmegaConf
from tokenizers import BertWordPieceTokenizer
from torch import multiprocessing as mp
from torch import nn
from torch.nn import functional as F

from . import buffers as B
from . import envs, losses, models, optimizers, utils
from .torchbeast.core import prof

# Some Global Variables
generator_batcher = utils.Batcher()
grounder_batcher = utils.Batcher()
generator_count = 0
message_counts = None
optimizer_steps = 0
generator_optimizer_steps = 0
grounder_optimizer_steps = 0
goal_counts = None
bert_tokenizer = None


LOCKS = defaultdict(threading.Lock)


def check_goal_completion(env_output, initial_env_state, action, goal, raw_goal):
    if FLAGS.is_babyai:
        old_frame = torch.flatten(initial_env_state["frame"], 2, 3)
        new_frame = torch.flatten(env_output["frame"], 2, 3)
        reached_condition = _check_goal_completion_babyai(
            old_frame,
            new_frame,
            env_output,
            action,
            goal,
        )
    else:
        if FLAGS.language_goals is None:
            nrow, ncol = env_output["chars"].shape[-2:]
            # Current location - note blstats coords are reversed.
            curr_c, curr_r = (
                env_output["blstats"][0, 0, 0].item(),
                env_output["blstats"][0, 0, 1].item(),
            )
            # Goal location

            goal_r, goal_c = divmod(goal.item(), ncol)

            # Make sure we're within the bounds.
            assert 0 <= curr_c < ncol, curr_c
            assert 0 <= curr_r < nrow, curr_r
            assert 0 <= goal_c < ncol, goal_c
            assert 0 <= goal_r < nrow, goal_r

            reached_condition = (curr_c == goal_c) and (curr_r == goal_r)
        else:
            # Check if message equals any of those given.
            reached_condition = (env_output["split_messages"] == raw_goal).all(-1).any()
    return reached_condition


def _check_goal_completion_babyai(old_frame, new_frame, env_output, action, goal):
    """
    Have we completed a goal?
    """
    # Verify completion of linguistic subgoal
    if FLAGS.language_goals is not None and goal.shape != (1, 1):
        raise RuntimeError(f"Invalid goal shape: {goal.shape}")
    if FLAGS.language_goals not in {"xy", None}:
        return env_output["subgoal_done"][goal.squeeze()]

    ans = new_frame == old_frame
    ans = (
        torch.sum(ans, 3) != 3
    )  # Reached if the three elements of the frame are not the same.
    return torch.squeeze(torch.gather(ans, 2, torch.unsqueeze(goal.long(), 2)))


@torch.no_grad()
def act(
    actor_index: int,
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    model: torch.nn.Module,
    generator_model,
    buffers: B.Buffers,
    initial_agent_state_buffers,
    all_proposed_goals,
    achieved_proposed_goals,
    generator_current_target,
):
    """Defines and generates IMPALA actors in multiples threads."""
    num_frames = 0

    try:
        logging.info(f"Actor {actor_index} started.")
        timings = prof.Timings()  # Keep track of how fast things are.
        gym_env = envs.create_env(FLAGS)
        seed = actor_index ^ int.from_bytes(os.urandom(4), byteorder="little")
        gym_env.seed(seed)

        if FLAGS.is_babyai:
            env = envs.babyai.Observation_WrapperSetup(gym_env)
        else:
            env = envs.minihack.TBWrapper(gym_env)

        env_output = env.reset()
        initial_env_state = env.get_initial_env_state(env_output)

        agent_state = model.initial_state(batch_size=1)
        if FLAGS.generator:
            # TODO: fng for xy goals.
            if FLAGS.force_new_goals and FLAGS.language_goals is not None:
                prev_goals = torch.zeros_like(generator_model.goals_mask).unsqueeze(0)
            else:
                prev_goals = None
            generator_output = generator_model(env_output, prev_goals=prev_goals)
        else:
            goal = torch.zeros((1, 1), dtype=torch.int64)
            if FLAGS.is_babyai:
                raw_goal = goal
            elif FLAGS.language_goals is None:
                # XY goals, minihack
                raw_goal = goal
            else:
                raw_goal = torch.zeros_like(env_output["message"])
            generator_output = {"goal": goal, "raw_goal": raw_goal}
        # goal: numeric id (for compactness). raw_goal: the actual goal format
        # (for BabyAI, it's just the same)
        goal = generator_output["goal"]
        raw_goal = generator_output["raw_goal"]

        intrinsic_done = False
        all_proposed_goals[goal.item()] += 1

        agent_output, _ = model(env_output, agent_state, raw_goal)

        reached_condition = check_goal_completion(
            env_output, initial_env_state, agent_output["action"], goal, raw_goal
        )
        intrinsic_done = reached_condition

        while True:
            timings.reset()
            timings.time("get_target")
            index = free_queue.get()
            timings.time("get_queue")
            if index is None:
                print(f"Got None index in worker process {actor_index}, exiting")
                break

            # Write old rollout end.
            initial_env_state_output = utils.map_dict(
                lambda k: f"initial_{k}", initial_env_state, map_keys=True
            )
            buffers.update(
                index,
                0,
                **env_output,
                **agent_output,
                **generator_output,
                **initial_env_state_output,
                intrinsic_done=intrinsic_done,
                reached=reached_condition,
            )
            for i, tensor in enumerate(agent_state):
                initial_agent_state_buffers[index][i][...] = tensor
            timings.time("write")

            # Do new rollout
            for t in range(FLAGS.unroll_length):
                timings.reset()

                agent_output, agent_state = model(env_output, agent_state, raw_goal)

                timings.time("model")

                env_output = env.step(agent_output["action"])
                num_frames += 1

                timings.time("step")

                # Did we reach the intrinsic goal?
                reached_condition = check_goal_completion(
                    env_output,
                    initial_env_state,
                    agent_output["action"],
                    goal,
                    raw_goal,
                )

                # Is env done (either extrinsic reward or max time steps)?
                env_done = env_output["done"][0] == 1

                intrinsic_done = reached_condition or env_done

                initial_env_state_output = utils.map_dict(
                    lambda k: f"initial_{k}", initial_env_state, map_keys=True
                )
                buffers.update(
                    index,
                    t + 1,
                    **env_output,
                    **agent_output,
                    **generator_output,
                    **initial_env_state_output,
                    intrinsic_done=intrinsic_done,
                    reached=reached_condition,
                )

                if intrinsic_done:
                    # Propose a new goal
                    if reached_condition:
                        # Successful completion of intrinsic goal.
                        achieved_proposed_goals[goal.item()] += 1
                        env.intrinsic_episode_step = 0
                    if env_done:
                        # Reached extrinsic goal, or max steps.
                        env_output = env.reset()
                        initial_env_state = env.get_initial_env_state(env_output)

                    # Re-propose goal.
                    if FLAGS.generator:
                        if FLAGS.force_new_goals and FLAGS.language_goals is not None:
                            if env_done:
                                prev_goals = torch.zeros_like(
                                    generator_model.goals_mask
                                ).unsqueeze(0)
                            else:
                                prev_goals[0, goal.item()] = True
                        generator_output = generator_model(
                            env_output, prev_goals=prev_goals
                        )
                    else:
                        goal = torch.zeros((1, 1), dtype=torch.int64)
                        if FLAGS.is_babyai:
                            raw_goal = goal
                        elif FLAGS.language_goals is None:
                            # XY goals, minihack
                            raw_goal = goal
                        else:
                            raw_goal = torch.zeros_like(env_output["message"])
                        generator_output = {"goal": goal, "raw_goal": raw_goal}
                    goal = generator_output["goal"]
                    raw_goal = generator_output["raw_goal"]

                    all_proposed_goals[goal.item()] += 1

                timings.time("write")
            full_queue.put(index)

    except KeyboardInterrupt:
        print(f"Caught KeyboardInterrupt in worker process {actor_index}")
        pass  # Return silently.
    except Exception as e:
        logging.info(f"Exception in worker process {actor_index}")
        traceback.print_exc()
        raise e


def get_batch(
    free_queue: mp.SimpleQueue,
    full_queue: mp.SimpleQueue,
    buffers: B.Buffers,
    initial_agent_state_buffers,
    timings,
):
    """Returns a Batch with the history."""

    with LOCKS["get_batch"]:
        timings.time("lock")
        indices = [full_queue.get() for _ in range(FLAGS.batch_size)]
        timings.time("dequeue")
    batch = buffers.get_batch(indices, device=FLAGS.device)
    initial_agent_state = (
        torch.cat(ts, dim=1)
        for ts in zip(*[initial_agent_state_buffers[m] for m in indices])
    )
    timings.time("batch")
    for m in indices:
        free_queue.put(m)
    timings.time("enqueue")
    initial_agent_state = tuple(
        t.to(device=FLAGS.device, non_blocking=True) for t in initial_agent_state
    )
    timings.time("device")

    return batch, initial_agent_state


def learn_generator_policy(
    generator_model,
    generator_optimizer,
    generator_scheduler,
    generator_batch,
    generator_current_target,
    max_steps,
    stats,
):
    global generator_count

    local_generator_current_target = generator_current_target.value

    generator_outputs = generator_model(generator_batch, prev_goals=None)
    generator_bootstrap_value = generator_outputs["generator_baseline"][-1]

    def compute_generator_reward(intrinsic_episode_step, reached, targ):
        aux = FLAGS.generator_reward_negative * torch.ones(
            intrinsic_episode_step.shape
        ).to(device=FLAGS.device)
        difficult_enough = (intrinsic_episode_step >= targ).float()
        aux += difficult_enough * reached  # +1 reward for reaching and difficult
        aux += ((1 - difficult_enough) * reached) * FLAGS.easy_goal_reward
        return aux

    generator_rewards = compute_generator_reward(
        generator_batch["intrinsic_episode_step"],
        generator_batch["reached"],
        targ=local_generator_current_target,
    )

    # Give negative reward for the null goal
    if FLAGS.language_goals not in {None, "xy"}:
        generator_rewards -= (generator_batch["goal"] == 0).float()

    if torch.mean(generator_rewards).item() >= FLAGS.generator_threshold:
        generator_count += 1
    else:
        generator_count = 0

    if (
        generator_count >= FLAGS.generator_counts
        and local_generator_current_target <= FLAGS.generator_maximum
    ):
        local_generator_current_target += 1
        generator_current_target.value = local_generator_current_target
        generator_count = 0

    if FLAGS.novelty:
        novelty_rewards = compute_novelty_rewards(generator_batch["goal"])
        stats["novelty_rewards"] = novelty_rewards.mean().item()
        stats["novelty_rewards_min"] = novelty_rewards.min().item()
        stats["novelty_rewards_max"] = novelty_rewards.max().item()
        generator_rewards += novelty_rewards

    if FLAGS.reward_clipping == "abs_one":
        generator_clipped_rewards = torch.clamp(generator_rewards, -2, 2)
    else:
        generator_clipped_rewards = generator_rewards

    if not FLAGS.no_extrinsic_rewards:
        generator_clipped_rewards = (
            1.0 * (generator_batch["reward"] > 0).float()
            + generator_clipped_rewards * (generator_batch["reward"] <= 0).float()
        )
        if FLAGS.combine_rewards:
            # Add an additional +1 reward if intrinsic and extrinsic goal are reached.
            reached_both_goals = (
                (generator_batch["reached"] > 0) & (generator_batch["reward"] > 0)
            ).float()
            generator_clipped_rewards += reached_both_goals

    generator_discounts = torch.zeros_like(generator_batch["intrinsic_episode_step"])

    gg_loss, generator_baseline_loss = losses.compute_actor_losses(
        behavior_policy_logits=generator_batch["generator_logits"],
        target_policy_logits=generator_outputs["generator_logits"],
        actions=generator_batch["goal"],
        discounts=generator_discounts,
        rewards=generator_clipped_rewards,
        values=generator_outputs["generator_baseline"],
        bootstrap_value=generator_bootstrap_value,
        baseline_cost=FLAGS.baseline_cost,
    )
    generator_entropy_loss = FLAGS.generator_entropy_cost * losses.compute_entropy_loss(
        generator_outputs["generator_logits"]
    )

    generator_total_loss = gg_loss + generator_entropy_loss + generator_baseline_loss

    intrinsic_rewards_gen = generator_batch["reached"] * (
        1 - 0.9 * (generator_batch["intrinsic_episode_step"].float() / max_steps)
    )

    stats["reached_goal"] = generator_batch["reached"].float().mean().item()
    stats["gen_rewards"] = torch.mean(generator_clipped_rewards).item()
    stats["gg_loss"] = gg_loss.item()
    stats["generator_baseline_loss"] = generator_baseline_loss.item()
    stats["generator_entropy_loss"] = generator_entropy_loss.item()
    stats["generator_intrinsic_rewards"] = intrinsic_rewards_gen.mean().item()
    stats["mean_intrinsic_episode_steps"] = torch.mean(
        generator_batch["intrinsic_episode_step"].float()
    ).item()
    stats["ex_reward"] = torch.mean(generator_batch["reward"]).item()
    stats["generator_current_target"] = local_generator_current_target

    generator_optimizer.zero_grad()
    generator_total_loss.backward()

    nn.utils.clip_grad_norm_(generator_model.parameters(), 40.0)
    generator_optimizer.step()
    generator_scheduler.step()
    global generator_optimizer_steps
    generator_optimizer_steps += 1
    stats["generator_optimizer_steps"] = generator_optimizer_steps
    stats["generator_lr"] = generator_optimizer.param_groups[0]["lr"]


def compute_novelty_rewards(goals):
    goals_flat = goals.view(-1)
    novelty_rewards = torch.zeros_like(goals_flat, dtype=torch.float32)

    global goal_counts
    if goal_counts is None:
        goal_counts = Counter()

    for i, goal in enumerate(goals_flat.cpu().numpy()):
        goal_counts[goal] += 1
        novelty_reward = -FLAGS.novelty_reward * goal_counts[goal]
        novelty_rewards[i] = novelty_reward

    # Add the minimum to prevent insane numbers (TODO: or, normalize)
    novelty_rewards += FLAGS.novelty_reward * min(goal_counts.values())
    # How to prevent massive subtraction? you could use min
    novelty_rewards = novelty_rewards.view_as(goals)

    return novelty_rewards


def learn_generator_grounder(
    generator_model, grounder_optimizer, grounder_scheduler, grounder_batch, stats
):
    # Train the grounding model
    grounder_outputs = generator_model.forward_grounder(grounder_batch)
    grounding_logits = grounder_outputs["logits"].unsqueeze(0)
    grounding_mask = torch.ones_like(grounding_logits, dtype=torch.bool)
    if FLAGS.is_babyai:
        grounding_targets = grounder_batch["targets"].float()
    else:
        # Targets is split messages. Convert to onehot targets
        grounding_targets = torch.zeros_like(grounding_mask, dtype=torch.float32)
        # Convert to indices for targets
        message_strs = models.minihack.language.messages_to_bytes(
            grounder_batch["targets"].cpu().numpy(), trim=False
        )
        it = np.nditer(message_strs, flags=["multi_index"])
        for msg in it:
            msg = msg.item()
            if msg not in generator_model.lang_hashes:
                continue
            msg_index = generator_model.lang_hashes[msg]
            orig_i = it.multi_index[:-1]
            grounding_targets[orig_i][msg_index] = 1.0

    grounding_preds = grounder_outputs["preds"].unsqueeze(0)

    grounding_logits = grounding_logits[grounding_mask]
    grounding_targets = grounding_targets[grounding_mask]
    grounding_preds = grounding_preds[grounding_mask]

    z = grounding_targets.mean()
    if z != 0:
        pos_weight = 1 / z
    else:
        pos_weight = None
    grounding_loss = F.binary_cross_entropy_with_logits(
        grounding_logits,
        grounding_targets,
        pos_weight=pos_weight,
    )
    grounding_acc = (grounding_preds == grounding_targets).float().mean()
    stats["grounding_acc"] = grounding_acc.item()
    stats["grounding_loss"] = grounding_loss.item()
    stats["grounding_f1"] = utils.f1_score(
        y_true=grounding_targets, y_pred=grounding_preds.float()
    ).item()

    grounder_optimizer.zero_grad()
    grounding_loss.backward()
    nn.utils.clip_grad_norm_(generator_model.parameters(), 40.0)
    grounder_optimizer.step()
    grounder_scheduler.step()
    global grounder_optimizer_steps
    grounder_optimizer_steps += 1
    stats["grounder_optimizer_steps"] = grounder_optimizer_steps
    stats["grounder_lr"] = grounder_optimizer.param_groups[0]["lr"]


@utils.require_lock(LOCKS, "learn")
def learn(
    actor_model,
    model,
    actor_generator_model,
    generator_model,
    rnd_model,
    batch,
    initial_agent_state,
    optimizer,
    generator_optimizer,
    grounder_optimizer,
    scheduler,
    generator_scheduler,
    grounder_scheduler,
    generator_current_target,
    max_steps=100.0,
):
    """Performs a learning (optimization) step for the policy, and for the generator whenever the generator batch is full."""
    stats = {}

    # Loading Batch
    subgoal_done = batch["subgoal_done"][1:].to(device=FLAGS.device)
    subgoal_achievable = batch["subgoal_achievable"][1:].to(device=FLAGS.device)
    goal = batch["goal"][1:].to(device=FLAGS.device)
    reached = batch["reached"][1:].to(device=FLAGS.device)
    intrinsic_done = batch["intrinsic_done"][1:].to(device=FLAGS.device)

    if FLAGS.language_goals not in {"xy", None}:
        goal_was_achievable = subgoal_achievable.gather(-1, goal.unsqueeze(-1)).squeeze(
            -1
        )
        goal_was_achievable = goal_was_achievable.float().mean().item()
    else:
        goal_was_achievable = None

    if FLAGS.generator:
        intrinsic_rewards = FLAGS.intrinsic_reward_coef * reached.float()
        intrinsic_rewards = intrinsic_rewards * (
            intrinsic_rewards
            - 0.9 * (batch["intrinsic_episode_step"][1:].float() / max_steps)
        )
    else:
        # Still may use other forms of intrinsic rewards.
        intrinsic_rewards = torch.zeros_like(reached, dtype=torch.float32)

    if FLAGS.naive_message_reward > 0:
        if FLAGS.is_babyai:
            encountered_message = batch["subgoal_done"].any(-1).float()
        else:
            encountered_message = (~((batch["message"][..., 0] == 101) & (batch["message"][..., 1] == 102))).float()

        if FLAGS.naive_message_reward_format == "learn":
            # LEARN paper, page 4:
            # let phi(s) be potential function, probability of related over
            # unrelated. (This is encountered_message). Then reward is
            # discount_factor * phi(s) - phi({s - 1})
            learn_reward = (FLAGS.discounting * encountered_message[1:]) - encountered_message[:-1]
            intrinsic_rewards += learn_reward * FLAGS.naive_message_reward
        else:
            # Standard reward
            intrinsic_rewards += encountered_message[1:] * FLAGS.naive_message_reward

    learner_outputs, _ = model(
        batch, initial_agent_state, batch["raw_goal"].squeeze(-1)
    )
    if FLAGS.noveld:
        novelty, phi_loss, message_novelty, message_phi_loss = rnd_model(
            batch, optimize=True
        )

        # Zero out novelty reward for initial states
        not_initial_state = batch["extrinsic_episode_step"][1:] > 1
        # Zero out novelty reward for already-visited states in an episode
        first_visit = batch["state_visits"][1:].squeeze(-1) <= 1
        novelty = losses.compute_noveld(
            novelty, mask=(not_initial_state & first_visit), alpha=FLAGS.noveld_alpha
        )

        stats["noveld_novelty"] = novelty.mean().item()
        stats["noveld_loss"] = phi_loss.item()

        intrinsic_rewards += FLAGS.noveld_novelty_coef * novelty

        if FLAGS.separate_message_novelty:
            # Zero out novelty reward for already-visited states in an episode
            first_visit_m = batch["state_visits_m"][1:].squeeze(-1) <= 1
            message_novelty = losses.compute_noveld(
                message_novelty,
                mask=(not_initial_state & first_visit_m),
                alpha=FLAGS.noveld_alpha,
            )

            stats["message_noveld_novelty"] = message_novelty.mean().item()
            stats["message_noveld_loss"] = message_phi_loss.item()

            intrinsic_rewards += FLAGS.separate_message_novelty_coef * message_novelty

            if FLAGS.plot_novelty:
                novelty_records = defaultdict(list)
                if FLAGS.is_babyai:
                    mn_flat = message_novelty.view(-1).cpu().numpy()
                    sg_flat = (
                        batch["subgoal_done"][1:]
                        .view(-1, batch["subgoal_done"].shape[-1])
                        .cpu()
                        .numpy()
                    )
                    for sgs, mn in zip(sg_flat[mn_flat > 0], mn_flat[mn_flat > 0]):
                        where = np.where(sgs)[0]
                        # Add this to novelty
                        for sg in where:
                            novelty_records[sg].append(mn)
                else:
                    mn_flat = (
                        message_novelty.unsqueeze(1)
                        .expand(-1, batch["split_messages"].shape[-2], -1)
                        .contiguous()
                        .view(-1)
                        .cpu()
                        .numpy()
                    )
                    global bert_tokenizer
                    if bert_tokenizer is None:
                        bert_tokenizer = BertWordPieceTokenizer(
                            FLAGS.minihack.msg.word.vocab_file, lowercase=True
                        )
                    sm_flat = (
                        batch["split_messages"][1:]
                        .view(-1, batch["split_messages"].shape[-1])
                        .cpu()
                        .numpy()
                    )
                    sm_flat = bert_tokenizer.decode_batch(sm_flat)
                    for sm, mn in zip(sm_flat, mn_flat):
                        if sm and mn > 0:
                            novelty_records[sm].append(mn)
                stats["novelty_records"] = novelty_records

    batch = utils.map_dict(lambda t: t[1:], batch)
    learner_outputs = utils.map_dict(lambda t: t[:-1], learner_outputs)
    rewards = batch["reward"]

    if not FLAGS.int.twoheaded:
        # Add intrinsic rewards to extrinsic rewards
        total_rewards = rewards + intrinsic_rewards
    else:
        total_rewards = rewards

    if FLAGS.reward_clipping == "abs_one":
        clipped_rewards = torch.clamp(total_rewards, -1, 1)
    elif FLAGS.reward_clipping == "none":
        clipped_rewards = total_rewards

    discounts = (
        ~batch["done"]
    ).float() * FLAGS.discounting  # Account for "done" episodes
    clipped_rewards += 1.0 * (rewards > 0.0).float()

    total_loss = 0
    # ==== STUDENT LOSS ====
    if not FLAGS.no_extrinsic_rewards:
        pg_loss, baseline_loss = losses.compute_actor_losses(
            behavior_policy_logits=batch["policy_logits"],
            target_policy_logits=learner_outputs["policy_logits"],
            actions=batch["action"],
            discounts=discounts,
            rewards=clipped_rewards,
            values=learner_outputs["baseline"],
            bootstrap_value=learner_outputs["baseline"][-1],
            baseline_cost=FLAGS.baseline_cost,
        )
        entropy_loss = FLAGS.entropy_cost * losses.compute_entropy_loss(
            learner_outputs["policy_logits"]
        )

        total_loss += pg_loss + baseline_loss + entropy_loss

    # ==== INTRINSIC LOSS ====
    if FLAGS.int.twoheaded or FLAGS.no_extrinsic_rewards:
        int_pg_loss, int_baseline_loss = losses.compute_actor_losses(
            behavior_policy_logits=batch["policy_logits"],
            target_policy_logits=learner_outputs["policy_logits"],
            actions=batch["action"],
            discounts=discounts,  # intrinsic discounts
            rewards=intrinsic_rewards,  # intrinsic reward
            values=learner_outputs["int_baseline"],
            bootstrap_value=learner_outputs["int_baseline"][-1],  # intrinsic bootstrap
            baseline_cost=FLAGS.int.baseline_cost,
        )
        stats.update(
            {
                "int_pg_loss": int_pg_loss.item(),
                "int_baseline_loss": int_baseline_loss.item(),
            }
        )
        total_loss += int_pg_loss + int_baseline_loss

    optimizer.zero_grad()
    total_loss.backward()
    nn.utils.clip_grad_norm_(model.parameters(), 40.0)
    optimizer.step()
    scheduler.step()
    actor_model.load_state_dict(model.state_dict())
    global optimizer_steps

    # ==== LOG STATS ====
    optimizer_steps += 1
    episode_returns = batch["episode_return"][batch["done"]]
    episode_reward = batch["reward"][batch["done"]]
    # Measure goal timeout rate.
    goal_reached_rate = reached[intrinsic_done].float().mean().nan_to_num(0).item()

    stats.update(
        {
            "mean_episode_return": episode_returns.mean().nan_to_num(0).item(),
            "mean_episode_final_reward": episode_reward.mean().nan_to_num(0).item(),
            "intrinsic_rewards": intrinsic_rewards.mean().item(),
            "total_loss": total_loss.item(),
            "pg_loss": pg_loss.item(),
            "baseline_loss": baseline_loss.item(),
            "entropy_loss": entropy_loss.item(),
            "mean_subgoal_done": subgoal_done.float().mean().item(),
            "goal_achievable": goal_was_achievable,
            "goal_reached_rate": goal_reached_rate,
            "optimizer_steps": optimizer_steps,
            "lr": optimizer.param_groups[0]["lr"],
        }
    )

    # ==== OPTIMIZE GENERATOR ====
    if FLAGS.generator:
        learn_generator(
            generator_model,
            actor_generator_model,
            generator_optimizer,
            grounder_optimizer,
            generator_scheduler,
            grounder_scheduler,
            batch,
            generator_current_target,
            max_steps,
            stats,
        )

    return stats


def learn_generator(
    generator_model,
    actor_generator_model,
    generator_optimizer,
    grounder_optimizer,
    generator_scheduler,
    grounder_scheduler,
    batch,
    generator_current_target,
    max_steps,
    stats,
):
    global generator_batcher
    global grounder_batcher

    initial_keys = [k for k in batch.keys() if k.startswith("initial")]

    # ==== UPDATE GOALS ====
    if FLAGS.language_goals in {"online_naive", "online_grounding"}:
        # Update possible goals for teacher.
        generator_model.update_goals(batch)
        stats["generator_goals_seen"] = generator_model.goals_mask.sum().item()

    if not FLAGS.train_generator:
        # Don't learn grounder or policy net, but update goals
        actor_generator_model.load_state_dict(generator_model.state_dict())
        return

    # ==== LEARN GROUNDER ====
    if FLAGS.language_goals == "online_grounding":
        if FLAGS.is_babyai:
            has_message = batch["subgoal_done"].any(-1).bool()
        else:
            # FIXME: check word gru here.
            has_message = batch["message"][..., 0] != 0
        # Add environments with messages to grounder batcher
        grounder_items = {k: batch[k] for k in initial_keys}
        if FLAGS.is_babyai:
            grounder_items["targets"] = batch["subgoal_done"]
        else:
            grounder_items["targets"] = batch["split_messages"]
        grounder_items = utils.map_dict(
            lambda x: x.split("initial_")[-1], grounder_items, map_keys=True
        )

        grounder_batcher.append(grounder_items, mask=has_message, device=FLAGS.device)

        if grounder_batcher.ready(FLAGS.grounder_batch_size):
            grounder_batch = grounder_batcher.get_batch(FLAGS.grounder_batch_size)
            grounder_batch = utils.map_dict(lambda x: x.unsqueeze(0), grounder_batch)
            learn_generator_grounder(
                generator_model,
                grounder_optimizer,
                grounder_scheduler,
                grounder_batch,
                stats,
            )

    # ==== LEARN POLICY ====
    generator_batch_keys = [
        "goal",
        "subgoal_achievable",
        "subgoal_done",
        "intrinsic_episode_step",
        "generator_logits",
        "reached",
        "reward",
        "carried_obj",
        "carried_col",
        # All initial env reprs
        *initial_keys,
    ]

    new_items = {k: batch[k] for k in generator_batch_keys}
    # Remove initial_ prefix from keys.
    new_items = utils.map_dict(
        lambda x: x.split("initial_")[-1], new_items, map_keys=True
    )
    generator_batcher.append(
        new_items, mask=batch["intrinsic_done"], device=FLAGS.device
    )

    # Perform update
    if generator_batcher.ready(FLAGS.generator_batch_size):
        generator_batch = generator_batcher.get_batch(FLAGS.generator_batch_size)
        generator_batch = utils.map_dict(lambda x: x.unsqueeze(0), generator_batch)
        learn_generator_policy(
            generator_model,
            generator_optimizer,
            generator_scheduler,
            generator_batch,
            generator_current_target,
            max_steps,
            stats,
        )

    # Update generator
    actor_generator_model.load_state_dict(generator_model.state_dict())


def train():
    """Full training loop."""
    checkpointpath = os.path.expandvars(
        os.path.expanduser("%s/%s/%s" % (FLAGS.savedir, FLAGS.name, "model.tar"))
    )

    if FLAGS.num_buffers is None:  # Set sensible default for num_buffers.
        FLAGS.num_buffers = max(4 * FLAGS.num_actors, 2 * FLAGS.batch_size)
    if FLAGS.num_actors >= FLAGS.num_buffers:
        raise ValueError("num_buffers should be larger than num_actors")

    T = FLAGS.unroll_length
    B = FLAGS.batch_size

    logging.info(f"Using device {FLAGS.device}")

    env = envs.create_env(FLAGS)

    (
        model,
        generator_model,
        learner_model,
        learner_generator_model,
        buffers,
    ) = models.create_models_and_buffers(env, FLAGS)

    if FLAGS.noveld:
        rnd_model = models.create_rnd_model(env, FLAGS)
    else:
        rnd_model = None

    model.share_memory()
    generator_model.share_memory()

    (
        optimizer,
        generator_optimizer,
        grounder_optimizer,
        scheduler,
        generator_scheduler,
        grounder_scheduler,
    ) = optimizers.create_optimizers(
        learner_model, learner_generator_model, FLAGS.total_frames, FLAGS
    )

    # Initialize tracking of proposed goals
    manager = mp.Manager()
    proposed_goals = {
        "all": manager.dict(),
        "achieved": manager.dict(),
    }
    generator_current_target = manager.Value("i", int(FLAGS.generator_start_target))

    if FLAGS.is_babyai:
        num_possible_goals = generator_model.logits_size
    else:
        if FLAGS.language_goals is None:
            num_possible_goals = generator_model.num_actions
        else:
            num_possible_goals = FLAGS.max_online_goals + 1
    for i in range(num_possible_goals):
        proposed_goals["all"][i] = 0
        proposed_goals["achieved"][i] = 0

    # Add initial RNN state.
    initial_agent_state_buffers = []
    for _ in range(FLAGS.num_buffers):
        state = model.initial_state(batch_size=1)
        for t in state:
            t.share_memory_()
        initial_agent_state_buffers.append(state)

    actor_processes = []
    ctx = mp.get_context("fork")
    free_queue = ctx.SimpleQueue()
    full_queue = ctx.SimpleQueue()

    for i in range(FLAGS.num_actors):
        if FLAGS.debug:
            procfn = threading.Thread
        else:
            procfn = ctx.Process
        actor = procfn(
            target=act,
            args=(
                i,
                free_queue,
                full_queue,
                model,
                generator_model,
                buffers,
                initial_agent_state_buffers,
                proposed_goals["all"],
                proposed_goals["achieved"],
                generator_current_target,
            ),
        )
        actor.start()
        actor_processes.append(actor)

    stat_keys = [
        "total_loss",
        "mean_episode_return",
        "pg_loss",
        "baseline_loss",
        "entropy_loss",
        "gen_rewards",
        "gg_loss",
        "generator_entropy_loss",
        "generator_baseline_loss",
        "mean_intrinsic_rewards",
        "mean_intrinsic_episode_steps",
        "ex_reward",
        "generator_current_target",
    ]
    logging.info("# Step\t{}".format("\t".join(stat_keys)))

    frames = 0
    stats = {}

    def batch_and_learn(i, generator_current_target):
        """Thread target for the learning process."""
        nonlocal frames, stats
        timings = prof.Timings()
        index = 0
        while frames < FLAGS.total_frames:
            timings.reset()
            batch, agent_state = get_batch(
                free_queue,
                full_queue,
                buffers,
                initial_agent_state_buffers,
                timings,
            )
            stats = learn(
                model,
                learner_model,
                generator_model,
                learner_generator_model,
                rnd_model,
                batch,
                agent_state,
                optimizer,
                generator_optimizer,
                grounder_optimizer,
                scheduler,
                generator_scheduler,
                grounder_scheduler,
                generator_current_target,
                max_steps=env.max_steps if FLAGS.is_babyai else env._max_episode_steps,
            )

            timings.time("learn")
            with LOCKS["batch_and_learn"]:
                to_log = dict(frames=frames)
                to_log.update(stats)
                to_log = utils.filter_dict(
                    lambda k: k not in {"grounder_dist", "goal_dist"},
                    to_log,
                    filter_keys=True,
                )
                frames += T * B
            index += 1
            if FLAGS.verbose and index % 15 == 0:
                logging.info(f"Batch and learn {i}: {timings.summary()}")
                timings.reset()

    for m in range(FLAGS.num_buffers):
        free_queue.put(m)

    threads = []
    for i in range(FLAGS.num_threads):
        thread = threading.Thread(
            target=batch_and_learn,
            name="batch-and-learn-%d" % i,
            args=(i, generator_current_target),
        )
        thread.start()
        threads.append(thread)

    def checkpoint():
        if FLAGS.disable_checkpoint:
            return
        logging.info(f"Saving checkpoint to {checkpointpath}")
        torch.save(
            {
                "model_state_dict": model.state_dict(),
                "generator_model_state_dict": generator_model.state_dict(),
                "optimizer_state_dict": optimizer.state_dict(),
                "generator_optimizer_state_dict": generator_optimizer.state_dict(),
                "scheduler_state_dict": scheduler.state_dict(),
                "generator_scheduler_state_dict": generator_scheduler.state_dict(),
                "FLAGS": vars(FLAGS),
            },
            checkpointpath,
        )

    timer = timeit.default_timer

    # Plots, global variables
    template_plot_data = []
    template_plot_data_norm = []
    noveld_plot_data = []
    all_templates_plot_norm = None
    achieved_templates_plot_norm = None
    noveld_plot = None
    logged_templates_plot = True
    logged_noveld_plot = True
    percent_goals_achieved = None
    novelty_records = defaultdict(list)

    try:
        last_checkpoint_time = timer()
        last_template_time = timer()
        last_plot_update_time = timer()

        while frames < FLAGS.total_frames:
            frame_interval = frames
            time_interval = timer()
            if FLAGS.debug:
                sleep_time = 5 * 60
            else:
                sleep_time = 5
            time.sleep(sleep_time)
            this_frames = frames
            if timer() - last_checkpoint_time > 10 * 60:  # Save every 10 min.
                checkpoint()
                last_checkpoint_time = timer()

            if FLAGS.plot_novelty and "novelty_records" in stats:
                # Update novelty runing averages
                for template, novelties in stats["novelty_records"].items():
                    novelty_records[template].extend(novelties)

            if timer() - last_template_time > 1 * 60:  # Update goal data every 1 min.
                if FLAGS.plot_novelty:
                    # Compute means and reset.
                    novelty_norm = {k: np.mean(v) for k, v in novelty_records.items()}
                    for t, nvlty in novelty_norm.items():
                        if FLAGS.is_babyai:
                            template = learner_generator_model.lang_templates[t]
                        else:
                            template = t
                        noveld_plot_data.append(
                            {
                                "num_frames": this_frames,
                                "novelty": nvlty,
                                "template": template,
                            }
                        )
                    novelty_records = defaultdict(list)

                # Combine proposed goals by templates.
                if FLAGS.language_goals is not None:
                    proposed_goals_templates = {"all": Counter(), "achieved": Counter()}
                    for i, template in enumerate(
                        learner_generator_model.lang_templates
                    ):
                        proposed_goals_templates["all"][template] += proposed_goals[
                            "all"
                        ][i]
                        proposed_goals_templates["achieved"][
                            template
                        ] += proposed_goals["achieved"][i]
                        # Reset counts for the next interval
                        proposed_goals["all"][i] = 0
                        proposed_goals["achieved"][i] = 0

                    for template in proposed_goals_templates["all"]:
                        n_all = proposed_goals_templates["all"][template]
                        n_achieved = proposed_goals_templates["achieved"][template]
                        if n_all == 0:
                            continue
                        template_plot_data.append(
                            {
                                "num_frames": this_frames,
                                "n_all": n_all,
                                "n_achieved": n_achieved,
                                "template": template,
                            }
                        )

                    # Compute proportion of goals achieved
                    percent_goals_achieved = sum(
                        proposed_goals_templates["achieved"].values()
                    ) / (sum(proposed_goals_templates["all"].values()) + 1e-10)

                    # Normalize the goals
                    all_sum = sum(proposed_goals_templates["all"].values())
                    achieved_sum = sum(proposed_goals_templates["achieved"].values())
                    proposed_goals_templates["all"] = utils.map_dict(
                        lambda v: v / (all_sum + 1e-10), proposed_goals_templates["all"]
                    )
                    proposed_goals_templates["achieved"] = utils.map_dict(
                        lambda v: v / (achieved_sum + 1e-10),
                        proposed_goals_templates["achieved"],
                    )

                    for template in proposed_goals_templates["all"]:
                        n_all = proposed_goals_templates["all"][template]
                        n_achieved = proposed_goals_templates["achieved"][template]
                        if n_all == 0:
                            continue
                        template_plot_data_norm.append(
                            {
                                "num_frames": this_frames,
                                "n_all": n_all,
                                "n_achieved": n_achieved,
                                "template": template,
                            }
                        )

                last_template_time = timer()

            # Combine proposed goals by templates.
            if timer() - last_plot_update_time > 5 * 60:  # Update plots every 5 min.
                if FLAGS.wandb and FLAGS.plot_novelty:
                    noveld_df = pd.DataFrame(noveld_plot_data)
                    noveld_plot = px.line(
                        noveld_df, x="num_frames", y="novelty", color="template"
                    )
                    logged_noveld_plot = False

                if FLAGS.wandb and (
                    FLAGS.language_goals not in {"xy", None} or not FLAGS.is_babyai
                ):
                    templates_df_norm = pd.DataFrame(template_plot_data_norm)
                    if len(templates_df_norm.columns) != 0:
                        all_templates_plot_norm = px.line(
                            templates_df_norm,
                            x="num_frames",
                            y="n_all",
                            color="template",
                        )
                        achieved_templates_plot_norm = px.line(
                            templates_df_norm,
                            x="num_frames",
                            y="n_achieved",
                            color="template",
                        )
                        logged_templates_plot = False

                last_plot_update_time = timer()

            fps = (frames - frame_interval) / (timer() - time_interval)
            if stats.get("episode_returns", None):
                mean_return = (
                    "Return per episode: %.1f. " % stats["mean_episode_return"]
                )
            else:
                mean_return = ""
            total_loss = stats.get("total_loss", float("inf"))
            # Filter out stats that are none
            stats_not_none = utils.filter_dict(
                lambda v: v is not None and not isinstance(v, tuple),
                stats,
            )
            stats_not_none = utils.filter_dict(
                lambda k: k not in {"novelty_records"},
                stats_not_none,
                filter_keys=True,
            )
            if FLAGS.wandb:
                metrics_to_log = {
                    "num_frames": frames,
                    "fps": fps,
                    **stats_not_none,
                    "percent_goals_achieved": percent_goals_achieved,
                }
                if not logged_templates_plot:
                    metrics_to_log.update(
                        {
                            "all_templates_norm": all_templates_plot_norm,
                            "achieved_templates_norm": achieved_templates_plot_norm,
                        }
                    )
                if not logged_noveld_plot:
                    metrics_to_log.update(
                        {
                            "noveld_plot": noveld_plot,
                        }
                    )
                wandb.log(metrics_to_log, step=frames)
                logged_templates_plot = True
                logged_noveld_plot = True

            logging.info(
                f"After {frames} frames: loss {total_loss:f} @ {fps:.1f} fps. {mean_return}Stats: {pprint.pformat(stats_not_none)}"
            )
    except KeyboardInterrupt:
        return  # Try joining actors then quit.
    except Exception:
        import traceback

        traceback.print_exc()
        logging.info("Got exception in main process, exiting")
    else:
        for thread in threads:
            thread.join()
        logging.info(f"Learning finished after {frames} frames.")
    finally:
        for _ in range(FLAGS.num_actors):
            free_queue.put(None)
        for actor in actor_processes:
            actor.join(timeout=1)

    checkpoint()
    return frames, model.state_dict()


def is_babyai(env_name):
    return env_name.startswith("BabyAI")


OmegaConf.register_new_resolver("is_babyai", is_babyai)
OmegaConf.register_new_resolver("uid", lambda: coolname.generate_slug(3))


@hydra.main(config_path="conf", config_name="config")
def main(cfg: DictConfig) -> None:
    global FLAGS
    FLAGS = cfg

    if FLAGS.noveld and not FLAGS.is_babyai:
        raise NotImplementedError(
            "Use minihack codebase to run NovelD/L-NovelD on MiniHack."
        )

    assert FLAGS.language_goals in {
        None,
        "xy",
        "online_naive",
        "online_grounding",
        "achievable",
    }
    if not FLAGS.is_babyai:
        assert (not FLAGS.generator) or FLAGS.language_goals in {
            "online_naive",
            "online_grounding",
            None,
        }, "Must use online methods or no language for Minihack"

    if FLAGS.wandb:
        wandb.init(
            project=str(FLAGS.project),
            group=str(FLAGS.group),
            name=str(FLAGS.name),
            config=vars(FLAGS),
        )

    train()


if __name__ == "__main__":
    main()
