import math
import argparse
from tqdm import tqdm
import random
import os
import pickle
import re
import uuid 
from typing import Optional
from collections import OrderedDict

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from transformers import AutoTokenizer, PreTrainedTokenizerFast, TrainerCallback, TrainingArguments, logging
from transformers import GPT2LMHeadModel, GPT2Config, Trainer, TrainerCallback, TrainerControl, TrainerState

from transformer_lens import HookedTransformer

from utils import *
from model import CustomGPT2LMHeadModel
from generate import get_test_activation, generate_and_print, run_test

logging.set_verbosity_info()

def choose_random_pos(
        args,
        tokens_batch: torch.LongTensor,
        tokenizer: PreTrainedTokenizerFast,
    ):
    batch_size, seq_len = tokens_batch.size()
    device = tokens_batch.device
    pad_mask = (tokens_batch == tokenizer.pad_token_id) | (tokens_batch == tokenizer.eos_token_id)
    arange_idx = torch.arange(batch_size, dtype=torch.long, device=device)
 
    if args.probed_task == "counting":
        pos_temp = (tokens_batch == tokenizer.vocab["|"]).float().argmax(dim=1, keepdim=True)
        arange_temp = torch.arange(seq_len, dtype=torch.long, device=device).unsqueeze(0).expand(batch_size, -1)
        mask_temp = (arange_temp <= pos_temp) | (arange_temp > pos_temp+2)
        pad_mask = pad_mask | mask_temp

    random_temp = torch.rand_like(tokens_batch, dtype=torch.float)
    random_temp.masked_fill_(pad_mask, -1e6)
    if args.pos_bias is None:  
        sel_pos = random_temp.argmax(dim=1)
    else:
        bias = torch.linspace(0, args.pos_bias, seq_len, device=device).unsqueeze(0)
        sel_pos = (random_temp + bias).argmax(dim=1)
    # print(sel_pos)
    # print(tokenizer.convert_ids_to_tokens(tokens_batch[arange_idx, sel_pos].tolist()))
    return arange_idx, sel_pos

def collect_data(
        args, 
        hooked_model: HookedTransformer, 
        tokenizer: PreTrainedTokenizerFast,
        act_dataset: ActivationDataset, 
        rollouts: rolloutManagerIOI, 
        probed_acts: list[str], 
        prob_weight: Optional[dict[str, float]] = None
    ):
    """ collect activation and corresponding rollout"""
    
    batch_size = args.caching_batch_size
    act_dataset.re_init()
    num_head = hooked_model.cfg.n_heads
    device = hooked_model.embed.W_E.device
      
    if prob_weight is None:
        prob_weight = OrderedDict({k: 0.0 for k in probed_acts})
    prob_weight_v = torch.tensor(list(prob_weight.values()), device=device)
    random_choice = torch.distributions.Categorical(logits=prob_weight_v)

    with torch.no_grad():
        print("collecting activations...")
        pbar = tqdm(total=act_dataset.data_per_epoch)
        while len(act_dataset) < act_dataset.data_per_epoch:
            text_batch = rollouts.next_batch(batch_size)
            text_batch = list(map(lambda x: x + tokenizer.eos_token, text_batch))
            
            tokens_batch = hooked_model.to_tokens(text_batch)
            assert (tokens_batch[:,0] == tokenizer.bos_token_id).all()

            arange_idx, sel_pos = choose_random_pos(args, tokens_batch, tokenizer)

            probed_act = list(prob_weight.keys())[random_choice.sample().item()]
            assert probed_act in probed_acts

            block_idx = int(re.search(r"blocks\.(\d+)\.", probed_act).group(1))
            cache = add_necessary_hooks(hooked_model, [probed_act])
            hooked_model(tokens_batch, stop_at_layer=block_idx+1)
            activation = retrieve_act(probed_act, cache, num_head)
            end_caching(hooked_model)

            selected_act = activation[arange_idx, sel_pos]
            assert selected_act.dim() == 2

            added_num = act_dataset.add_data(selected_act.cpu(), tokens_batch.tolist(), probed_act)
            pbar.update(added_num)

    pbar.close()
    print("dataset is created ")
    
    
