

import os
import json
import math
import torch
from torch.utils.data import Dataset
from config.utils import *
from config.options import *
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from tqdm import tqdm
from transformers import BertTokenizer
import clip

try:
    from torchvision.transforms import InterpolationMode
    BICUBIC = InterpolationMode.BICUBIC
except ImportError:
    BICUBIC = Image.BICUBIC

def _convert_image_to_rgb(image):
    return image.convert("RGB")

# def _transform(n_px):
#     return Compose([
#         Resize(n_px, interpolation=BICUBIC),
#         _convert_image_to_rgb,
#         ToTensor(),
#         Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
#     ])

def init_tokenizer():
    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
    tokenizer.add_special_tokens({'bos_token':'[DEC]'})
    tokenizer.add_special_tokens({'additional_special_tokens':['[ENC]']})       
    tokenizer.enc_token_id = tokenizer.additional_special_tokens_ids[0]  
    return tokenizer

class InpaintRewardDataset(Dataset):
    
    def __init__(self, dataset, root_path):

        self.inpaint_imgs = None
        self.mask_imgs = None
        self.scores = None
        self.ranks = None
        self.root_path = root_path 
        
        self.clip, self.preprocess = clip.load("ViT-B/32", device="cuda")  #clip.load(config['clip_model'], device="cuda" if torch.cuda.is_available() else "cpu")
        self.dataset_path = os.path.join(config['data_base'], f"{dataset}.json")
        with open(self.dataset_path, "r") as f:
                self.data_json = json.load(f)
        self.data = [{'img_id': key.rsplit('_', 1)[0], 'img_id_v123': key, 'inpaint_img': key, 'score': value[0], 'rank': value[1]} for key, value in self.data_json.items()]
        
        self.get_fileinfo() 

        self.iters_per_epoch = int(math.ceil(len(self.data)*1.0/opts.batch_size))

    def get_fileinfo(self):
        inpaint_images_list = []
        scores_list = []
        ranks_list = []
        img_ids_list = []
        img_ids_v123_list = []
        
        for item in self.data:
            inpaint_images_list.append(item['inpaint_img'])
            scores_list.append(item['score'])
            ranks_list.append(item['rank'])
            img_ids_list.append(item['img_id'])
            img_ids_v123_list.append(item['img_id_v123'])
        self.inpaint_imgs = [os.path.join(self.root_path, 'inpaint', name) for name in inpaint_images_list]
        self.scores = scores_list
        self.ranks = ranks_list
        self.img_ids = img_ids_list 
        self.img_ids_v123 = img_ids_v123_list
        self.mask_imgs = [os.path.join(self.root_path, 'masks_color', name.rsplit('_', 1)[0]+'_mask_color') for name in inpaint_images_list]
        
    def get_file(self, index):
        inpaint = Image.open(self.inpaint_imgs[index] + '.png')
        mask_rgb = Image.open(self.mask_imgs[index] + '.png')
        inpaint = self.preprocess(inpaint).unsqueeze(0)
        mask_rgb = self.preprocess(mask_rgb).unsqueeze(0)
        score = torch.tensor([self.scores[index]], dtype=torch.float32)
        rank = torch.tensor([self.ranks[index]], dtype=torch.float32)
        img_id = self.img_ids[index]
        img_id_v123 = self.img_ids_v123[index]
        data = {
            'img_id': img_id,
            'img_id_v123': img_id_v123,
            "inpaint": inpaint,
            'mask_rgb': mask_rgb,
            'score': score,
            'rank': rank,
        }
        return data
        

    def __getitem__(self, index):
        return self.get_file(index)

    def __len__(self):
        return len(self.data)
    
    
class InpaintRewardDatasetGroup(Dataset):
    
    def __init__(self, dataset, root_path):

        self.better_inpaint_imgs = None
        self.better_mask_imgs = None
        self.worse_inpaint_imgs = None
        self.worse_mask_imgs = None
        self.scores = None
        self.root_path = root_path 
        
        self.clip, self.preprocess = clip.load("ViT-B/32", device="cuda")  #clip.load(config['clip_model'], device="cuda" if torch.cuda.is_available() else "cpu")
        self.dataset_path = os.path.join(config['data_base'], f"{dataset}.json")
        with open(self.dataset_path, "r") as f:
                self.data_json = json.load(f)
        self.data = [{'img_id': key, 'generations': value} for key, value in self.data_json.items()]
        
        self.get_fileinfo() 

        self.iters_per_epoch = int(math.ceil(len(self.data)*1.0/opts.batch_size))

    def get_fileinfo(self):
        better_inpaint_images_list = []
        worse_inpaint_images_list = []
        
        for id, generations in self.data_json.items():
            for img_name, value in generations.items():
                if value[1] == 1:  
                    better_inpaint_images_list.append(img_name)
                if value[1] == 3:
                    worse_inpaint_images_list.append(img_name)

        self.better_inpaint_imgs = [os.path.join(self.root_path, 'inpaint', name) for name in better_inpaint_images_list]
        self.worse_inpaint_imgs = [os.path.join(self.root_path, 'inpaint', name) for name in worse_inpaint_images_list]
        self.better_mask_imgs = [os.path.join(self.root_path, 'masks_color', name.rsplit('_', 1)[0]+'_mask_color') for name in better_inpaint_images_list]
        self.worse_mask_imgs = [os.path.join(self.root_path, 'masks_color', name.rsplit('_', 1)[0]+'_mask_color') for name in worse_inpaint_images_list]
    
    def get_file(self, index):

        better_inpaint = Image.open(self.better_inpaint_imgs[index] + '.png')
        better_mask = Image.open(self.better_mask_imgs[index] + '.png')
        better_inpaint = self.preprocess(better_inpaint).unsqueeze(0)
        better_mask = self.preprocess(better_mask).unsqueeze(0)

        worse_inpaint = Image.open(self.worse_inpaint_imgs[index] + '.png')
        worse_mask = Image.open(self.worse_mask_imgs[index] + '.png')
        worse_inpaint = self.preprocess(worse_inpaint).unsqueeze(0)
        worse_mask = self.preprocess(worse_mask).unsqueeze(0)

        data = {
            "better_inpt": better_inpaint,
            'better_msk': better_mask,
            "worse_inpt": worse_inpaint,
            'worse_msk': worse_mask,
        }

        return data
        

    def __getitem__(self, index):
        return self.get_file(index)

    def __len__(self):
        return len(self.data)
    