import os

import numpy as np
import torch
from datasets import load_dataset, Dataset, DatasetDict
from evaluate import load
from transformers import pipeline, LlamaTokenizerFast

from eval.dataset_manager import ConversationContextManager, ArxivContextManager, GSMContextManager, \
    BBHContextManager
from eval.task_manager import ContinueConversation, Reasoning, OriginalContextReconsutrction
from utils.text_generation import Ranker


# TODO: Requirements:
#  1. Each dataset corresponds to one task. arxiv -> summarization, sharegpt -> conversation, GSM8K -> reasoning, etc.
#  2. Create another class that takes in dataset, compressor, tokenizer, and LLM. Implement dataset corresponding task.


def _preprocess_alpaca(sample):
    new_examples = {
        "query": None,
        "output": None
    }

    instruction, output = sample["instruction"], sample["output"]

    new_examples["query"] = instruction
    new_examples["output"] = output
    # new_examples["input_ids"] = tokenizer.encode(prompt)
    return new_examples


class Evaluator:
    def __init__(self, datasets):
        self.prefixed_manager = {
            "ShareGPT": {'task_managers': ContinueConversation, 'dataset_managers': ConversationContextManager},
            "Arxiv": {'task_managers': OriginalContextReconsutrction, 'dataset_managers': ArxivContextManager},
            "GSM8K": {"reasoning": Reasoning, "dataset_managers": GSMContextManager},
            "BBH": {"reasoning": Reasoning, "dataset_managers": BBHContextManager},
        }
        self.datasets = datasets
        self.rouge = load("/home/trl/trl/hf_hub/metrics/rouge/rouge.py")
        self.bleu = load("/home/trl/trl/hf_hub/metrics/bleu/bleu.py")

    def update_llm(self, agent, generator, tokenizer=None):
        self.agent = agent
        self.generator = generator
        self.tokenizer = tokenizer

    def calculate_compression_ratio(self, original_text, compressed_text):
        assert len(original_text) == len(compressed_text), "Length of original and compressed text should be the same"

        if isinstance(original_text, str):
            original_tokens = self.tokenizer(original_text)
            compressed_tokens = self.tokenizer(compressed_text)
            cr = len(original_tokens["input_ids"]) / len(compressed_tokens["input_ids"])

        elif isinstance(original_text, list):
            total_original_tokens = 0
            total_compressed_tokens = 0
            for i in range(len(original_text)):
                original_tokens = self.tokenizer(original_text[i])
                compressed_tokens = self.tokenizer(compressed_text[i])
                total_original_tokens += len(original_tokens["input_ids"])
                total_compressed_tokens += len(compressed_tokens["input_ids"])
            cr = total_original_tokens / total_compressed_tokens

        else:
            raise TypeError("Unsupported type :{} & {}".format(type(original_text), type(compressed_text)))

        return cr

    def evaluate_compressor(self, dataset_name, compressor_name):
        if isinstance(self.datasets[dataset_name], DatasetDict) or isinstance(self.datasets[dataset_name], dict):
            dataset = self.datasets[dataset_name]['test']
        else:
            dataset = self.datasets[dataset_name]
        task_manager = self.prefixed_manager[dataset_name]["task_managers"]
        dataset_manager = self.prefixed_manager[dataset_name]["dataset_managers"]

        # note: rearrange prompt in raw dataset
        context = dataset_manager(dataset, self.agent, self.tokenizer).generate_context(10,
                                                                                        compressor_name=compressor_name)
        answer = task_manager(context, self.generator, self.tokenizer).get_answer()

        original_prompt, compressed_prompt = context['orig'], context['compressed']
        original_response, compressed_response = answer['orig'], answer['compressed']

        # Calculate Rouge and BLEU
        rouge_score = self.rouge.compute(predictions=compressed_response, references=original_response)
        bleu_score = self.bleu.compute(predictions=compressed_response, references=original_response)
        compression_ratio = self.calculate_compression_ratio(original_prompt, compressed_prompt)

        # Aggregate and return results
        return {
            "rouge": rouge_score,
            "bleu": bleu_score,
            "compression_ratio": compression_ratio
        }

    def evaluate_alpaca(self, dataset_name):
        # return win rate
        if isinstance(self.datasets[dataset_name], DatasetDict) or isinstance(self.datasets[dataset_name], dict):
            dataset = self.datasets[dataset_name]['test']
        else:
            dataset = self.datasets[dataset_name]

        # preprocess dataset
        random_indices = np.random.choice(len(dataset), 20, replace=False)
        dataset = dataset.map(_preprocess_alpaca, remove_columns=dataset.column_names, num_proc=4).select(random_indices)

        llama_generation_kwargs = {
            "do_sample": False,  # yes, we want to sample
            "pad_token_id": self.tokenizer.eos_token_id,
            "eos_token_id": self.tokenizer.eos_token_id,
            "repetition_penalty": 1.1,
            "max_new_tokens": 256,  # specify how many tokens you want to generate at most
        }

        queries = dataset['query']
        ref_responses = dataset['output']
        query_ids = self.tokenizer(queries)['input_ids']
        query_tensors = [torch.tensor(q) for q in query_ids]
        outputs = self.agent.generate(query_tensors,
                                      return_prompt=False,
                                      **llama_generation_kwargs)

        outputs = [self.tokenizer.decode(r.squeeze(), skip_special_tokens=True)
                   for r in outputs]

        # Calculate win rate
        ranker_prompt = open('/home/trl/trl/alpaca_train/alpaca_eval_llama.txt', 'r').read()
        ranker = Ranker(self.generator, llama_generation_kwargs, ranker_prompt)
        rewards = ranker.get_rouge_reward(queries, ref_responses, outputs)
        win_rate = sum(rewards) / len(rewards)

        return {"win_rate": win_rate}


