import os
import sys
import math
import pickle
import argparse
import numpy as np
import pandas as pd
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.utils.data.distributed import DistributedSampler
from torch.utils.data import Dataset, DataLoader, RandomSampler, SequentialSampler
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss

from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    GenerationConfig,
    SchedulerType,
    get_scheduler,
)

import deepspeed
from deepspeed import get_accelerator
from deepspeed.ops.adam import DeepSpeedCPUAdam, FusedAdam

sys.path.append('.')
sys.path.append(
    os.path.abspath(os.path.join(os.path.dirname(__file__), os.path.pardir)))

from dataset import TrainDataset, EvaluationDataset

from local_utils.utils import print_rank_0, to_device, save_hf_format, set_random_seed, get_all_reduce_mean, get_optimizer_grouped_parameters, save_zero_three_model, load_hf_tokenizer, map_answer_to_label
from local_utils.ds_utils import get_train_ds_config
from local_utils.module.lora import convert_linear_layer_to_lora, convert_lora_to_linear_layer, only_optimize_lora_parameters, make_model_gradient_checkpointing_compatible

from torch.utils.tensorboard import SummaryWriter
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score, cohen_kappa_score, classification_report

IGNORE_TOKEN_ID = -100

def parse_args():
    parser = argparse.ArgumentParser(
        description=
        "Finetune a transformers model on a causal language modeling task")
    # increased parameters based on DeepSpeedExample
    parser.add_argument("--train",
                        action="store_true",
                        help="decide to train or inference model"
                        )
    parser.add_argument("--load_local_ckpt", 
                        action="store_true",
                        help="Load a local checkpoint for the model")
    parser.add_argument("--cache_dir",
                        type=str,
                        default="~/.cache/huggingface/",
                        help="cache dir path")
    parser.add_argument("--dataset", 
                        default="moral_choice", 
                        type=str, 
                        help="Dataset to evaluate (beavertails, denevil, moral_choice, value_fulcra)")
    parser.add_argument("--data_version", 
                        default="original", 
                        type=str, 
                        help="Data version to use (original, pairwise, augmented)")
    parser.add_argument("--train_split", 
                        default="small", 
                        type=str, 
                        help="Train split to use, small / large")
    parser.add_argument("--remove_value",
                        default=None,
                        type=str,
                        help="Remove a specific value from the training data")
    parser.add_argument("--prompt_method",
                        default="vanilla",
                        type=str,
                        help="Prompt method to use, vanilla means just prompts.")
    parser.add_argument("--with_definition", 
                        default="True", 
                        type=str, 
                        help="Whether to include the basic/prior definition in the prompt.")
    parser.add_argument("--with_result", 
                        default="True", 
                        type=str, 
                        help="Whether use the result when encoding key concepts.")
    parser.add_argument("--use_inferred",
                        default="False",
                        type=str,
                        help="Whether to use the inferred features for testing data, instead of the mapped one")
    parser.add_argument("--max_new_tokens",
                        type=int,
                        default=512,
                        help="Max number of tokens to generate.")

    parser.add_argument(
        "--model_name",
        type=str,
        help=
        "Path to pretrained model or model identifier from huggingface.co/models.",
        required=True,
    )
    parser.add_argument(
        "--per_device_train_batch_size",
        type=int,
        default=16,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument(
        "--per_device_eval_batch_size",
        type=int,
        default=16,
        help="Batch size (per device) for the evaluation dataloader.",
    )
    parser.add_argument(
        "--max_seq_len",
        type=int,
        default=512,
        help="The maximum sequence length.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1e-3,
        help=
        "Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument("--weight_decay",
                        type=float,
                        default=0.,
                        help="Weight decay to use.")
    parser.add_argument("--num_train_epochs",
                        type=int,
                        default=1,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--lr_scheduler_type",
        type=SchedulerType,
        default="cosine",
        help="The scheduler type to use.",
        choices=[
            "linear", "cosine", "cosine_with_restarts", "polynomial",
            "constant", "constant_with_warmup"
        ],
    )
    parser.add_argument(
        "--num_warmup_steps",
        type=int,
        default=0,
        help="Number of steps for the warmup in the lr scheduler.")
    parser.add_argument("--output_dir",
                        type=str,
                        default="../../models",
                        help="Where to store the model.")
    parser.add_argument("--seed",
                        type=int,
                        default=1234,
                        help="A seed for reproducible training.")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--gradient_checkpointing',
                        action='store_true',
                        help='Enable HF gradient checkpointing for model.')
    parser.add_argument(
        "--dropout",
        type=float,
        default=None,
        help="If dropout configured, use it. "
        "Otherwise, keep the default dropout configuration of the model.")

    # deepspeed features
    parser.add_argument('--offload',
                        action='store_true',
                        help='Enable ZeRO Offload techniques.')
    parser.add_argument('--dtype',
                        type=str,
                        default='fp16',
                        choices=['fp16', 'bf16', 'fp32'],
                        help='Training data type')
    parser.add_argument(
        '--zero_stage',
        type=int,
        default=0,
        help='ZeRO optimization stage for Actor model (and clones).')

    ## LoRA for efficient training setting
    parser.add_argument("--lora_dim",
                        type=int,
                        default=0,
                        help="If > 0, use LoRA for efficient training.")
    parser.add_argument("--lora_module_name",
                        type=str,
                        default="decoder.layers.",
                        help="The scope of LoRA.")
    parser.add_argument('--only_optimize_lora',
                        action='store_true',
                        help='Only optimize the LoRA parameters.')
    parser.add_argument(
        "--lora_learning_rate",
        type=float,
        default=5e-4,
        help=
        "Initial LoRA learning rate (after the potential warmup period) to use."
    )
    ## Print loss
    parser.add_argument('--print_loss',
                        action='store_true',
                        help='Prints loss at each step.')
    parser.add_argument('--print_every_n_step',
                        type=int,
                        default=50,
                        help='Print loss every n steps.')
    parser = deepspeed.add_config_arguments(parser)
    args = parser.parse_args()

    return args

