import os
import csv
import json
import torch
import hydra
from hydra.core.hydra_config import HydraConfig
from transformers import AutoModelForCausalLM, AutoTokenizer
from peft import PeftModel
from glob import glob

from src.utils import init_script
from src.lightningutil.modelmodule import *
from src.lightningutil.datamodule import create_datamod
from src.lightningutil.util import create_log_dir, init_default_callbacks
from src.lightningutil.strategy import MyDeepSpeedStrategy
from src.lightningutil.callbacks import TransformerCheckpointIO

from pytorch_lightning import Trainer, seed_everything
from pytorch_lightning.utilities import rank_zero_info
from omegaconf import OmegaConf


def main():
    dataconf = OmegaConf.load("configs/data/tofu.yaml")
    dataconf['split'] = "forget10_perturbed"
    dataconf['batch_size'] = 16
    all_configs = glob("configs/data_mode/*.yaml")
    tokenizer = AutoTokenizer.from_pretrained("locuslab/tofu_ft_llama2-7b")
    tokenizer.padding_side = "right"

    import ipdb; ipdb.set_trace()
    for conf_path in all_configs:
        print("Mode:", conf_path)
        conf = OmegaConf.load(conf_path)

        data_module = create_datamod(configs=dataconf, tokenizer=tokenizer, data_mode_config=conf)
        data_module.prepare_data()
        data_module.setup('fit')

        batch = next(iter(data_module.train_dataloader()))
        stats = data_module.stats()
        print(json.dumps(stats, indent=2))
        print(batch.keys())
        for k in ['input_ids', 'prefer_input_ids']:
            if k in batch:
                print(tokenizer.batch_decode(batch[k][:4], skip_special_tokens=True))
        if 'retainlabels' in batch:
            print(batch['retainlabels'])
        import ipdb; ipdb.set_trace()



if __name__ == "__main__":
    main()