def collect_test_data(args, hooked_model: HookedTransformer, rollouts: rolloutManagerIOI, probed_acts: list[str]):
    batch_size = 10
    num_step = args.num_test_rollout // batch_size
    total_num = num_step * batch_size
    device = hooked_model.embed.W_E.device
    
    cache = add_necessary_hooks(hooked_model, probed_acts)
    test_act_dataset = {probed_act: ActivationDataset(total_num, probed_acts) for probed_act in probed_acts}
    torch.manual_seed(99)
    with torch.no_grad():
        print("collecting activations...")
        for i in tqdm(range(num_step)):
            text_batch = rollouts.next_batch(batch_size)
            while torch.rand(()) < 0.8:
                text_batch = rollouts.next_batch(batch_size)

            text_batch = list(map(lambda x: x + tokenizer.eos_token, text_batch))

            tokens_batch = hooked_model.to_tokens(text_batch)

            hooked_model(tokens_batch, return_type=None)

            arange_idx, sel_pos = choose_random_pos(args, tokens_batch, tokenizer)

            for probed_act in probed_acts:
            
                activation = retrieve_act(probed_act, cache, None)
                selected_act = activation[arange_idx, sel_pos]
                assert selected_act.dim() == 2
                test_act_dataset[probed_act].add_data(selected_act.cpu(), tokens_batch.tolist(), probed_act)

    print("test dataset is created ")
    end_caching(hooked_model)
    del rollouts

    return test_act_dataset