def evaluation(args, model, tokenizer, eval_dataloader, device, epoch=-1, writer=None):
    model.eval()
    labels, predictions = [], []
    values = []
    for step, batch in tqdm(enumerate(eval_dataloader), desc="evaluating with free output"):
        batch = to_device(batch, device)
        generation_config = GenerationConfig(do_sample = False, num_beams = 1)

        with torch.no_grad():
            output_ids = model.generate(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                max_new_tokens = 20,
                pad_token_id = tokenizer.eos_token_id,
                generation_config = generation_config,
            )
            answers = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
            answers = [answer[len(prompt):].strip() for answer, prompt in zip(answers, batch["prompts"])]
            predictions.extend(answers)
            labels.extend(batch["labels"])
            values.extend(batch["values"])
        if (epoch < args.num_train_epochs - 1 and step >= 200) or step >= 480:
            break

    print("predictions: ", predictions[:30])
    prediction = map_answer_to_label(predictions)
    ground_truth = map_answer_to_label(labels)

    accuracy = accuracy_score(ground_truth, prediction)
    precision = precision_score(ground_truth, prediction, average="micro")
    recall = recall_score(ground_truth, prediction, average="micro")
    f1 = f1_score(ground_truth, prediction, average="micro")
    report = classification_report(ground_truth, prediction)
    print(f"classification report: {report}")
    print(f"Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}, F1: {f1}")

    # compute metrics for each value
    value_results = {}
    unique_values = sorted(list(set(values)))
    for value in unique_values:
        value_prediction = [prediction[i] for i in range(len(prediction)) if values[i] == value]
        value_ground_truth = [ground_truth[i] for i in range(len(ground_truth)) if values[i] == value]

        accuracy = accuracy_score(value_ground_truth, value_prediction)
        precision = precision_score(value_ground_truth, value_prediction, average="micro")
        recall = recall_score(value_ground_truth, value_prediction, average="micro")
        f1 = f1_score(value_ground_truth, value_prediction, average="micro")
        value = value.split(":", 1)[0]

        value_results[value] = {
            "accuracy": round(accuracy, 4),
            "precision": round(precision, 4),
            "recall": round(recall, 4),
            "f1": round(f1, 4)
        }

        print(f"Metrics for value {value}: {value_results[value]}")
    
    return prediction, ground_truth

