import os
import cv2
import sys
import torch
import random
import itertools
import numpy as np
import pandas as pd
import ujson as json
from PIL import Image
from torchvision import transforms
from collections import defaultdict
from modules.basic_utils import load_json
from torch.utils.data import Dataset
from config.base_config import Config
from datasets.video_capture import VideoCapture


class MSRVTTDataset(Dataset):
    """
        videos_dir: directory where all videos are stored 
        config: AllConfig object
        split_type: 'train'/'test'
        img_transforms: Composition of transforms
    """
    def __init__(self, config: Config, split_type = 'train', img_transforms=None):
        self.config = config
        self.videos_dir = config.videos_dir
        self.img_transforms = img_transforms
        self.split_type = split_type
        self.platform = config.platform
        self.diffusion_test_mode = config.diffusion_test_mode

        if self.platform == 'XX':
            dir = '/data/MSRVTT'

        else:
            raise NotImplementedError

        db_file = dir + '/MSRVTT_data.json'

        if self.diffusion_test_mode == 'benchmark':
            test_file_pth = '/MSRVTT_JSFUSION_test.csv'
        elif self.diffusion_test_mode == 'debug':
            test_file_pth = '/MSRVTT_JSFUSION_test_debug.csv'
        else:
            raise NotImplementedError
        test_csv = dir + test_file_pth

        if config.msrvtt_train_file == '7k':
            train_csv = dir + '/MSRVTT_train.7k.csv'
        elif config.msrvtt_train_file == '12':
            train_csv = dir + '/MSRVTT_train.0.0012k.csv'
        else:
            train_csv = dir + '/MSRVTT_train.9k.csv'

        self.db = load_json(db_file)
        if split_type == 'train':
            train_df = pd.read_csv(train_csv)
            self.train_vids = train_df['video_id'].unique()
            self._compute_vid2caption()
            self._construct_all_train_pairs()
        else:
            self.test_df = pd.read_csv(test_csv)

            
    def __getitem__(self, index):

        if self.split_type == 'train':

            video_path, caption, video_id, sen_id = self._get_vidpath_and_caption_by_index(index)
            imgs, idxs = VideoCapture.load_frames_from_video(video_path,
                                                             self.config.num_frames,
                                                             self.config.video_sample_type)
            if self.config.use_gen:
                gen_frame = self._get_genframe_by_senid(sen_id)

            # process images of video
            if self.img_transforms is not None:
                imgs = self.img_transforms(imgs)

            if self.config.use_gen:
                return {
                    'video_id': video_id,
                    'video': imgs,
                    'text': caption,
                    'gen_frame': gen_frame
                }
            else:
                return {
                    'video_id': video_id,
                    'video': imgs,
                    'text': caption,
                }
        else:
            video_path, caption, video_id = self._get_vidpath_and_caption_by_index(index)
            imgs, idxs = VideoCapture.load_frames_from_video(video_path,
                                                             self.config.num_frames,
                                                             self.config.video_sample_type)
            if self.config.use_gen:
                gen_frame = self._get_genframe_tst_by_vid(video_id)

            # process images of video
            if self.img_transforms is not None:
                imgs = self.img_transforms(imgs)

            if self.config.use_gen:
                return {
                    'video_id': video_id,
                    'video': imgs,
                    'text': caption,
                    'gen_frame': gen_frame
                }
            else:
                return {
                    'video_id': video_id,
                    'video': imgs,
                    'text': caption,
                }

    
    def __len__(self):
        if self.split_type == 'train':
            return len(self.all_train_pairs)
        return len(self.test_df)


    def _get_vidpath_and_caption_by_index(self, index):
        # returns video path and caption as string
        if self.split_type == 'train':

            vid, caption, senid = self.all_train_pairs[index]
            video_path = os.path.join(self.videos_dir, vid + '.mp4')
            return video_path, caption, vid, senid
        else:
            vid = self.test_df.iloc[index].video_id
            video_path = os.path.join(self.videos_dir, vid + '.mp4')
            caption = self.test_df.iloc[index].sentence

            return video_path, caption, vid

    def _get_genframe_by_senid(self, senid):
        genframe_pth = os.path.join(self.config.genframe_dir, str(senid) + '.npz')
        gen_frame = np.load(genframe_pth, allow_pickle=True)['data'].item()['data']
        return gen_frame

    def _get_genframe_tst_by_vid(self, vid):
        # print(f'>>>vid={vid}, type={type(vid)}') # >>>vid=video9773, type=<class 'str'>
        genframe_pth = os.path.join(self.config.genframe_dir_tst, vid + '.npz')
        gen_frame = np.load(genframe_pth, allow_pickle=True)['data'].item()['data']
        return gen_frame

    
    def _construct_all_train_pairs(self):
        self.all_train_pairs = []
        if self.split_type == 'train':
            for vid in self.train_vids:

                for caption, senid in zip(self.vid2caption[vid], self.vid2senid[vid]):
                    self.all_train_pairs.append([vid, caption, senid])
            
    def _compute_vid2caption(self):
        self.vid2caption = defaultdict(list)

        self.vid2senid   = defaultdict(list)

        for annotation in self.db['sentences']:
            caption = annotation['caption']
            vid = annotation['video_id']
            self.vid2caption[vid].append(caption)

            senid = annotation['sen_id']
            self.vid2senid[vid].append(senid)