if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--seed", type=int, default=0)
    parser.add_argument("--save_dir", type=str, default="temp") # choices=["temp", "random", "named"]
    parser.add_argument("--num_epoch", type=int, default=50)
    parser.add_argument("--data_per_epoch", type=int, default=1_000_000)
    parser.add_argument("--num_test_rollout", type=int, default=1_000)
    parser.add_argument("--pos_bias", type=float, default=None)
    parser.add_argument("--batch_size", type=int, default=256)
    parser.add_argument("--caching_batch_size", type=int, default=64)
    parser.add_argument("--lr", type=float, default=1e-4)
    parser.add_argument("--acc_steps", type=int, default=1)
    parser.add_argument("--fp16", action="store_true")
    parser.add_argument("--compile", action="store_true")
    parser.add_argument("--arch_h", type=int, default=4)
    parser.add_argument("--arch_d", type=int, default=256)
    parser.add_argument("--arch_l", type=int, default=2)
    parser.add_argument("--gen_num", type=int, default=1)
    parser.add_argument("--cross_attn", action="store_true")
    parser.add_argument("--rebalance", type=float, default=0.0) # 6.0
    parser.add_argument("--probed_task", type=str, choices=["ioi", "addition", "counting"])
    parser.add_argument("--cossim", action="store_true")
    parser.add_argument("--parallel_num", type=int, default=1)
    parser.add_argument("--no_saving", action="store_true")


    args = parser.parse_args()

    torch.manual_seed(args.seed)
    random.seed(args.seed)
    print("device num", torch.cuda.device_count())
    device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

    print(args)

    probed_model_path, data_path, exp_save_dir, max_len = get_paths(args)
    print("experiment directory:", exp_save_dir)

    if args.probed_task not in ["addition" "counting"]:
        tokenizer = AutoTokenizer.from_pretrained(probed_model_path, add_bos_token=True)
    else:
        tokenizer = None
    probed_model = GPT2LMHeadModel.from_pretrained(probed_model_path)
    
    if args.probed_task in ["ioi",]:
        tokenizer.add_special_tokens({"eos_token": "[EOS]", "pad_token": "[PAD]"})
        #  bos: <|endoftext|> 50256     eos: [EOS] 50257     pad: [PAD] 50258
        probed_model.resize_token_embeddings(probed_model.config.vocab_size+2)
        probed_model.config.eos_token_id = tokenizer.eos_token_id
        probed_model.config.pad_token_id = tokenizer.pad_token_id

    hooked_model = HookedTransformer.from_pretrained(
            "gpt2",
            hf_model=probed_model,
            tokenizer=tokenizer,
            n_embd=probed_model.config.n_embd,
            n_layer=probed_model.config.n_layer,
            n_head=probed_model.config.n_head,
            vocab_size=probed_model.config.vocab_size,
            n_ctx=probed_model.config.n_positions,
    )
    print(hooked_model)
    print(hooked_model.embed.W_E.device)
    
    if args.probed_task == "addition":
        sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
        from addition.train import customTokenizer, make_dataset
        tokenizer = customTokenizer()
        hooked_model.tokenizer = tokenizer
    elif args.probed_task == "counting":
        sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
        from counting.train import customTokenizer, make_dataset
        tokenizer = customTokenizer()
        hooked_model.tokenizer = tokenizer
    hooked_model.eval()
    del probed_model


    if args.probed_task == "ioi":
        probed_acts = ["blocks.0.hook_resid_pre"]
        for i in range(hooked_model.cfg.n_layers):
            for j in range(hooked_model.cfg.n_heads):
                probed_acts.append(f"blocks.{i}.attn.hook_result.{j}")
            probed_acts.append(f"blocks.{i}.hook_mlp_out")
    elif args.probed_task == "counting":
        probed_acts = ["blocks.0.hook_resid_pre"]
        for i in range(hooked_model.cfg.n_layers):
            for j in range(hooked_model.cfg.n_heads):
                probed_acts.append(f"blocks.{i}.attn.hook_result.{j}")
            probed_acts.append(f"blocks.{i}.hook_resid_mid")
            probed_acts.append(f"blocks.{i}.hook_mlp_out")
            probed_acts.append(f"blocks.{i}.hook_resid_post")
    elif args.probed_task == "addition":
        probed_acts = ["blocks.0.hook_resid_pre"]
        for i in range(hooked_model.cfg.n_layers):
            for j in range(hooked_model.cfg.n_heads):
                probed_acts.append(f"blocks.{i}.attn.hook_result.{j}")
            probed_acts.append(f"blocks.{i}.hook_resid_mid")
            probed_acts.append(f"blocks.{i}.hook_resid_post")
    
    
    act_dataset = ActivationDataset(args.data_per_epoch, probed_acts)
    if args.probed_task == "ioi":
        rollouts = rolloutManagerIOI(data_path)
        test_rollouts = rolloutManagerIOI(data_path)
    elif args.probed_task == "addition":
        rollouts = rolloutManagerAddition(make_dataset(tokenizer, train_ratio=1.0)[0])
        test_rollouts = rolloutManagerAddition(make_dataset(tokenizer, train_ratio=1.0)[0])
    elif args.probed_task == "counting":
        rollouts = rolloutManagerCounting(make_dataset(tokenizer, train_ratio=1.0)[0])
        test_rollouts = rolloutManagerCounting(make_dataset(tokenizer, train_ratio=1.0)[0])


    data_collator = customCollator(tokenizer)

    collect_data(args, hooked_model, tokenizer, act_dataset, rollouts, probed_acts)
    test_act_dataset = collect_test_data(args, hooked_model, test_rollouts, probed_acts)
    test_loader = {
        probed_act: DataLoader(test_act_dataset[probed_act], batch_size=args.caching_batch_size, shuffle=False, collate_fn=data_collator)
            for probed_act in probed_acts
    }

    act_proc_cfg = {
        "act_dim": act_dataset.activations.size(1),
        "probed_acts": probed_acts,
        "cross_attn": args.cross_attn,
        "act_proc_resid_dim": act_dataset.activations.size(1) + len(probed_acts),
        "act_proc_mid_dim": act_dataset.activations.size(1),
        "parallel_num": args.parallel_num,
    }

    cfg = GPT2Config(vocab_size=len(tokenizer), 
                    n_positions=max_len+args.parallel_num,  # for eos, parallel_num
                    n_embd=args.arch_d,
                    n_layer=args.arch_l,
                    n_head=args.arch_h,
                    bos_token_id=tokenizer.bos_token_id, 
                    eos_token_id=tokenizer.eos_token_id,
                    pad_token_id=tokenizer.pad_token_id,
                    attn_pdrop=0.0,
                    )
    model = CustomGPT2LMHeadModel(cfg, act_proc_cfg=act_proc_cfg)
    print(model.config.bos_token_id)
    print(model.config.eos_token_id)
    print(model.config.pad_token_id)

    training_args = TrainingArguments(
        output_dir=exp_save_dir,
        overwrite_output_dir=True,
        per_device_train_batch_size=args.batch_size,
        gradient_accumulation_steps=args.acc_steps,
        evaluation_strategy='no',
        num_train_epochs=args.num_epoch,
        save_strategy="epoch" if not args.no_saving else "no",
        save_total_limit=1,
        logging_strategy="epoch",
        learning_rate=args.lr,
        weight_decay=0.01,
        optim='adamw_torch',
        fp16=args.fp16,
        lr_scheduler_type='constant',
        report_to="none",
        torch_compile=args.compile,
    )

    class RegenerateAndGenerate(TrainerCallback):
        def on_epoch_begin(self, train_args: TrainingArguments, state: TrainerState, control: TrainerControl, model: CustomGPT2LMHeadModel, optimizer, **kwargs):
            
            if state.epoch > 0:
                # regenerate data in act_dataset
                collect_data(args, hooked_model, tokenizer, act_dataset, rollouts, probed_acts, model.prob_weight)

        def on_epoch_end(self, train_args: TrainingArguments, state: TrainerState, control: TrainerControl, model: CustomGPT2LMHeadModel, **kwargs):
            if round(state.epoch) % 5 == 0:
                model.eval()
                probed_act = probed_acts[torch.randint(0, len(probed_acts), ()).item()]
                print("\n", probed_act, "is selected")
                
                test_act = get_test_activation(args, probed_act, hooked_model)
                generate_and_print(args, test_act, model, tokenizer)

                prob_weight = run_test(args, probed_acts, test_loader, model, tokenizer, hooked_model, max_len)
                model.prob_weight = prob_weight
                model.train()


    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=act_dataset,
        data_collator=data_collator,
        callbacks=[RegenerateAndGenerate],
    )

    

    trainer.train()
