# Imports
import torch
from torchmetrics.image.inception import InceptionScore
import os
from PIL import Image
import torchvision.transforms as transforms
import argparse
_ = torch.manual_seed(123)

# Add argparse setup
parser = argparse.ArgumentParser(description='Get Inception Score (IS) scores')

# This score is only computed with respect to the generation
parser.add_argument(
    '--folder_sds', 
    type=str,
    default='./example/A_baby_bunny_sitting_on_top_of_a_stack_of_pancakes@20240507-182038/save/it9000-test', 
    help='Path to the folder containing SDS images'
    )

# GPU device
parser.add_argument(
    '--device',
    type=str,
    default='1' if torch.cuda.is_available() else 'cpu',
    help='Device to run inference on (cuda or cpu)'
)

args = parser.parse_args()

device = torch.device("cuda:" + str(args.device)) if torch.cuda.is_available() else torch.device("cpu")

# Function to load and preprocess images
def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        img = Image.open(os.path.join(folder, filename))
        img = img.convert('RGB')  # Convert to RGB if necessary
        transform = transforms.Compose([
            transforms.Lambda(lambda img: img.crop((0, 0, 512, 512))),
            transforms.Resize((299, 299)),  # Resize to InceptionV3 input size
            transforms.ToTensor(),  # Convert to tensor
        ])
        img = transform(img)
        images.append(img)
    return torch.stack(images)

inception = InceptionScore()
inception = inception.to(device)

# Load images from folders and move to GPU
sds_images = load_images_from_folder(args.folder_sds)*255
sds_images = sds_images.to(device, dtype=torch.uint8)

# Compute Inception Score
inception.update(sds_images)
mean_inception, std_inception = inception.compute()

print("Inception Score mean: ", mean_inception.item(),', standard deviation: ',std_inception.item())


