r"""This is an example colab which demonstrates how to load a book or article
from a dataset, run two different models on the article, compare the results,
and inspect the context in which tokens occur. All cells should be run in order.

The citc client, gin configurations and pre-trained model directories should
be set to the appropriate values by the user.
"""

from http.client import UnimplementedFileMode
import sys
from datetime import datetime


from typing import Sequence
import jax.numpy as jnp
from flax.training import common_utils

from absl import app
from absl import flags
from absl import logging
import argparse

import gin
import jax
import numpy as np
import tensorflow.compat.v2 as tf

# import matplotlib.pyplot as plt
import bokeh.plotting as bplt

# ---- Change this to the appropriate user and client. ----

from transformer import decoder_stack
from transformer import inference_utils
from transformer import text_dataset

from transformer import tasks
import json

from absl import flags
flags.FLAGS([''])
import jax
import numpy as np


import seqio
import functools

from bigbench.bbseqio import task_api
from bigbench.bbseqio import tasks as bbtasks
logging.set_verbosity(logging.WARNING)



def load_model_configs(model_name, configs, sample_method = "greedy", batch_size = 2, sequence_length = 512):
    # googlelog.set_global_capture(True)   # Uncomment to see all logs.
    print("loading model ", model_name)
    gin.clear_config(clear_constants=True)

    gin.enter_interactive_mode()  # Avoid errors when reloading cells.
    gin_paths=["transformer/configs"]
        
    gin_files = configs[model_name]["gin_files"]
    print("gin_files: ", gin_files)
    
    load_dir = configs[model_name]["load_dir"]
    print("load_dir: ", load_dir)
    
    # Override the task batch size and set it to 1.
    # This may require creating new Transformer-XL state, so ignore the pre-trained state.
    # Ask the model to output separate losses per token.
    gin_params = [
     f"DecoderOnlyLanguageModel.sample_method=\"{sample_method}\"",
      "DecoderOnlyLanguageModel.output_token_losses=True",
      "DecoderOnlyLanguageModel.output_logits=True",
      f"TransformerTaskConfig.batch_size={batch_size}",
      f"TransformerTaskConfig.sequence_length={sequence_length}",
      f"TransformerLayer.window_length={sequence_length}",
      "Trainer.restore_state_variables=False",
    ]

    # Parse the gin files and parameters.
    # If the config is not unlocked, then this command will fail if it is run a second time.
    with gin.unlock_config():
      inference_utils.parse_gin_configuration(gin_files, gin_params, gin_paths=gin_paths)

    article_data = inference_utils.read_article(verbose=True)

    (article_block_list, vocab) = article_data

    batch_idx = 0
    print(text_dataset.pretty_print_article(article_block_list[batch_idx], {"targets": vocab}, 32768))

    return vocab, load_dir

from tqdm import tqdm
import copy

