import os
from pathlib import Path
from typing import Union, Tuple

import pandas as pd
import torch
from torch.utils.data import Dataset

import PIL
from PIL import Image
from PIL.ImageOps import exif_transpose
import argparse
from tqdm import tqdm

from transformers import CLIPProcessor, CLIPModel
import ImageReward as reward

class ImagePromptDataset(Dataset):
    """
    Dataset of prompts & images. Will read from source CSV and output full row + image
    
    Parameters:
        root (Union[str, os.PathLike]): Path to CSV of prompts and images
    """
    def __init__(self, root: Union[str, os.PathLike]):
        assert Path(root).suffix.lower() == ".csv", "Expected csv file"
        self.root = Path(root)
        self.df = pd.read_csv(root)
    
    def __len__(self) -> int:
        return len(self.df)

    def __getitem__(self, index: int) -> Tuple[str, PIL.Image.Image]:
        row = self.df.iloc[0]
        prompt = row['Prompt']
        path = row['Gen Image Path']
        
        img = Image.open(path).convert("RGB")
        img = exif_transpose(img)
        
        return (prompt, img)
    
def build_clip_model(device='cuda'):
    model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
    return dict(
        device=device,
        model=model.to(device),
        processor=processor
    )

@torch.no_grad()
def clip_score(prompt,image,model_dict):

    scores = []
    inputs = model_dict['processor'](text=prompt, images=image, return_tensors="pt", padding=True,truncation=True, max_length=77).to(model_dict['device'])
    score = model_dict['model'](**inputs).logits_per_image.item() # Should only be 1 value
    return score


def build_image_reward_model(device='cuda'):
    return dict(model=reward.load("ImageReward-v1.0").to(device),device=device)

@torch.no_grad()
def get_imagereward_score(prompt,image,model_dict):
    model = model_dict['model']
    score = model.score(prompt, image)
    return score
