

import av
import ffmpeg
from joblib import Parallel, delayed
from multiprocessing import Manager
import numpy as np
import os
import pickle
import random
import torch
import torch.utils.data
import glob

try: 
    from decoder import decode
except:
    from .decoder import decode

try: 
    from transforms_SF import random_short_side_scale_jitter, random_crop, horizontal_flip,grayscale, color_jitter, uniform_crop
except:
    from .transforms_SF import random_short_side_scale_jitter, random_crop, horizontal_flip,grayscale, color_jitter, uniform_crop


# Enable multi thread decoding.
ENABLE_MULTI_THREAD_DECODE = True
# Input videos may has different fps, convert it to the target video fps before frame sampling.
TARGET_FPS=30
# Decoding backend, options include `pyav` or `torchvision`
DECODING_BACKEND = 'pyav'

# Testing: Number of crops to sample from a frame spatially for aggregating the prediction results
NUM_SPATIAL_CROPS=3
# Tesintg: Number of clips to sample from a video uniformly for aggregating the prediction results.
NUM_ENSEMBLE_VIEWS=10

# The spatial crop size for training.
TRAIN_CROP_SIZE = 224
# The spatial augmentation jitter scales for training.
TRAIN_JITTER_SCALES=[256, 320]
# The mean value of the video raw pixels across the R G B channels.
MEAN=[0.45, 0.45, 0.45] # (0.45, 0.45, 0.45) (0.43216, 0.394666, 0.37645)
# The std value of the video raw pixels across the R G B channels.
STD=[0.225, 0.225, 0.225] # (0.225, 0.225, 0.225) (0.22803, 0.22145, 0.216989)


ROOT_DIR = {
    'kinetics': './data/kinetics',
    'kinetics600': './data/kinetics600',
    'kinetics_sound': './data/kinetics',
    'ave': './data/ave/',
    'audioset': './data/audioset',
    'vggsound': './data/vggsound',
}

MODE_DIR = {
    'kinetics': {
        'train': 'train_avi-480p',
        'val': 'val_avi-480p'
    },
    'kinetics600': {
        'train': 'train',
        'val': 'val'
    },
    'kinetics_sound': {
        'train': 'train_avi-480p',
        'val': 'val_avi-480p'
    },
    'ave': {
        'train': 'train',
        'val': 'val'
    },
    'audioset': {
        'train': 'unbalanced_train_segments/video',
        'val': 'eval_segments/video'
    },
    'vggsound': {
        'train': 'train',
        'val': 'test'
    },
}


def valid_video(vid_idx, vid_path):
    try:
        probe = ffmpeg.probe(vid_path)
        video_stream = next((
            stream for stream in probe['streams'] if stream['codec_type'] == 'video'), 
            None
        )
        audio_stream = next((
            stream for stream in probe['streams'] if stream['codec_type'] == 'audio'), 
            None
        )
        if audio_stream and video_stream and float(video_stream['duration']) > 1.1 and float(audio_stream['duration']) > 1.1:
            print(f"{vid_idx}: True", flush=True)
            return True
        else:
            print(f"{vid_idx}: False (duration short/ no audio)", flush=True)
            return False
    except:
        print(f"{vid_idx}: False", flush=True)
        return False


def filter_videos(vid_paths):
    all_indices = Parallel(n_jobs=30)(delayed(valid_video)(vid_idx, vid_paths[vid_idx]) for vid_idx in range(len(vid_paths)))
    valid_indices = [i for i, val in enumerate(all_indices) if val]
    return valid_indices


def get_video_container(path_to_vid, multi_thread_decode=False, backend="pyav"):
    """
    Given the path to the video, return the pyav video container.
    Args:
        path_to_vid (str): path to the video.
        multi_thread_decode (bool): if True, perform multi-thread decoding.
        backend (str): decoder backend, options include `pyav` and
            `torchvision`, default is `pyav`.
    Returns:
        container (container): video container.
    """
    if backend == "torchvision":
        with open(path_to_vid, "rb") as fp:
            container = fp.read()
        return container
    elif backend == "pyav":
        try:
            container = av.open(path_to_vid)
        except:
            container = av.open(path_to_vid, metadata_errors="ignore")
        if multi_thread_decode:
            # Enable multiple threads for decoding.
            container.streams.video[0].thread_type = "AUTO"
        return container
    else:
        raise NotImplementedError("Unknown backend {}".format(backend))