def predict_fn(gen_task, gen_task_state, batch_size, rep, sequence_length, do_generate, ds: tf.data.Dataset):
    all_inferences = []
    all_indices = []
    task_state = copy.deepcopy(gen_task_state)
    
    try:
      original_ds_length = len(ds)
      dataset_remainder = original_ds_length % (batch_size * rep)  # pytype:disable=wrong-arg-types
      print('length of dataset = %s', len(ds))
    except TypeError as e:
      if str(e) == 'dataset length is unknown.':
        logging.warning(
            'The following error is likely due to the use of TensorFlow v1 in '
            'your dataset pipeline. Verify you are not importing from '
            '`tf.compat.v1` as part of your pipeline.')
      raise e
    
    if dataset_remainder:
      dataset_pad_amt = batch_size * rep - dataset_remainder
      print(
          'Padding infer dataset with %d examples for even per-replica shards.',
          dataset_pad_amt)
      # Pad with the first example using an index of -1 so seqio will ignore.
      pad_ds = ds.take(1).map(lambda i, x: (np.int64(-1), x)).repeat(
          dataset_pad_amt)
      ds = ds.concatenate(pad_ds)
    
    for batch_indice, x in tqdm(ds.batch(batch_size* rep, drop_remainder=True)):
        
        prompt_tokens = x['decoder_input_tokens'].numpy()[:, :sequence_length]
        start_of_sequence = np.ones(batch_size * rep , dtype=np.int32)

        loss_mask = x['decoder_loss_weights'].numpy()[:, :sequence_length]
        # loss_mask = np.ones([task_config.batch_size, task_config.sequence_length], dtype=np.int32)
        # loss_mask[: , : len(prompt_tokens) - 1] = 0
        
        model_x = {"targets": prompt_tokens, "start_of_sequence": start_of_sequence, "loss_mask": loss_mask, "epoch": np.array([0]*rep)[:, None], "nucleus_cutoff": np.array([0.9]*rep)[:, None], "temperature": np.array([1]*rep)[:, None]}
        # print([v.shape for k, v in model_x.items()])
        out, tstate = inference_utils.run_model(gen_task, task_state, ([model_x], vocab), verbose=False, return_tstate = True)
        task_state = (tstate, task_state[1])
        out = out[0]
        
        if do_generate:
          gen_tokens = out["gen_tokens"].reshape(batch_size * rep, -1)

          pred = jax.vmap(tasks.get_masked_tokens)(gen_tokens, loss_mask)
        else:
          # get score
          logits = out["logits"]
          soft_targets = common_utils.onehot(x['decoder_target_tokens'].numpy(), logits.shape[-1])
          pred = (logits * soft_targets * (x["decoder_loss_weights"].numpy() == 1)[:, :, None]).sum(-1).sum(-1)

        all_inferences.append(pred)
        all_indices.append(batch_indice)
    all_inferences = np.concatenate(all_inferences)
    all_indices = np.concatenate(all_indices)
    
    non_pad_idxs = all_indices >= 0
    all_indices = all_indices[non_pad_idxs]
    all_inferences = jax.tree_map(lambda x: x[non_pad_idxs], all_inferences)
        
    indices_and_outputs = list(zip(all_indices, all_inferences))
    return indices_and_outputs



def evaluate_one_task(task_name, gen_task, gen_task_state, task, task_state, batch_size, sequence_length, replicate_mode):
    evaluator = seqio.Evaluator(
          mixture_or_task_name=task_name,
          feature_converter=seqio.DecoderFeatureConverter(pack=False),  # pytype:disable=not-instantiable
          eval_split="all",
          sequence_length=None)

    # remove empty tasks
    empty_tasks = []
    for ds_name, ds in evaluator.cached_task_datasets.items():
      print(ds_name, len(ds))
    #   if len(ds) == 0:
    #     empty_tasks.append(ds_name)
    # evaluator._eval_tasks = filter(lambda t: t.name not in empty_tasks, evaluator.eval_tasks)


    rep = jax.local_device_count() if replicate_mode else 1
    all_metrics, _, _ = evaluator.evaluate(
            compute_metrics=jax.process_index() == 0,
            predict_fn=functools.partial(predict_fn, gen_task, gen_task_state, batch_size, rep, sequence_length, True),
            score_fn=functools.partial(predict_fn, task, task_state, batch_size, rep, sequence_length, False)
            )
    return all_metrics.result()

