import argparse
import os
import pickle

import sg_utils
from sg_utils import get_default_device


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--model", type=str, required=True)
    parser.add_argument("--output-dir", type=str, required=True)
    args = parser.parse_args()

    device = get_default_device()

    model_name = args.model
    simplified_model_name = (
        model_name.replace("/", "_").replace(":", "_").replace("__", "_")
    )
    target_fn = os.path.join(args.output_dir, simplified_model_name + ".pkl")

    if os.path.exists(target_fn):
        print(f"File {target_fn} already exists, skipping...")
        return

    model = sg_utils.load_model(model_name, device=device)
    network_layers = sg_utils.get_relevant_layers(model, args.model, strict_mode=False)

    with open(target_fn, "wb") as f:
        pickle.dump(network_layers, f)


if __name__ == "__main__":
    main()
