import pandas as pd
from datasets import load_dataset
import hpsv2
from tqdm import tqdm

def save_pickapic_captions():
    splits = ['validation_unique', 'test_unique']
    prompts = {split: [] for split in splits}
    for split in splits:
        dataset = iter(load_dataset("yuvalkirstain/pickapic_v2_no_images", split=split))
        for sample in tqdm(dataset, desc=split):
            prompts[split].append(sample['caption'])
        assert len(prompts[split]) == 500
    
    for split in prompts:
        fname = f"pickapic_v2_{split}.csv"
        df = pd.DataFrame({"Prompt": prompts[split]})
        df.to_csv(fname, index=False)
        print(f"Pick-a-Pic v2 prompts saved to pickapic_v2_{split}.csv")
        
    return

def save_pickapic_captions_v1():
    splits = ['validation_unique', 'test_unique']
    prompts = {split: [] for split in splits}
    for split in splits:
        dataset = iter(load_dataset("yuvalkirstain/pickapic_v1_no_images", split=split))
        for sample in tqdm(dataset, desc=split):
            prompts[split].append(sample['caption'])
        assert len(prompts[split]) == 500
    
    for split in prompts:
        fname = f"pickapic_v1_{split}.csv"
        df = pd.DataFrame({"Prompt": prompts[split]})
        df.to_csv(fname, index=False)
        print(f"Pick-a-Pic v2 prompts saved to pickapic_v1_{split}.csv")
        
    return

def save_partiprompts():
    dataset = load_dataset("nateraw/parti-prompts")['train']
    fname = f"partiprompts.csv"
    dataset.to_csv(fname)
    print(f"Saved PartiPrompts to {fname}")
    
    return

def save_hps():
    all_prompts = hpsv2.benchmark_prompts('all')
    prompts = {
                "Prompt": [],
                "Style": [],
            }
    for style in tqdm(all_prompts, desc="Saving HPS prompts"):
        for prompt in all_prompts[style]:
            prompts["Prompt"].append(prompt)
            prompts["Style"].append(style)
            
    df = pd.DataFrame(prompts)
    assert len(df) == 3200
    fname = "hps_v2_prompts.csv"
    df.to_csv(fname, index=False)
    print(f"Saved HPS prompts to {fname}.")
    return
    

if __name__ == "__main__":
    save_pickapic_captions()
    save_partiprompts()
    save_hps()
    save_pickapic_captions_v1()