def define_bigbench_eval_tasks(sequence_length, vocab, vocab_name):

  all_bigbench_tasks_2shots, all_bigbench_tasks_3shots = tasks.register_all_bigbench_eval_tasks(sequence_length, vocab, vocab_name)

  # modify GSM8K tasks.
  def GSM8K_postprocessor(output_or_target, example=None, is_target=False):
    return tasks.extract_answer(output_or_target)
  for t in seqio.MixtureRegistry.get(f"bigbench:GSM8K_mix_{sequence_length}").tasks:
    t._postprocess_fn = GSM8K_postprocessor
  
  mul_emergence_tasks = [
    "bigbench:analytic_entailment.mul",
    "bigbench:common_morpheme.mul",
    "bigbench:fact_checker.mul",
    "bigbench:figure_of_speech_detection.mul",
    "bigbench:hindu_knowledge.mul",
    "bigbench:irony_identification.mul",
    "bigbench:logical_args.mul",
    "bigbench:logical_deduction.mul",
    "bigbench:misconceptions.mul",
    "bigbench:phrase_relatedness.mul",
    "bigbench:physical_intuition.mul",
    "bigbench:social_iqa.mul",
    "bigbench:sports_understanding.mul",
    "bigbench:strange_stories.mul",
    "bigbench:swahili_english_proverbs.mul",
    "bigbench:strategyqa.mul",
    ]

  tasks.register_subset_tasks(
    all_bigbench_tasks_2shots+ all_bigbench_tasks_3shots, 
    mul_emergence_tasks, 
    "bigbench:mul_emergence_tasks_mix_512s",
    512,
    modify = False
    )

  additional_individual_test_tasks = [
    f"bigbench:code_translation.gen.{vocab_name}_vocab.1_shot.all_examples.TransCoder_cpp_to_python",
    f"bigbench:code_translation.gen.{vocab_name}_vocab.1_shot.all_examples.CodeXGLUE_cs_to_java",
    f"bigbench:code_translation.gen.{vocab_name}_vocab.1_shot.all_examples.CodeXGLUE_java_to_cs",
  ]
  tasks.register_subset_tasks(
    additional_individual_test_tasks, 
    additional_individual_test_tasks, 
    "bigbench:additional_individual_test_tasks_mix_512",
    512,
    modify = False
    )
  
  eval_tasks = [
    # "bigbench:additional_individual_test_tasks_mix",
    # "bigbench:2shot_list_functions_mix",
    # "bigbench:3shot_list_functions_mix",
    # "bigbench:all_emergence_tasks_mix",
    # "bigbench:wmt_mix",
    "bigbench:code_translation_mix",
    # "bigbench:GSM8K_mix",
    # "bigbench:conlang_translation_mix",
    # "bigbench:linguistics_puzzles_mix",
    # "bigbench:language_games_mix", 
  ]
  eval_tasks = [ s + "_" + str(args.sequence_length) for s in eval_tasks]

  for eval_task_name in eval_tasks:
    for t in seqio.MixtureRegistry.get(eval_task_name).tasks:
      original_ps = list(t.preprocessors)
      if "loss_mask" in t.output_features:
        # undo modifications
        t._preprocessors = tuple(original_ps[:-4])
        del t.output_features["loss_mask"]
        del t.output_features["task_idx"]
        del t.output_features["start_of_sequence"]
        del t.output_features["epoch"]
      else:
        # need to add sequence limit 
        original_ps.append(functools.partial(
          tasks.filter_too_long,
          args.sequence_length
          )
        )
        t._preprocessors = tuple(original_ps)
  return eval_tasks

if __name__ == "__main__":
    print(jax.devices())
    text_dataset.set_default_data_directory()

    with open("model_configs.json") as f:
        configs = json.load(f)
     
    parser = argparse.ArgumentParser()
    parser.add_argument('--vocab', type=str, default="t5")
    parser.add_argument('--batch_size', type=int, default=4)
    parser.add_argument('--no_replicate_mode', type=bool, default=False)
    parser.add_argument('--sequence_length', type=int, default=512)
    args = parser.parse_args()

    
    # models_to_eval = ["mix-baseline", "mix-shuffle-0.5"]
    models_to_eval = ["mix-shuffle-0.5-gpt2", "mix-baseline-gpt2"]

    if args.vocab == "t5":
      eval_tasks = define_bigbench_eval_tasks(args.sequence_length, tasks.T5_DEFAULT_VOCABULARY, "t5_default")
    elif args.vocab == "gpt2":
      eval_tasks = define_bigbench_eval_tasks(args.sequence_length, tasks.GPT2_VOCABULARY, "gpt2")
    else:
      raise NotImplementedError
   
    
    for model_name in models_to_eval:
        print("==========", model_name,"==========")
        vocab, load_dir = load_model_configs(model_name, configs, batch_size = args.batch_size, sequence_length = args.sequence_length)    
        gen_task, gen_task_state, _ = inference_utils.create_model_and_task(vocab, load_dir=load_dir, task_mode="generate", replicate_mode = not args.no_replicate_mode)
        task, task_state, _ = inference_utils.create_model_and_task(vocab, load_dir=load_dir, replicate_mode = not args.no_replicate_mode)
        for eval_task_name in eval_tasks:
          result = evaluate_one_task(eval_task_name, gen_task, gen_task_state, task, task_state, batch_size = args.batch_size, sequence_length = args.sequence_length, replicate_mode = not args.no_replicate_mode)
          print(result)
          with open(model_name + "_"  +  datetime.now().strftime("%d_%m_%Y_%H_%M_%S") + "_" + eval_task_name + ".json", "w") as outfile:
              json.dump(result, outfile)
        del gen_task, gen_task_state, task, task_state

        
    print("FINISHED!!!")