class AVideoDataset(torch.utils.data.Dataset):
    """
    Audio-video loader. Construct the video loader, then sample
    clips from the videos. For training and validation, a single clip is
    randomly sampled from every video with random cropping, scaling, and
    flipping. For testing, multiple clips are uniformaly sampled from every
    video with uniform cropping. For uniform cropping, we take the left, center,
    and right crop if the width is larger than height, or take top, center, and
    bottom crop if the height is larger than the width.
    """

    def __init__(
        self, 
        ds_name='kinetics',
        mode='train',
        num_frames=30,
        sample_rate=1,
        train_crop_size=112,
        test_crop_size=112,
        num_spatial_crops=3,
        num_ensemble_views=10,
        num_train_clips=1,
        path_to_data_dir='datasets/data',
        num_retries=1,
        seed=None,
        num_data_samples=None,
        colorjitter=False,
        temp_jitter=True,
        center_crop=False,
        fold=1,
        target_fps=30,
        decode_audio=True,
        aug_audio=[],
        num_sec=1,
        aud_sample_rate=48000,
        aud_spec_type=1,
        use_volume_jittering=False,
        use_temporal_jittering=False,
        z_normalize=False,
        mean=[0.45, 0.45, 0.45],
        std=[0.225, 0.225, 0.225],
        corruption=1
    ):
        """
        Construct the Kinetics video loader with a given csv file. The format of
        the csv file is:
        ```
        path_to_video_1
        path_to_video_2
        ...
        path_to_video_N
        ```
        Args:
            cfg (CfgNode): configs.
            mode (string): Options includes `train`, `val`, or `test` mode.
                For the train and val mode, the data loader will take data
                from the train or val set, and sample one clip per video.
                For the test mode, the data loader will take data from test set,
                and sample multiple clips per video.
            num_retries (int): number of retries.
        """
        # Only support train, val, and test mode.
        assert mode in [
            "train",
            "val",
            "test",
        ], "Split '{}' not supported for '{}'".format(mode, ds_name)
        self.ds_name = ds_name
        self.mode = mode
        self.num_frames = num_frames
        self.sample_rate = sample_rate
        self.train_crop_size = train_crop_size
        self.test_crop_size = train_crop_size
        if train_crop_size == 112:
            train_jitter_scles = (128, 160)
        else:
            train_jitter_scles = (256, 320)
        self.train_jitter_scles = train_jitter_scles
        self.num_ensemble_views = num_ensemble_views
        self.num_spatial_crops = num_spatial_crops
        self.num_train_clips = num_train_clips
        self.data_prefix = os.path.join(ROOT_DIR[ds_name], MODE_DIR[ds_name][mode])
        self.path_to_data_dir = path_to_data_dir
        self.num_data_samples = num_data_samples
        self.colorjitter = colorjitter
        self.temp_jitter = temp_jitter
        self.center_crop = center_crop
        self.fold = fold
        self.target_fps = target_fps
        self.decode_audio = decode_audio
        self.aug_audio = aug_audio
        self.num_sec=num_sec
        self.aud_sample_rate = aud_sample_rate
        self.aud_spec_type = aud_spec_type
        self.use_volume_jittering = use_volume_jittering
        self.use_temporal_jittering = use_temporal_jittering
        self.z_normalize = z_normalize
        self.mean = mean
        self.std = std
        self.corruption = corruption

        self._video_meta = {}
        self._num_retries = num_retries

        # Get classes
        if ds_name != 'audioset':
            classes = list(sorted(glob.glob(os.path.join(self.data_prefix, '*'))))
            classes = [os.path.basename(i) for i in classes]
            self.class_to_idx = {classes[i]: i for i in range(len(classes))}
        
        self.sound_only_classes_kinetics = ["blowing_nose", "blowing_out_candles", "bowling", "chopping_wood", 
                "dribbling_basketball",  "laughing", "mowing_lawn", "playing_accordion",
                "playing_bagpipes", "playing_bass_guitar", "playing_clarinet", "playing_drums",
                "playing_guitar", "playing_harmonica", "playing_keyboard", "playing_organ",
                "playing_piano", "playing_saxophone", "playing_trombone", "playing_trumpet",
                "playing_violin", "playing_xylophone", "ripping_paper", "shoveling_snow",
                "shuffling_cards", "singing", "stomping_grapes", "strumming_guitar",
                "tap_dancing", "tapping_guitar", "tapping_pen", "tickling"]
        print(f"Number of Sound Classes: {len(self.sound_only_classes_kinetics)}")

        # For training or validation mode, one single clip is sampled from every video. 
        # For testing, NUM_ENSEMBLE_VIEWS clips are sampled from every video. 
        # For every clip, NUM_SPATIAL_CROPS is cropped spatially from the frames.
        if self.mode in ["train", "val"]:
            self._num_clips = self.num_train_clips
        elif self.mode in ["test"]:
            self._num_clips = (
                self.num_ensemble_views * self.num_spatial_crops
            )
        
        self.manager = Manager()
        print(f"Constructing {self.ds_name} {self.mode}...")
        self._construct_loader()

    def _construct_loader(self):
        """
        Construct the video loader.
        """
        path_to_file = os.path.join(
            self.path_to_data_dir, f"{self.ds_name}_{self.mode}.txt"
        )
        if not os.path.exists(path_to_file) and self.ds_name != 'audioset':
            files = list(sorted(glob.glob(os.path.join(self.data_prefix, '*', '*')))) 
            with open(path_to_file, 'w') as f:
                for item in files:
                    if self.ds_name == 'kinetics_sound':
                        class_name = item.split('/')[-2]
                        print(class_name)
                        if class_name in self.sound_only_classes_kinetics:
                            f.write("%s\n" % item)
                    else:
                        f.write("%s\n" % item)

        self._path_to_videos = []
        self._labels = []
        self._spatial_temporal_idx = []
        self._vid_indices = []
        with open(path_to_file, "r") as f:
            for clip_idx, path in enumerate(f.read().splitlines()):
                for idx in range(self._num_clips):
                    self._path_to_videos.append(
                        os.path.join(self.data_prefix, path).replace('datasets01_101', 'datasets01')
                    )
                    label = 0
                    # Get class index
                    class_name = path.split('/')[-2]
                    if self.ds_name == 'audioset':
                        label = 0 # todo?
                    else:
                        label = self.class_to_idx[class_name]
                    self._labels.append(int(label))
                    self._spatial_temporal_idx.append(idx)
                    self._vid_indices.append(clip_idx)
                    self._video_meta[clip_idx * self._num_clips + idx] = {}
        print(f"Number of Unique Labels: {len(np.unique(self._labels))}")
        assert (
            len(self._path_to_videos) > 0
        ), "Failed to load {} split {} from {}".format(
            self.ds_name, self._split_idx, path_to_file
        )
        print(
            "Constructing {} dataloader (size: {}) from {}".format(
                self.ds_name, len(self._path_to_videos), path_to_file
            )
        )

        # Create / Load valid indices (has audio)
        if self.ds_name in ['audioset', 'kinetics', 'kinetics600', 'vggsound', 'kinetics_sound', 'ave']:
            if self.mode == 'train':
                vid_valid_file = f'{self.path_to_data_dir}/{self.ds_name}_valid.pkl'
            else:
                vid_valid_file = f'{self.path_to_data_dir}/{self.ds_name}_{self.mode}_valid.pkl'
            if os.path.exists(vid_valid_file):
                with open(vid_valid_file, 'rb') as handle:
                    self.valid_indices = pickle.load(handle)
            else:
                self.valid_indices = filter_videos(self._path_to_videos)
                with open(vid_valid_file, 'wb') as handle:
                    pickle.dump(
                        self.valid_indices, 
                        handle, 
                        protocol=pickle.HIGHEST_PROTOCOL
                    )
            if self.num_data_samples is not None:
                self.valid_indices = self.valid_indices[:self.num_data_samples]
            print(f"Total number of videos: {len(self._path_to_videos)}, Valid videos: {len(self.valid_indices)}", flush=True)
        else:
            assert(False)
 
    def __getitem__(self, index):
        """
        Given the video index, return tensors: video, audio, label, vid_idx, idx
        Otherwise, repeatly find a random video that can be decoded as a replacement.
        Args:
            index (int): the video index provided by the pytorch sampler.
        Returns:
            frames (tensor): the frames of sampled from the video. The dimension
                is `channel` x `num frames` x `height` x `width`.
            label (int): the label of the current video.
            index (int): if the video provided by pytorch sampler can be
                decoded, then return the index of the video. If not, return the
                index of the video replacement that can be decoded.
        """
        index_capped = index
        index = self.valid_indices[index_capped]
        if self.mode in ["train", "val"]:
            # -1 indicates random sampling.
            temporal_sample_index = -1
            spatial_sample_index = -1
            min_scale = self.train_jitter_scles[0]
            max_scale = self.train_jitter_scles[1]
            crop_size = self.train_crop_size
            if self.center_crop:
                spatial_sample_index = 1
                min_scale = self.train_crop_size
                max_scale = self.train_crop_size
                crop_size = self.train_crop_size
        elif self.mode in ["test"]:
            temporal_sample_index = (
                self._spatial_temporal_idx[index] // self.num_spatial_crops
            )
            # spatial_sample_index is in [0, 1, 2]. Corresponding to left,
            # center, or right if width is larger than height, and top, middle,
            # or bottom if height is larger than width.
            spatial_sample_index = (
                1 
            ) # self._spatial_temporal_idx[index] % self.num_spatial_crops
            # min_scale, max_scale, crop_size = [self.test_crop_size] * 3
            if self.train_crop_size == 112:
                min_scale, max_scale, crop_size = [128, 128, 112] # [128, 171, 112], [128, 128, 112]
            else:
                min_scale, max_scale, crop_size = [256, 320, 224]
        else:
            raise NotImplementedError(
                "Does not support {} mode".format(self.mode)
            )

        # Try to decode and sample a clip from a video. 
        # If the video can not be decoded, repeatly find a random video replacement that can be decoded.
        for retry_ix in range(self._num_retries):
            video_container = None
            try:
                video_container = get_video_container(
                    self._path_to_videos[index],
                    ENABLE_MULTI_THREAD_DECODE,
                    DECODING_BACKEND,
                )
            except Exception as e:
                print(
                    "Failed to load video from {} with error {}".format(
                        self._path_to_videos[index], e
                    )
                )
            # Select a random video if the current video was not able to access.
            if video_container is None:
                print(f"Retrying {retry_ix} / {self._num_retries}: video_container is not None", flush=True)
                index = random.randint(0, len(self._path_to_videos) - 1)
                continue

            # Decode video. Meta info is used to perform selective decoding.
            frames, spec = decode(
                self._path_to_videos[index],
                video_container,
                self.sample_rate,
                self.num_frames,
                temporal_sample_index if self.temp_jitter else 500,
                self.num_ensemble_views if self.temp_jitter else 1000,
                video_meta=self._video_meta[index],
                target_fps=self.target_fps,
                backend=DECODING_BACKEND,
                max_spatial_scale=max_scale,
                decode_audio=self.decode_audio,
                aug_audio=self.aug_audio,
                num_sec=self.num_sec,
                aud_sample_rate=self.aud_sample_rate,
                aud_spec_type=self.aud_spec_type,
                use_volume_jittering=self.use_volume_jittering,
                use_temporal_jittering=self.use_temporal_jittering,
                z_normalize=self.z_normalize,
            )

            frames = frames.float()
            frames = frames / 255.0

            # T H W C -> C T H W.
            frames = frames.permute(3, 0, 1, 2)
            # Perform data augmentation.
            frames = self.spatial_sampling(
                frames,
                spatial_idx=spatial_sample_index,
                min_scale=min_scale,
                max_scale=max_scale,
                crop_size=crop_size,
            )

            if self.colorjitter:
                frames = color_jitter(frames, 0.4, 0.4, 0.4)

            # Perform color normalization.
            frames = frames - torch.tensor(self.mean).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
            frames = frames / torch.tensor(self.std).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)

            label = self._labels[index]
            vid_idx = self._vid_indices[index]
            idx = index
            audio = spec
            if self.decode_audio:
                return frames, audio, label, vid_idx, index_capped
            else:
                return frames, label, vid_idx, index_capped
        else:
            raise RuntimeError(
                "Failed to fetch video after {} retries.".format(
                    self._num_retries
                )
            )

    def __len__(self):
        """
        Returns:
            (int): the number of videos in the dataset.
        """
        return len(self.valid_indices)

    def spatial_sampling(
        self,
        frames,
        spatial_idx=-1,
        min_scale=256,
        max_scale=320,
        crop_size=224,
    ):
        """
        Perform spatial sampling on the given video frames. If spatial_idx is
        -1, perform random scale, random crop, and random flip on the given
        frames. If spatial_idx is 0, 1, or 2, perform spatial uniform sampling
        with the given spatial_idx.
        Args:
            frames (tensor): frames of images sampled from the video. The
                dimension is `num frames` x `height` x `width` x `channel`.
            spatial_idx (int): if -1, perform random spatial sampling. If 0, 1,
                or 2, perform left, center, right crop if width is larger than
                height, and perform top, center, buttom crop if height is larger
                than width.
            min_scale (int): the minimal size of scaling.
            max_scale (int): the maximal size of scaling.
            crop_size (int): the size of height and width used to crop the
                frames.
        Returns:
            frames (tensor): spatially sampled frames.
        """
        assert spatial_idx in [-1, 0, 1, 2]
        if spatial_idx == -1:
            frames, _ = random_short_side_scale_jitter(
                frames, min_scale, max_scale
            )
            frames, _ = random_crop(frames, crop_size)
            frames, _ = horizontal_flip(0.5, frames)
            if self.corruption > 1:
                print("HEY")
                new_size = int(crop_size / self.corruption)
                frames = torch.nn.functional.interpolate(frames, size=new_size)
                frames = torch.nn.functional.interpolate(frames, size=crop_size)
        else:
            # The testing is deterministic and no jitter should be performed.
            # min_scale, max_scale, and crop_size are expect to be the same.
            # assert len({min_scale, max_scale, crop_size}) == 1
            frames, _ = random_short_side_scale_jitter(
                frames, min_scale, max_scale
            )
            frames, _ = uniform_crop(frames, crop_size, spatial_idx)
            if self.corruption > 1:
                new_size = int(crop_size / self.corruption)
                frames = torch.nn.functional.interpolate(frames, size=new_size)
                frames = torch.nn.functional.interpolate(frames, size=crop_size)
        return frames