# compute the probability of each answer        
def evaluation_limited(args, model, tokenizer, eval_dataloader, device, epoch=-1, writer=None):
    model.eval()
    labels, predictions = [], []
    values = []
    if "gpt2" in args.model_name:
        answer_tokens = [tokenizer.encode("Yes")[0], tokenizer.encode("No")[0], tokenizer.encode("Not Related")[0]]
    else:
        answer_tokens = [tokenizer.encode("Yes")[1], tokenizer.encode("No")[1], tokenizer.encode("Not Related")[1]]
    answer_candidates = ["Yes", "No", "Not related"]
    for step, batch in tqdm(enumerate(eval_dataloader), desc="evaluating with limited output"):
        batch = to_device(batch, device)
        generation_config = GenerationConfig(do_sample = False, num_beams = 1)
        
        with torch.no_grad():
            output = model(
                input_ids=batch["input_ids"],
                attention_mask=batch["attention_mask"],
                use_cache=False,
            )
            
            last_logits = output.logits[:, -1, :]  # [batch_size, 32000]
            answer_logits = last_logits[:, answer_tokens]  # [batch_size, 3]
            if args.dataset != "value_fulcra":
                answer_logits[:, -1] = -1e4 # set the logits of "Not Related" to be very small
            answers = [answer_candidates[torch.argmax(answer_logits[i])] for i in range(len(answer_logits))]
            predictions.extend(answers)
            labels.extend(batch["labels"])
            values.extend(batch["values"])

        if (epoch < args.num_train_epochs - 1 and step >= 200) or step >= 480:
            break

    print("predictions: ", predictions[:30])
    
    prediction = map_answer_to_label(predictions)
    ground_truth = map_answer_to_label(labels)

    accuracy = accuracy_score(ground_truth, prediction)
    precision = precision_score(ground_truth, prediction, average="micro")
    recall = recall_score(ground_truth, prediction, average="micro")
    f1 = f1_score(ground_truth, prediction, average="micro")
    report = classification_report(ground_truth, prediction)
    print(f"classification report: {report}")
    print(f"Accuracy: {accuracy}, Precision: {precision}, Recall: {recall}, F1: {f1}")

    # compute metrics for each value
    value_results = {}
    unique_values = sorted(list(set(values)))
    for value in unique_values:
        value_prediction = [prediction[i] for i in range(len(prediction)) if values[i] == value]
        value_ground_truth = [ground_truth[i] for i in range(len(ground_truth)) if values[i] == value]

        accuracy = accuracy_score(value_ground_truth, value_prediction)
        precision = precision_score(value_ground_truth, value_prediction, average="micro")
        recall = recall_score(value_ground_truth, value_prediction, average="micro")
        f1 = f1_score(value_ground_truth, value_prediction, average="micro")
        value = value.split(":", 1)[0]

        value_results[value] = {
            "accuracy": round(accuracy, 4),
            "precision": round(precision, 4),
            "recall": round(recall, 4),
            "f1": round(f1, 4)
        }

        print(f"Metrics for value {value}: {value_results[value]}")

