# coding=utf-8
# Copyright 2020-present the HuggingFace Inc. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
The Trainer class, to easily train a 🤗 Transformers from scratch or finetune it on a new task.
"""

import collections
import gc
import inspect
import math
import os
import re
import shutil
import sys
import time
import warnings
from logging import StreamHandler
from pathlib import Path
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
from numpy.lib.arraysetops import isin
from numpy.lib.shape_base import column_stack
import wandb

# Integrations must be imported before ML frameworks:
from transformers.integrations import (  # isort: split
    default_hp_search_backend,
    get_reporting_integration_callbacks,
    hp_params,
    is_fairscale_available,
    is_optuna_available,
    is_ray_tune_available,
    run_hp_search_optuna,
    run_hp_search_ray,
    init_deepspeed,
)

import numpy as np
import torch
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data.sampler import RandomSampler, SequentialSampler

from transformers.data.data_collator import (
    DataCollator,
    DataCollatorWithPadding,
    default_data_collator,
)
from transformers.file_utils import (
    WEIGHTS_NAME,
    is_apex_available,
    is_datasets_available,
    is_in_notebook,
    is_sagemaker_distributed_available,
    is_torch_tpu_available,
    is_training_run_on_sagemaker,
)
from transformers.modeling_utils import PreTrainedModel, unwrap_model
from transformers.optimization import Adafactor, AdamW, get_scheduler
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
from transformers.trainer_callback import (
    CallbackHandler,
    DefaultFlowCallback,
    PrinterCallback,
    ProgressCallback,
    TrainerCallback,
    TrainerControl,
    TrainerState,
)
from transformers.trainer_pt_utils import (
    DistributedLengthGroupedSampler,
    DistributedSamplerWithLoop,
    DistributedTensorGatherer,
    LabelSmoother,
    LengthGroupedSampler,
    SequentialDistributedSampler,
    distributed_broadcast_scalars,
    distributed_concat,
    get_parameter_names,
    nested_concat,
    nested_detach,
    nested_numpify,
    nested_xla_mesh_reduce,
    reissue_pt_warnings,
)
from transformers.trainer_utils import (
    PREFIX_CHECKPOINT_DIR,
    BestRun,
    EvalPrediction,
    HPSearchBackend,
    PredictionOutput,
    ShardedDDPOption,
    TrainerMemoryTracker,
    TrainOutput,
    default_compute_objective,
    default_hp_space,
    denumpify_detensorize,
    get_last_checkpoint,
    set_seed,
    speed_metrics,
)
from transformers.training_args import ParallelMode, TrainingArguments
from transformers.utils import logging
from transformers.utils.modeling_auto_mapping import (
    MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES,
)
from transformers import Trainer
from transformers.integrations import WandbCallback, rewrite_logs
from models.utils import topk

_is_native_amp_available = False

DEFAULT_CALLBACKS = [DefaultFlowCallback]
DEFAULT_PROGRESS_CALLBACK = ProgressCallback

if is_in_notebook():
    from .utils.notebook import NotebookProgressCallback

    DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback

if is_apex_available():
    from apex import amp

if version.parse(torch.__version__) >= version.parse("1.6"):
    _is_native_amp_available = True
    from torch.cuda.amp import autocast

if is_datasets_available():
    import datasets

if is_torch_tpu_available():
    import torch_xla.core.xla_model as xm
    import torch_xla.debug.metrics as met
    import torch_xla.distributed.parallel_loader as pl

if is_fairscale_available():
    import fairscale
    from fairscale.nn.data_parallel import ShardedDataParallel as ShardedDDP
    from fairscale.optim import OSS
    from fairscale.optim.grad_scaler import ShardedGradScaler

    if version.parse(fairscale.__version__) >= version.parse("0.3"):
        from fairscale.nn.data_parallel import (
            FullyShardedDataParallel as FullyShardedDDP,
        )
        from fairscale.nn.wrap import auto_wrap
    else:
        FullyShardedDDP = None

if is_sagemaker_distributed_available():
    import smdistributed.dataparallel.torch.distributed as dist
    from smdistributed.dataparallel.torch.parallel.distributed import (
        DistributedDataParallel as DDP,
    )
else:
    import torch.distributed as dist

if is_training_run_on_sagemaker():
    logging.add_handler(StreamHandler(sys.stdout))


if TYPE_CHECKING:
    import optuna

logger = logging.get_logger(__name__)


class WandbCallbackViz(WandbCallback):
    def on_log(self, args, state, control, model=None, logs=None, **kwargs):
        if self._wandb is None:
            return
        if not self._initialized:
            self.setup(args, state, model, reinit=False)
        is_table = len(logs) == 1

        if state.is_world_process_zero:
            if is_table:
                self._wandb.log(logs)
            else:
                logs = rewrite_logs(logs)
                self._wandb.log(logs, step=state.global_step)

class VisualizationTrainer(Trainer):
    def __init__(self, *args, **kwargs):

        self.table_viz= kwargs.pop("table_viz", True)
        assert "train_collator" in kwargs
        assert "eval_collator" in kwargs
        self.train_data_collator= kwargs.pop("train_collator")
        self.eval_data_collator= kwargs.pop("eval_collator")
        self.data_collator = None
        super().__init__(*args, **kwargs)
        self.pop_callback(WandbCallback)
        self.add_callback(WandbCallbackViz)

    def prediction_step(
        self,
        model: nn.Module,
        inputs: Dict[str, Union[torch.Tensor, Any]],
        prediction_loss_only: bool,
        ignore_keys: Optional[List[str]] = None,
        wandb_table=None,
        k_list=[1, 5],
    ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]:
        """
        Perform an evaluation step on :obj:`model` using obj:`inputs`.

        Subclass and override to inject custom behavior.

        Args:
            model (:obj:`nn.Module`):
                The model to evaluate.
            inputs (:obj:`Dict[str, Union[torch.Tensor, Any]]`):
                The inputs and targets of the model.

                The dictionary will be unpacked before being fed to the model. Most models expect the targets under the
                argument :obj:`labels`. Check your model's documentation for all accepted arguments.
            prediction_loss_only (:obj:`bool`):
                Whether or not to return the loss only.
            ignore_keys (:obj:`Lst[str]`, `optional`):
                A list of keys in the output of your model (if it is a dictionary) that should be ignored when
                gathering predictions.

        Return:
            Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: A tuple with the loss, logits and
            labels (each being optional).
        """
        has_labels = all(inputs.get(k) is not None for k in self.label_names)
        inputs = self._prepare_inputs(inputs)
        if ignore_keys is None:
            if hasattr(self.model, "config"):
                ignore_keys = getattr(
                    self.model.config, "keys_to_ignore_at_inference", []
                )
            else:
                ignore_keys = []

        # labels may be popped when computing the loss (label smoothing for instance) so we grab them first.
        if has_labels:
            labels = nested_detach(tuple(inputs.get(name) for name in self.label_names))
            if len(labels) == 1:
                labels = labels[0]
        else:
            labels = None

        lm_loss = None
        sent_loss = None
        retrieval_loss = None

        input_ids = inputs["input_ids"].clone()
        labels = inputs["labels"].clone()
        wandb_input_ids = inputs["input_ids"].clone()
        wandb_labels = inputs["labels"].clone()
        
        with torch.no_grad():
            if has_labels:

                loss, outputs = self.compute_loss(model, inputs, return_outputs=True)
                loss = loss.mean().detach()
                if isinstance(outputs, dict):
                    logits = tuple(
                        v for k, v in outputs.items() if k not in ignore_keys + ["loss"]
                    )
                else:
                    logits = outputs[1:]
                if "lm_loss" in outputs:
                    lm_loss = outputs["lm_loss"]
                    lm_loss = lm_loss.mean().detach()
                if "sent_loss" in outputs:
                    sent_loss = outputs["sent_loss"]
                    sent_loss = sent_loss.mean().detach()
                if "retrieval_loss" in outputs:
                    retrieval_loss = outputs["retrieval_loss"]
                    retrieval_loss = retrieval_loss.mean().detach()
            else:
                loss = None
                if self.use_amp:
                    with autocast():
                        outputs = model(**inputs)
                else:
                    outputs = model(**inputs)
                if isinstance(outputs, dict):
                    logits = tuple(
                        v for k, v in outputs.items() if k not in ignore_keys
                    )
                else:
                    logits = outputs
                # TODO: this needs to be fixed and made cleaner later.
                if self.args.past_index >= 0:
                    self._past = outputs[self.args.past_index - 1]

        # do evaluation
        # table with 4 columns "all text" "text snipped around it", "top 10 words"
        # top-k (k = 1, 5) accuracy;
        num_sentences = self.args.num_sentences

        input_shape = input_ids.size()
        batch_size, seq_length = input_shape
        top_k_acc_list = None
        if batch_size > num_sentences:
            modified_batch_size = batch_size // num_sentences

            # TODO batch size < num sentence if condition; check if labels change

            # ignore labels with the -100 class
            labels = labels[: (modified_batch_size * num_sentences)]
            labels_copy = labels.clone()

            labels_copy = labels_copy.view(modified_batch_size, num_sentences, seq_length)
            labels_copy = labels_copy.permute(0, 2, 1)  # (B' x L x S); B' = B / S
            # get labels for each position
            masked_indices = torch.sum(labels_copy != -100, dim=2) == 1
            labels_copy[labels_copy == -100] = 0
            labels_copy = torch.sum(labels_copy, dim=2)            
            labels_copy[~masked_indices] = -100

            # get corresponding lo
            if isinstance(logits, tuple):
                l = logits[0]
            else:
                l = logits
            l = l.detach()
            logits_flat = l.view(-1, l.shape[-1]).detach()
            labels_flat = labels_copy.view(-1).detach()

            if logits_flat.shape[0] == labels_flat.shape[0]:
                masked_labels = labels_flat != -100
                logits_flat = logits_flat[masked_labels, :]
                labels_flat = labels_flat[masked_labels]
                # top k calculation
                top_k_acc_list = topk(logits_flat, labels_flat, k_list)
            elif l.shape[1] == labels_copy.shape[1] + num_sentences + 1:
                l = l[:, num_sentences + 1:, :]
                logits_flat = l.reshape(-1, l.shape[-1])
                masked_labels = labels_flat != -100
                logits_flat = logits_flat[masked_labels, :]
                labels_flat = labels_flat[masked_labels]
                # top k calculation
                top_k_acc_list = topk(logits_flat, labels_flat, k_list)
            elif l.shape[1] == labels_copy.shape[1] + num_sentences:
                l = l[:, num_sentences:, :]
                logits_flat = l.reshape(-1, l.shape[-1])
                masked_labels = labels_flat != -100
                logits_flat = logits_flat[masked_labels, :]
                labels_flat = labels_flat[masked_labels]
                # top k calculation
                top_k_acc_list = topk(logits_flat, labels_flat, k_list)
            elif l.shape[0] == labels.shape[0]:
                logits_flat = l.reshape(-1, l.shape[-1])
                input_ids_flat = input_ids.view(-1)
                masked_labels = input_ids_flat == 50264
                logits_flat = logits_flat[masked_labels, :]
                labels_flat = labels.view(-1)[masked_labels]
                # top k calculation
                top_k_acc_list = topk(logits_flat, labels_flat, k_list)
            elif l.shape[1] == labels_copy.shape[1] + 1:
                l = l[:, 1:, :]
                logits_flat = l.reshape(-1, l.shape[-1])
                masked_labels = labels_flat != -100
                logits_flat = logits_flat[masked_labels, :]
                labels_flat = labels_flat[masked_labels]
                # top k calculation
                top_k_acc_list = topk(logits_flat, labels_flat, k_list)

            else:
                logger.warning("skipping last batch, batch size not divisible by number of sentences")
        else:
            logger.warning("Not including really small batch in top-k acc. calculation")

        if wandb_table is not None:
            input_ids  = wandb_input_ids
            labels = wandb_labels

            input_shape = input_ids.size()
            batch_size, seq_length = input_shape

            num_sentences = self.args.num_sentences
            # decode all sentences in the batch
            tokenizer = self.tokenizer
            # all_sents_tokens = []

            # for batch_id in range(batch_size):
            #     all_sents_tokens.append(tokenizer.decode(input_ids[batch_id]))
            # group all sentences by num senterrnces, find mask tokens and visualize corresponding snippet in all 4 sentences
            # masked_indices = input_ids == 50264
            masked_indices = labels != -100
            # skip batch if we don't have equally sized chunks
            def replace_tags(string):
                return string.replace(">", "]").replace("<", "[")
            if batch_size % num_sentences == 0:
                for batch_id in range((batch_size // num_sentences) * num_sentences):
                    cur_sampled_masked_indices = masked_indices[batch_id]
                    masked_positions = cur_sampled_masked_indices.nonzero(
                        as_tuple=True
                    )[0]
                    for masked_pos in masked_positions:
                        # add a new entry to the table
                        cur_sentences = [
                            replace_tags(
                                tokenizer.decode(
                                    input_ids[j][: max(masked_pos - 10, 0)]
                                )
                            )
                            + " <strong> %s </strong> "
                            % replace_tags(
                                tokenizer.decode(
                                    input_ids[j][
                                        max(masked_pos - 10, 0) : min(
                                            masked_pos + 10, seq_length - 1
                                        )
                                    ]
                                )
                            )
                            + replace_tags(
                                tokenizer.decode(
                                    input_ids[j][min(masked_pos + 10, seq_length - 1) :]
                                )
                            )
                            for j in range(
                                ((batch_id // num_sentences) * num_sentences),
                                ((batch_id // num_sentences) + 1) * num_sentences,
                            )
                        ]
                        cur_snippet = (
                            replace_tags(
                                tokenizer.decode(
                                    input_ids[batch_id][
                                        max(masked_pos - 10, 0) : masked_pos
                                    ]
                                )
                            )
                            + " <strong> %s </strong> "
                            % replace_tags(
                                tokenizer.decode(
                                    input_ids[batch_id][masked_pos : masked_pos + 1]
                                )
                            )
                            + replace_tags(
                                tokenizer.decode(
                                    input_ids[batch_id][
                                        min(masked_pos + 1, seq_length - 1) : masked_pos
                                        + 10
                                    ]
                                )
                            )
                        )
                        # figure out top k predictions
                        top_k_word_indices = torch.topk(
                            l[batch_id // num_sentences][masked_pos], 10
                        ).indices

                        top_k_words = tokenizer.convert_ids_to_tokens(
                            top_k_word_indices
                        )
                        # print(len(["\n <SAMPLE>".join(cur_sentences), cur_snippet, top_k_words]))
                        label_word = tokenizer.convert_ids_to_tokens(
                            [labels[batch_id][masked_pos].item()]
                        )

                        try:
                            cur_sentences_html = wandb.Html(
                                " <p> %s </p> " % " <br> <br> ".join(cur_sentences)
                            )
                            cur_snippet_html = wandb.Html(" <p> %s </p> " % cur_snippet)
                        except:
                            cur_sentences_html = wandb.Html("<p>  </p>")
                            cur_snippet_html = wandb.Html("<p>  </p>")
                        wandb_table.add_data(
                            cur_sentences_html,
                            cur_snippet_html,
                            ",".join(top_k_words),
                            label_word[0],
                        )

        if prediction_loss_only:
            return (loss, None, None, top_k_acc_list, lm_loss, sent_loss, retrieval_loss)

        logits = nested_detach(logits)
        if len(logits) == 1:
            logits = logits[0]

        return (loss, logits, labels, top_k_acc_list, lm_loss, sent_loss, retrieval_loss)

    def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch):
        if self.control.should_log and not self.control.should_evaluate:
            logs: Dict[str, float] = {}
            tr_loss_scalar = tr_loss.item()
            # reset tr_loss to zero
            tr_loss -= tr_loss

            logs["loss"] = round(
                tr_loss_scalar
                / (self.state.global_step - self._globalstep_last_logged),
                4,
            )
            logs["learning_rate"] = self._get_learning_rate()

            self._total_loss_scalar += tr_loss_scalar
            self._globalstep_last_logged = self.state.global_step

            self.log(logs)

        metrics = None
        if self.control.should_evaluate:
            metrics = self.evaluate()
            self._report_to_hp_search(trial, epoch, metrics)

        if self.control.should_save:
            self._save_checkpoint(model, trial, metrics=metrics)
            self.control = self.callback_handler.on_save(
                self.args, self.state, self.control
            )

    def prediction_loop(
        self,
        dataloader: DataLoader,
        description: str,
        prediction_loss_only: Optional[bool] = None,
        ignore_keys: Optional[List[str]] = None,
        metric_key_prefix: str = "eval",
    ) -> PredictionOutput:
        """
        Prediction/evaluation loop, shared by :obj:`Trainer.evaluate()` and :obj:`Trainer.predict()`.

        Works both with or without labels.
        """
        if not isinstance(dataloader.dataset, collections.abc.Sized):
            raise ValueError("dataset must implement __len__")
        prediction_loss_only = (
            prediction_loss_only
            if prediction_loss_only is not None
            else self.args.prediction_loss_only
        )

        if self.args.deepspeed and not self.args.do_train:
            # no harm, but flagging to the user that deepspeed config is ignored for eval
            # flagging only for when --do_train wasn't passed as only then it's redundant
            logger.info(
                "Detected the deepspeed argument but it will not be used for evaluation"
            )
        model = self._wrap_model(self.model, training=False)

        # if full fp16 is wanted on eval and this ``evaluation`` or ``predict`` isn't called while
        # ``train`` is running, half it first and then put on device
        if not self.is_in_train and self.args.fp16_full_eval:
            model = model.half().to(self.args.device)

        batch_size = dataloader.batch_size
        num_examples = self.num_examples(dataloader)
        num_examples = (num_examples // batch_size) * batch_size if self.args.dataloader_drop_last else num_examples
        logger.info("***** Running %s *****", description)
        logger.info("  Num examples = %d", num_examples)
        logger.info("  Batch size = %d", batch_size)
        losses_host: torch.Tensor = None
        preds_host: Union[torch.Tensor, List[torch.Tensor]] = None
        labels_host: Union[torch.Tensor, List[torch.Tensor]] = None
        top_k_acc_host = None
        lm_loss_host = None
        sent_loss_host = None
        retrieval_loss_host = None

        world_size = max(1, self.args.world_size)

        eval_losses_gatherer = DistributedTensorGatherer(
            world_size, num_examples, make_multiple_of=batch_size
        )
        k_list = [1, 5]
        top_k_gatherer_list = [
            DistributedTensorGatherer(
                world_size, num_examples, make_multiple_of=batch_size
            )
            for _ in k_list
        ]
        lm_loss_gatherer = DistributedTensorGatherer(
            world_size, num_examples, make_multiple_of=batch_size
        )
        sent_loss_gatherer = DistributedTensorGatherer(
            world_size, num_examples, make_multiple_of=batch_size
        )
        retrieval_loss_gatherer = DistributedTensorGatherer(
            world_size, num_examples, make_multiple_of=batch_size
        )
        if not prediction_loss_only:
            # The actual number of eval_sample can be greater than num_examples in distributed settings (when we pass
            # a batch size to the sampler)
            make_multiple_of = None
            if hasattr(dataloader, "sampler") and isinstance(
                dataloader.sampler, SequentialDistributedSampler
            ):
                make_multiple_of = dataloader.sampler.batch_size
            preds_gatherer = DistributedTensorGatherer(
                world_size, num_examples, make_multiple_of=make_multiple_of
            )
            labels_gatherer = DistributedTensorGatherer(
                world_size, num_examples, make_multiple_of=make_multiple_of
            )

        model.eval()

        if is_torch_tpu_available():
            dataloader = pl.ParallelLoader(
                dataloader, [self.args.device]
            ).per_device_loader(self.args.device)

        if self.args.past_index >= 0:
            self._past = None

        self.callback_handler.eval_dataloader = dataloader
        column_names = ["all context", "snippet", "top k words", "GT Label"]
        table = wandb.Table(columns=column_names)

        for step, inputs in enumerate(dataloader):

            loss, logits, labels, top_k_acc_list, lm_loss, sent_loss, retrieval_loss = self.prediction_step(
                model,
                inputs,
                prediction_loss_only,
                ignore_keys=ignore_keys,
                wandb_table=None if len(table.data) > 5000 or not self.table_viz else table,
                k_list=k_list,
            )
            if loss is not None:
                losses = loss.repeat(batch_size)
                losses_host = (
                    losses
                    if losses_host is None
                    else torch.cat((losses_host, losses), dim=0)
                )
                if top_k_acc_list is not None:
                    top_k_acc_repeated = [v.repeat(batch_size) for v in top_k_acc_list]

                    top_k_acc_host = [
                        v
                        if top_k_acc_host is None
                        else torch.cat([top_k_acc_host[index], v])
                        for index, v in enumerate(top_k_acc_repeated)
                    ]
                if lm_loss is not None:
                    lm_losses = lm_loss.repeat(batch_size)
                    lm_loss_host = (
                        lm_losses
                        if lm_loss_host is None
                        else torch.cat((lm_loss_host, lm_losses), dim=0)
                    )
                if sent_loss is not None:
                    sent_losses = sent_loss.repeat(batch_size)
                    sent_loss_host = (
                        sent_losses
                        if sent_loss_host is None
                        else torch.cat((sent_loss_host, sent_losses), dim=0)
                    )
                if retrieval_loss is not None:
                    retrieval_losses = retrieval_loss.repeat(batch_size)
                    retrieval_loss_host = (
                        retrieval_losses
                        if retrieval_loss_host is None
                        else torch.cat((retrieval_loss_host, retrieval_losses), dim=0)
                    )

            if logits is not None:
                preds_host = (
                    logits
                    if preds_host is None
                    else nested_concat(preds_host, logits, padding_index=-100)
                )
            if labels is not None:
                labels_host = (
                    labels
                    if labels_host is None
                    else nested_concat(labels_host, labels, padding_index=-100)
                )
            self.control = self.callback_handler.on_prediction_step(
                self.args, self.state, self.control
            )

            # Gather all tensors and put them back on the CPU if we have done enough accumulation steps.
            if (
                self.args.eval_accumulation_steps is not None
                and (step + 1) % self.args.eval_accumulation_steps == 0
            ):
                eval_losses_gatherer.add_arrays(
                    self._gather_and_numpify(losses_host, "eval_losses")
                )

                if top_k_acc_host is not None:
                    [
                        v.add_arrays(
                            self._gather_and_numpify(
                                top_k_acc_host[index], f"eval_top{k_list[index]}_acc"
                            )
                        )
                        for index, v in enumerate(top_k_gatherer_list)
                    ]
                if lm_loss_host is not None:
                    lm_loss_gatherer.add_arrays(
                        self._gather_and_numpify(lm_loss_host, "eval_lm_losses")
                    )
                if sent_loss_host is not None:
                    sent_loss_gatherer.add_arrays(
                        self._gather_and_numpify(sent_loss_host, "eval_sent_losses")
                    )
                if retrieval_loss_host is not None:
                    retrieval_loss_gatherer.add_arrays(
                        self._gather_and_numpify(retrieval_loss_host, "eval_retrieval_losses")
                    )
                if not prediction_loss_only:
                    preds_gatherer.add_arrays(
                        self._gather_and_numpify(preds_host, "eval_preds")
                    )
                    labels_gatherer.add_arrays(
                        self._gather_and_numpify(labels_host, "eval_label_ids")
                    )

                # Set back to None to begin a new accumulation
                losses_host, preds_host, labels_host, top_k_acc_host = None, None, None, None
                lm_loss_host, sent_loss_host, retrieval_loss_host = None, None, None

        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of the evaluation loop
            delattr(self, "_past")

        self.callback_handler.on_log(
            self.args, self.state, self.control, {str(self.state.global_step): table}
        )
        # Gather all remaining tensors and put them back on the CPU
        eval_losses_gatherer.add_arrays(
            self._gather_and_numpify(losses_host, "eval_losses")
        )

        [
            v.add_arrays(
                self._gather_and_numpify(
                    top_k_acc_host[index], f"eval_top{k_list[index]}_acc"
                )
            )
            for index, v in enumerate(top_k_gatherer_list)
        ]
        if lm_loss_host is not None:
            lm_loss_gatherer.add_arrays(
                self._gather_and_numpify(lm_loss_host, "eval_lm_losses")
            )
        if sent_loss_host is not None:
            sent_loss_gatherer.add_arrays(
                self._gather_and_numpify(sent_loss_host, "eval_sent_losses")
            )
        if retrieval_loss_host is not None:
            retrieval_loss_gatherer.add_arrays(
                self._gather_and_numpify(retrieval_loss_host, "eval_retrieval_losses")
            )

        if not prediction_loss_only:
            preds_gatherer.add_arrays(
                self._gather_and_numpify(preds_host, "eval_preds")
            )
            labels_gatherer.add_arrays(
                self._gather_and_numpify(labels_host, "eval_label_ids")
            )

        eval_loss = eval_losses_gatherer.finalize()
        lm_loss = lm_loss_gatherer.finalize() if lm_loss_host is not None else None
        sent_loss_host = sent_loss_gatherer.finalize() if sent_loss_host is not None else None
        retrieval_loss_host = retrieval_loss_gatherer.finalize() if retrieval_loss_host is not None else None
        # top_k_acc_final = [v.finalize() for v in top_k_gatherer_list]
        top_k_acc_final = top_k_acc_host

        preds = preds_gatherer.finalize() if not prediction_loss_only else None
        label_ids = labels_gatherer.finalize() if not prediction_loss_only else None

        if (
            self.compute_metrics is not None
            and preds is not None
            and label_ids is not None
        ):
            metrics = self.compute_metrics(
                EvalPrediction(predictions=preds, label_ids=label_ids)
            )
        else:
            metrics = {}

        # To be JSON-serializable, we need to remove numpy types or zero-d tensors
        metrics = denumpify_detensorize(metrics)

        if eval_loss is not None:
            metrics[f"{metric_key_prefix}_loss"] = eval_loss.mean().item()
            for index, k in enumerate(k_list):
                metrics[f"{metric_key_prefix}_top{k}_acc"] = (
                    top_k_acc_final[index].mean().item()
                )
            if lm_loss is not None:
                metrics[f"{metric_key_prefix}_lm_loss"] = lm_loss.mean().item()
            if sent_loss is not None:
                metrics[f"{metric_key_prefix}_sent_loss"] = sent_loss.mean().item()
            if retrieval_loss is not None:
                metrics[f"{metric_key_prefix}_retrieval_loss"] = retrieval_loss.mean().item()

        # Prefix all keys with metric_key_prefix + '_'
        for key in list(metrics.keys()):
            if not key.startswith(f"{metric_key_prefix}_"):
                metrics[f"{metric_key_prefix}_{key}"] = metrics.pop(key)

        return PredictionOutput(predictions=preds, label_ids=label_ids, metrics=metrics)


    def get_train_dataloader(self) -> DataLoader:
        """
        Returns the training :class:`~torch.utils.data.DataLoader`.

        Will use no sampler if :obj:`self.train_dataset` does not implement :obj:`__len__`, a random sampler (adapted
        to distributed training if necessary) otherwise.

        Subclass and override this method if you want to inject some custom behavior.
        """
        if self.train_dataset is None:
            raise ValueError("Trainer: training requires a train_dataset.")
        train_sampler = self._get_train_sampler()

        return DataLoader(
            self.train_dataset,
            batch_size=self.args.train_batch_size,
            sampler=train_sampler,
            collate_fn=self.train_data_collator,
            drop_last=self.args.dataloader_drop_last,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
        )

    def get_eval_dataloader(self, eval_dataset: Optional[Dataset] = None) -> DataLoader:
        """
        Returns the evaluation :class:`~torch.utils.data.DataLoader`.

        Subclass and override this method if you want to inject some custom behavior.

        Args:
            eval_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
                If provided, will override :obj:`self.eval_dataset`. If it is an :obj:`datasets.Dataset`, columns not
                accepted by the ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
        """
        if eval_dataset is None and self.eval_dataset is None:
            raise ValueError("Trainer: evaluation requires an eval_dataset.")
        elif eval_dataset is not None and not isinstance(eval_dataset, collections.abc.Sized):
            raise ValueError("eval_dataset must implement __len__")
        elif is_datasets_available() and isinstance(eval_dataset, datasets.Dataset):
            self._remove_unused_columns(eval_dataset, description="evaluation")
        eval_dataset = eval_dataset if eval_dataset is not None else self.eval_dataset
        eval_sampler = self._get_eval_sampler(eval_dataset)

        return DataLoader(
            eval_dataset,
            sampler=eval_sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.eval_data_collator,
            drop_last=self.args.dataloader_drop_last,
            num_workers=self.args.dataloader_num_workers,
            pin_memory=self.args.dataloader_pin_memory,
        )

    def get_test_dataloader(self, test_dataset: Dataset) -> DataLoader:
        """
        Returns the test :class:`~torch.utils.data.DataLoader`.

        Subclass and override this method if you want to inject some custom behavior.

        Args:
            test_dataset (:obj:`torch.utils.data.dataset.Dataset`, `optional`):
                The test dataset to use. If it is an :obj:`datasets.Dataset`, columns not accepted by the
                ``model.forward()`` method are automatically removed. It must implement :obj:`__len__`.
        """
        if not isinstance(test_dataset, collections.abc.Sized):
            raise ValueError("test_dataset must implement __len__")
        elif is_datasets_available() and isinstance(test_dataset, datasets.Dataset):
            self._remove_unused_columns(test_dataset, description="test")
        test_sampler = self._get_eval_sampler(test_dataset)

        # We use the same batch_size as for eval.
        return DataLoader(
            test_dataset,
            sampler=test_sampler,
            batch_size=self.args.eval_batch_size,
            collate_fn=self.eval_data_collator,
            drop_last=self.args.dataloader_drop_last,
            pin_memory=self.args.dataloader_pin_memory,
        )