if __name__ == '__main__':

    import random
    import time
    from torch.utils.data import DataLoader
    from torch.utils.data.dataloader import default_collate
    import torchvision
    import torch

    # val_dataset = AVideoDataset(
    #     ds_name='audioset',
    #     mode='train',
    # )
    # val_loader = DataLoader(
    #     val_dataset,
    #     batch_size=16,
    #     num_workers=0,
    #     shuffle=True,
    #     collate_fn=None
    # )
    # tic = time.time()
    # for batch_idx, batch in enumerate(val_loader):
    #     video_1, spec_1, label, vid_idx, idx = batch
    #     print(
    #         batch_idx,
    #         video_1.size(),
    #         spec_1.size(),
    #         label,
    #         idx,
    #         vid_idx,
    #         time.time() - tic
    #     )
    #     print(f'Batch time (s): {time.time() - tic}')
    #     tic = time.time()
    #########################
    print("="*60)
    print('Testing AVideoDataset')
    val_dataset = AVideoDataset(
        ds_name='kinetics_sound',
        mode='train',
    )
    val_loader = DataLoader(
        val_dataset,
        batch_size=16,
        num_workers=0,
        shuffle=True,
        collate_fn=None
    )
    tic = time.time()
    for batch_idx, batch in enumerate(val_loader):
        video, spec, label, vid_idx, idx = batch
        print(len(video),video[0].size(),flush=True)
        print(
            batch_idx,
            video.size(),
            spec.size(),
            label,
            idx,
            vid_idx,
            time.time() - tic
        )
        print(f'Batch time (s): {time.time() - tic}')
        tic = time.time()