def train_model(args, model, tokenizer, ds_config, device):
    train_dataset = TrainDataset(args, tokenizer, split="train", max_length=args.max_seq_len)
    eval_dataset = EvaluationDataset(args, tokenizer, split="valid", max_length=args.max_seq_len)
    print(f"#Train Dataset: {len(train_dataset)}, #Eval Dataset: {len(eval_dataset)}")
    print("train_dataset: ", train_dataset)
    args.vocab_size = model.config.vocab_size
    
    # DataLoaders creation:
    if args.local_rank == -1:
        train_sampler = RandomSampler(train_dataset)
        eval_sampler = SequentialSampler(eval_dataset)
    else:
        train_sampler = RandomSampler(train_dataset) # DistributedSampler(train_dataset)
        eval_sampler = SequentialSampler(eval_dataset)

    train_dataloader = DataLoader(train_dataset,
                                  collate_fn=train_dataset.collator,
                                  sampler=train_sampler,
                                  batch_size=args.per_device_train_batch_size)
    eval_dataloader = DataLoader(eval_dataset,
                                 collate_fn=eval_dataset.collator,
                                 sampler=eval_sampler,
                                 batch_size=args.per_device_eval_batch_size)
    
    # Split weights in two groups, one with weight decay and the other not.
    no_decay = ["bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay) and p.requires_grad],
            "weight_decay": args.weight_decay,
        },
        {
            "params": [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay) and p.requires_grad],
            "weight_decay": 0.0,
        },
    ]

    AdamOptimizer = DeepSpeedCPUAdam if args.offload else FusedAdam
    optimizer = AdamOptimizer(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              betas=(0.9, 0.95))

    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / args.gradient_accumulation_steps)
    lr_scheduler = get_scheduler(
        name=args.lr_scheduler_type,
        optimizer=optimizer,
        num_warmup_steps=args.num_warmup_steps,
        num_training_steps=args.num_train_epochs * num_update_steps_per_epoch,
    )

    model, optimizer, _, lr_scheduler = deepspeed.initialize(
        model=model,
        optimizer=optimizer,
        args=args,
        config=ds_config,
        lr_scheduler=lr_scheduler,
        dist_init_required=True)

    if args.gradient_checkpointing:
        model.gradient_checkpointing_enable()
    
    log_dir = os.path.join(args.output_dir, "tensorboard_logs")
    writer = SummaryWriter(log_dir)

    evaluation(args, model, tokenizer, eval_dataloader, device, epoch=-1, writer=writer)
    for epoch in range(args.num_train_epochs):
        print(
            f"Beginning of Epoch {epoch+1}/{args.num_train_epochs}, Total Micro Batches {len(train_dataloader)}",
            args.global_rank)
        model.train()
        import time
        for step, batch in enumerate(train_dataloader):
            start = time.time()
            batch = to_device(batch, device)
            outputs = model(**batch, use_cache=False)
            loss = outputs.loss
            if args.print_loss and step % args.print_every_n_step == 0:
                print(f"Epoch: {epoch}, Step: {step}, Rank: {torch.distributed.get_rank()}, loss = {loss}")
            model.backward(loss)
            model.step()
            end = time.time()
            writer.add_scalar('Loss', loss, epoch * len(train_dataloader) + step)

        # Evaluating on the validation set.
        print(
            f"***** Evaluating, Epoch {epoch+1}/{args.num_train_epochs} *****",
            args.global_rank)
        evaluation(args, model, tokenizer, eval_dataloader, device, epoch, writer)
        model.tput_timer.update_epoch_count()

    save_model(model, tokenizer, args, epoch)
    
    writer.close()

def save_model(model, tokenizer, args, epoch):
    if args.output_dir is not None:
        print('saving the final model ...')
        model = convert_lora_to_linear_layer(model)

        # if args.global_rank == 0:
        save_hf_format(model, tokenizer, args, sub_folder=f"epoch_{epoch+1}")

        if args.zero_stage == 3:
            # For zero stage 3, each gpu only has a part of the model, so we need a special save function
            save_zero_three_model(model,
                                  args.global_rank,
                                  args.output_dir,
                                  os.path.join(args.output_dir,f"epoch_{epoch+1}"),
                                  zero_stage=args.zero_stage)

