import os
from collections import namedtuple

import torch
import pickle
import argparse
import yaml

from torchtyping import TensorType as TT

import torch
from utils.cspa_main import (
    get_cspa_per_checkpoint
)
from utils.data_processing import get_ckpts

if torch.cuda.is_available():
    device = "cuda"
else:
    device = "cpu"


def get_args():
    parser = argparse.ArgumentParser(description="Get CPSA per checkpoint and attention head")
    parser.add_argument(
        "-c",
        "--config",
        default="./configs/cspa/160m-canonical.yml",
        help="Path to config file",
    )
    return parser.parse_args()


def read_config(config_path):
    with open(config_path, "r") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)
    return config


def main(args):

    if 'device' in args:
        device = args.device
    else:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    torch.set_grad_enabled(False)
    print(f"Using device: {device}")

    config = read_config(args.config)

    #checkpoints = get_ckpts(config['checkpoint_schedule'])
    checkpoints = config['checkpoint_schedule']
    print(config)


    get_cspa_per_checkpoint(
        config['base_model'], 
        config['variant'], 
        config['cache'], 
        device, 
        checkpoints, 
        start_layer=config["start_layer"], 
        overwrite=config["overwrite"], 
        display_all=False
    )


if __name__ == "__main__":
    args = get_args()
    main(args)