import re
import os
import sys
import random
import argparse

import wandb
import torch
import numpy as np
from transformers import AutoModelForCausalLM, AutoTokenizer

# dirty but working
import sys
sys.path.append(os.pardir)
from src.pruner import Pruner
from src.utils.random import fix_seed

from data_utils import get_loaders
from engine import eval_perplexity


def main():
    parser = argparse.ArgumentParser(description="One-shot pruning on ImageNet of timm models.")
    # Model params
    parser.add_argument(
        '--model_name_or_path',
        type=str,
        required=True,
        help="The name or path to the model being pruned",
    )
    # Data params
    parser.add_argument(
        '--dataset_name_or_path',
        type=str,
        required=True,
        help="The name or dataset or path used for calibration.",
    )
    parser.add_argument(
        '--sequence_length',
        default=512,
        type=int,
        help="length of extracted sequences."
    )
    # Sparsification params
    parser.add_argument(
        '--iterations',
        default=10,
        type=int
    )
    parser.add_argument(
        '--pruning_method',
        default="FastOBC",
        choices=["FastOBC", "OBC"],
        type=str
    )
    parser.add_argument(
        '--sparsity',
        default=0.5,
        type=float
    )
    parser.add_argument(
        '--alpha',
        default=0.0,
        type=float
    )
    parser.add_argument(
        '--module_regex',
        type=str,
        required=True,
        help="Modules to prune",
    )
    parser.add_argument(
        '--decoder_blocks',
        required=True,
        type=str
    )
    parser.add_argument(
        '--pre_decoder_modules',
        required=True,
        nargs="+",
        type=str
    )
    parser.add_argument(
        '--post_decoder_modules',
        required=True,
        nargs="+",
        type=str
    )
    parser.add_argument(
        '--calibration_dataset_size',
        default=None,
        type=int,
        help="Size of calibration dataset."
    )
    parser.add_argument(
        '--block_size',
        default=64,
        type=int
    )
    parser.add_argument(
        '--rel_damp',
        default=1e-2,
        type=float
    )
    parser.add_argument(
        '--rows_in_parallel',
        default=None,
        type=int
    )
    parser.add_argument(
        '--perturbation',
        default='gradient',
        choices=['gradient', 'interpolation'],
        type=str
    )
    parser.add_argument(
        '--sequential',
        action='store_true',
        help='Whether to prune sequentially'
    )
    parser.add_argument(
        '--cpu_offload',
        action='store_true',
        help='Whether to offload model to CPU.'
    )
    # Misc params
    parser.add_argument(
        '--seed',
        default=0,
        type=int,
        help="random seed."
    )
    parser.add_argument(
        '--output_dir',
        type=str,
        default=None,
        help='Output directory where model checkpoints and results are stored.'
    )
    parser.add_argument(
        '--save_model',
        action='store_true',
        help='Whether to save pruned model'
    )
    parser.add_argument(
        "--dtype",
        type=str,
        default="auto",
        choices=["auto", "bfloat16", "float16", "float32"],
        help="dtype to load the model.",
    )
    parser.add_argument(
        '--low_cpu_mem_usage',
        action='store_true',
        help='Whether to load model with the use of `low_cpu_mem_usage`'
    )
    parser.add_argument(
        '--load_model_on_device',
        action='store_true',
        help='Whether to load model on device on pruning'
    )
    args = parser.parse_args()
    run(args)


def run(args):
    fix_seed(args.seed)
    # get device
    device = "cuda" if torch.cuda.is_available() else "cpu"
    # model
    torch_dtype = args.dtype
    if torch_dtype != 'auto':
        torch_dtype = getattr(torch, args.dtype)

    model = AutoModelForCausalLM.from_pretrained(
        args.model_name_or_path,
        trust_remote_code=True,
        torch_dtype=torch_dtype,
        low_cpu_mem_usage=args.low_cpu_mem_usage
    )
    max_sequence_length = float('inf')
    if hasattr(model.config, 'max_sequence_length'):
        max_sequence_length = model.config.max_sequence_length
    elif hasattr(model.config, 'max_position_embeddings'):
        max_sequence_length = model.config.max_position_embeddings
    elif hasattr(model.config, 'max_seq_len'):
        max_sequence_length = model.config.max_seq_len
    assert args.sequence_length <= max_sequence_length
    model.sequence_length = args.sequence_length

    # get tokenizer
    tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path)

    calibration_data = get_loaders(
        args.dataset_name_or_path, 
        args.calibration_dataset_size,
        args.seed,
        model.sequence_length,
        False,
        tokenizer
    )
    data_loader = [([], {'input_ids': input_ids}) for input_ids in calibration_data]

    # make dirs if needed
    if args.output_dir is not None:
        os.makedirs(args.output_dir, exist_ok=True)
    else:
        assert not args.save_model

    # init hooks and handles
    weights_orig = {}
    for module_name, module in model.named_modules():
        if re.search(args.module_regex, module_name):
            weights_orig[module_name] = module.weight.cpu().clone()

    if args.pruning_method == "FastOBC":
        obc_util_kwargs = {"block_size": args.block_size}
    elif args.pruning_method == "OBC":
        obc_util_kwargs = {"rows_in_parallel": args.rows_in_parallel}

    # create pruner
    pruner = Pruner(
        model,
        data_loader=data_loader,
        module_regex=args.module_regex,
        weights_orig=weights_orig,
        pruning_method=args.pruning_method,
        rel_damp=args.rel_damp,
        obc_util_kwargs=obc_util_kwargs,
        sequential=args.sequential,
        device=device,
        cpu_offload=args.cpu_offload,
        blocks=args.decoder_blocks,
        pre_modules=args.pre_decoder_modules,
        max_samples=args.calibration_dataset_size
    )

    eval_stats = {}
    print(f'{args.output_dir=}')
    for i in range(args.iterations):
        print(f"Iteration {i}/{args.iterations}")
        pruner.prune(args.sparsity, args.alpha)
        print('---Evaluation after pruning---')
        
        for eval_dataset_name in ['wikitext2', 'ptb', 'c4']:
            test_data = get_loaders(
                eval_dataset_name, 
                0,
                args.seed,
                args.sequence_length,
                True,
                tokenizer
            )
            test_loader = [([], {'input_ids': input_ids}) for input_ids in test_data]
            ppl = eval_perplexity(
                model, 
                test_loader, 
                args.decoder_blocks,
                args.pre_decoder_modules,
                args.post_decoder_modules,
                device,
                cpu_offload=True
            )
            print(f'Dataset: {eval_dataset_name}\nPerplexity: {ppl:.2f}')
            eval_stats[f'eval/iteration_{i}/{eval_dataset_name}'] = ppl

    if args.output_dir is not None:
        torch.save(eval_stats, os.path.join(args.output_dir, 'eval_results.pth'))

if __name__ == "__main__":
    sys.exit(main())  # pragma: no cover