def main():
    args = parse_args()

    if args.local_rank == -1:
        device = torch.device(get_accelerator().device_name())
    else:
        get_accelerator().set_device(args.local_rank)
        device = torch.device(get_accelerator().device_name(), args.local_rank)
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        # torch.distributed.init_process_group(backend='nccl')
        deepspeed.init_distributed()

    args.global_rank = torch.distributed.get_rank()

    ds_config = get_train_ds_config(offload=args.offload, stage=args.zero_stage)
    ds_config['train_micro_batch_size_per_gpu'] = args.per_device_train_batch_size
    ds_config['train_batch_size'] = args.per_device_train_batch_size * torch.distributed.get_world_size() * args.gradient_accumulation_steps

    # If passed along, set the training seed now.
    set_random_seed(args.seed)

    torch.distributed.barrier()

    # load_hf_tokenizer will get the correct tokenizer and set padding tokens based on the model family
    tokenizer = AutoTokenizer.from_pretrained(
        args.model_name,
        cache_dir = args.cache_dir,
        model_max_length = 4096,
        use_fast = True,
        # padding_side = "right",  # right for training, left for inference
    )
    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token
        # tokenizer.pad_token = tokenizer.bos_token
    llm_config = AutoConfig.from_pretrained(
        args.model_name,
        cache_dir=args.cache_dir,
        trust_remote_code = True if "phi" in args.model_name.lower() else False,
    )
    llm_config.pad_token_id = tokenizer.pad_token_id
    llm_config.use_cache = False
    # llm_config.trust_remote_code = True if "phi" in args.model_name.lower() else False

    model = AutoModelForCausalLM.from_pretrained(  # generation
        args.model_name,
        config=llm_config,
        cache_dir=args.cache_dir,
        trust_remote_code = True if "phi" in args.model_name.lower() else False)

    para_path = f"dataset_{args.dataset}_split_{args.train_split}_prompt_method_{args.prompt_method}_with_result_{args.with_result}"
    args.output_dir = os.path.join(args.output_dir, args.model_name.split("/")[-1], para_path)
    os.makedirs(args.output_dir, exist_ok=True)
    
    if args.load_local_ckpt:
        save_path = os.path.join(args.output_dir, f"epoch_{args.num_train_epochs}", "pytorch_model.bin")
        # save_path = os.path.join(args.output_dir, f"epoch_{args.num_train_epochs}")
        print("loading the saved model from path: " + save_path)
        # model.from_pretrained(save_path)
        model.load_state_dict(torch.load(save_path), strict=False) # do not load lora parameters
    
    model = model.to(device)

    if args.lora_dim > 0 and args.train:
        model = convert_linear_layer_to_lora(model, args.lora_module_name,
                                             args.lora_dim)
        if args.only_optimize_lora:
            model = only_optimize_lora_parameters(model)
            model = make_model_gradient_checkpointing_compatible(model)
    
    if args.train:
        train_model(args, model, tokenizer, ds_config, device)
    
    eval_dataset = EvaluationDataset(args, tokenizer, split="valid", max_length=args.max_seq_len)
    if args.local_rank == -1:
        eval_sampler = SequentialSampler(eval_dataset)
    else:
        eval_sampler = SequentialSampler(eval_dataset) # DistributedSampler(eval_dataset)

    eval_dataloader = DataLoader(eval_dataset,
                                collate_fn=eval_dataset.collator,
                                sampler=eval_sampler,
                                batch_size=args.per_device_eval_batch_size)
    
    prediction, ground_truth = evaluation(args, model, tokenizer, eval_dataloader, device, epoch=args.num_train_epochs)
    evaluation_limited(args, model, tokenizer, eval_dataloader, device, epoch=args.num_train_epochs)
    with open(os.path.join(args.output_dir, f"prediction_{args.data_version}.pkl"), "wb") as f:
        pickle.dump([prediction, ground_truth], f)

if __name__ == "__main__":
    main()