# Imports
import torch
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
import os
from PIL import Image
import torchvision.transforms as transforms
import argparse
_ = torch.manual_seed(123)

# Add argparse setup
parser = argparse.ArgumentParser(description='Get LPIPS scores')

parser.add_argument(
    '--folder_ddim', 
    type=str,
    default='./example/A_3D_model_of_an_adorable_cottage_with_a_thatched_roof@20240507-161523/save/it9000-test', 
    help='Path to the folder containing DDIM ground truth images'
    )
parser.add_argument(
    '--folder_sds', 
    type=str,
    default='./example/A_baby_bunny_sitting_on_top_of_a_stack_of_pancakes@20240507-182038/save/it9000-test', 
    help='Path to the folder containing SDS images'
    )

# GPU device to run CLIP network
parser.add_argument(
    '--device',
    type=str,
    default='1' if torch.cuda.is_available() else 'cpu',
    help='Device to run inference on (cuda or cpu)'
)

args = parser.parse_args()
device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

# Function to load and preprocess images
def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        img = Image.open(os.path.join(folder, filename))
        img = img.convert('RGB')  # Convert to RGB if necessary
        transform = transforms.Compose([
            transforms.Lambda(lambda img: img.crop((0, 0, 512, 512))),
            transforms.Resize((100, 100)),  # Resize to InceptionV3 input size
            transforms.ToTensor(),  # Convert to tensor
        ])
        img = transform(img)
        images.append(img)
    return torch.stack(images)

lpips = LearnedPerceptualImagePatchSimilarity(net_type='squeeze')
lpips = lpips.to(device)

# Load images from folders and move to GPU
ddim_images = load_images_from_folder(args.folder_ddim)
# Rescale to [-1,1]
ddim_images = ddim_images*2-1
ddim_images = ddim_images.to(device)

sds_images = load_images_from_folder(args.folder_sds)
# Rescale to [-1,1]
sds_images = sds_images*2-1
sds_images = sds_images.to(device)

# Compute LPIPS
lpips_score = lpips(ddim_images,sds_images)

print("LPIPs Score:", lpips_score.item())


