import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import os
import glob

from metric import Retrieval_metrics, MRR


class SycDataset(Dataset):
    def __init__(self, datasetA, datasetB):
        self.datasetA = datasetA
        self.datasetB = datasetB

    def __getitem__(self, index):
        CLIP_A, CLAP_A = self.datasetA[index]
        CLIP_B, CLAP_B = self.datasetB[index]
        return CLIP_A, CLAP_A, CLIP_B, CLAP_B

    def __len__(self):
        return len(self.datasetA)

class TextDataset(Dataset):
    def __init__(self, dataset_list):
        CLAP_COCO = './embedding/CLAP_COCO_train-val_text.pt'
        CLAP_AudioCap = './embedding/CLAP_AudioCap_text.pt'
        CLAP_Clotho = './embedding/CLAP_Clotho_text.pt'
        CLAP_CC1M = './embedding/CLAP_CC1M_text.pt'
        CLAP_MSRVTT = './embedding/CLAP_MSRVTT_text.pt'
        CLAP_MAD = './embedding/CLAP_MAD_text.pt'
        CLAP_AVS100 = './embedding/CLAP_AVS100.pt'

        CLIP_COCO = './embedding/CLIP_COCO_train-val_text.pt'
        CLIP_AudioCap = './embedding/CLIP_AudioCap_text.pt'
        CLIP_Clotho = './embedding/CLIP_Clotho_text.pt'
        CLIP_CC1M = './embedding/CLIP_CC1M_text.pt'
        CLIP_MSRVTT = './embedding/CLIP_MSRVTT_text.pt'
        CLIP_MAD = './embedding/CLIP_MAD_text.pt'
        CLIP_AVS100 = './embedding/CLIP_AVS100.pt'

        CLIP_text_data = []
        CLAP_text_data = []
        for text in dataset_list:
            if text == 'COCO':
                CLIP_text_data.append(F.normalize(torch.load(CLIP_COCO), dim=-1).cpu())
                CLAP_text_data.append(F.normalize(torch.load(CLAP_COCO), dim=-1).cpu())
            if text == 'AudioCap':
                CLIP_text_data.append(F.normalize(torch.load(CLIP_AudioCap), dim=-1).cpu())
                CLAP_text_data.append(F.normalize(torch.load(CLAP_AudioCap), dim=-1).cpu())
            if text == 'Clotho':
                CLIP_text_data.append(F.normalize(torch.load(CLIP_Clotho), dim=-1).cpu())
                CLAP_text_data.append(F.normalize(torch.load(CLAP_Clotho), dim=-1).cpu())
            if text == 'CC1M':
                CLIP_text_data.append(F.normalize(torch.load(CLIP_CC1M), dim=-1).cpu())
                CLAP_text_data.append(F.normalize(torch.load(CLAP_CC1M), dim=-1).cpu())
            if text == 'MSRVTT':
                CLIP_text_data.append(F.normalize(torch.load(CLIP_MSRVTT), dim=-1).cpu())
                CLAP_text_data.append(F.normalize(torch.load(CLAP_MSRVTT), dim=-1).cpu())
            if text == 'MAD':
                CLIP_text_data.append(F.normalize(torch.load(CLIP_MAD), dim=-1).cpu())
                CLAP_text_data.append(F.normalize(torch.load(CLAP_MAD), dim=-1).cpu())
            if text == 'AVS100':
                CLIP_text_data.append(F.normalize(torch.load(CLIP_AVS100), dim=-1).cpu())
                CLAP_text_data.append(F.normalize(torch.load(CLAP_AVS100), dim=-1).cpu())


        self.CLIP_text_embs = torch.cat(CLIP_text_data, dim=0)
        self.CLAP_text_embs = torch.cat(CLAP_text_data, dim=0)

    def __len__(self):
        # return len(glob.glob(os.path.join(self.folder, '*.pt')))//2
        return self.CLIP_text_embs.shape[0]

    def __getitem__(self, idx):
        CLIP_emb = self.CLIP_text_embs[idx]
        CLAP_emb = self.CLAP_text_embs[idx]

        return CLIP_emb, CLAP_emb

class AudioImageDataset(Dataset):
    def __init__(self, dataset):
        CLAP_Flickr = './embedding/CLAP_Flickr_Audio.pt'
        CLIP_Flickr = './embedding/CLIP_Flickr_Image.pt'
        CLAP_AVE = './embedding/CLAP_AVE_Audio.pt'
        CLIP_AVE = './embedding/CLIP_Full-AVE_Image.pt'

        if dataset == 'Flickr':
            self.audios_embs = F.normalize(torch.load(CLAP_Flickr), dim=-1)
            self.images_embs = F.normalize(torch.load(CLIP_Flickr), dim=-1)
        elif dataset == 'AVE':
            self.audios_embs = F.normalize(torch.load(CLAP_AVE), dim=-1)
            self.images_embs = F.normalize(torch.load(CLIP_AVE), dim=-1)

    def __len__(self):
        return len(self.audios_embs)

    def __getitem__(self, idx):
        # Implement the logic to retrieve a single item of data at index `idx`
        # and return it as a tuple of tensors (input, target)
        audio_emb = self.audios_embs[idx]
        image_emb = self.images_embs[idx]
        return audio_emb, image_emb

class ImageNet_AudioSet_Dataset(Dataset):
    def __init__(self, length):
        ImageNet_path = './embedding/CLIP_ImageNet1K.pt'
        AudioSet_path = './embedding/CLAP_AudioSet.pt'

        self.images_embs = F.normalize(torch.load(ImageNet_path), dim=-1)
        self.audios_embs = F.normalize(torch.load(AudioSet_path), dim=-1)

        self.images_embs = self.resample_fix_length(self.images_embs, length)
        self.audios_embs = self.resample_fix_length(self.audios_embs, length)

    def resample_fix_length(self, embs, length):
        ori_embs = embs
        while len(embs) != length:
            if len(embs) < length:
                embs = torch.cat([embs, ori_embs], dim=0)
            else:
                embs = embs[0:length]
        return embs

    def __len__(self):
        return len(self.audios_embs)

    def __getitem__(self, idx):
        # Implement the logic to retrieve a single item of data at index `idx`
        # and return it as a tuple of tensors (input, target)
        audio_emb = self.audios_embs[idx]
        image_emb = self.images_embs[idx]
        return image_emb, audio_emb