if __name__ == "__main__":
    datasets = {
        "Alpaca": load_dataset('/home/hmp/hmp-mh/trl/hf_hub/datasets/alpaca-gpt4', split="train"),
        "ShareGPT": load_dataset('/home/hmp/hmp-mh/trl/hf_hub/datasets/sharegpt-500', split="train"),
        "Arxiv": load_dataset('/home/hmp/hmp-mh/trl/hf_hub/datasets/arxiv-march-2023', split="train"),
        "GSM8K": load_dataset('/home/hmp/hmp-mh/trl/hf_hub/datasets/qwedsacf-grade-school-math-instructions',
                              split="train"),
        # "BBH": load_dataset('/home/hmp/hmp-mh/trl/hf_hub/datasets/lukaemon-bbh/bbh.py', 'boolean_expressions'),
    }

    generator_model = '/home/hmp/hmp-mh/trl/hf_hub/models/Llama-2-7b-chat-hf'
    tokenizer = LlamaTokenizerFast.from_pretrained(generator_model)
    tokenizer.padding_side = 'left'
    tokenizer.pad_token = tokenizer.eos_token
    tokenizer.pad_token_id = tokenizer.eos_token_id

    text_generator = pipeline(task='text-generation',
                              model=generator_model,
                              device_map='auto',
                              tokenizer=tokenizer,
                              torch_dtype=torch.float16,
                              model_kwargs={'load_in_8bit': True},
                              )

    # Example 1: Evaluate baseline 1 - selective-content
    from selective_context import SelectiveContext

    sc = SelectiveContext(model_type='gpt2', lang='en')

    evaluator = Evaluator(datasets)
    evaluator.update_llm(agent=sc, generator=text_generator, tokenizer=tokenizer)
    results = evaluator.evaluate_compressor(dataset_name="Arxiv", compressor_name='selective_content')
    print(results)

    # # Example 2: Evaluate our method - self-play compressor
    # ppo_trainer = None  # replace with your PPO trainer
    # evaluator = Evaluator(datasets, compressor=ppo_trainer, generator=text_generator, tokenizer=tokenizer)
    # results = evaluator.evaluate_dataset(dataset_name="Arxiv", compressor_name='sp-compressor')
    # print(results)
