import torch
from typing import Sequence
import urllib.request
import tarfile
import os
import numpy as np
import random
import os
import json
from .base_video_dataset import BaseVideoDataset
from omegaconf import DictConfig
from pathlib import Path
from tqdm import tqdm
from zipfile import ZipFile


class MinerlVideoDataset(BaseVideoDataset):
    """
    MineRL dataset
    """

    def __init__(self, cfg: DictConfig, split: str = "training"):
        if split == "test":
            split = "validation"
        self.video_dir = Path(cfg.save_dir) / f"{split}" / "videos"
        self.condition_dir = Path(cfg.save_dir) / f"{split}" / "conditions"
        super().__init__(cfg, split)

    def get_data_paths(self, split):
        video_dir = self.save_dir / split / "videos"
        paths = sorted(list(video_dir.glob("*.mp4")), key=lambda x: x.name)
        return paths

    def download_dataset(self) -> Sequence[int]:
        for s in ["training", "validation"]:
            (self.save_dir / s / "videos").mkdir(exist_ok=True, parents=True)
            (self.save_dir / s / "conditions").mkdir(exist_ok=True, parents=True)

        download_dir = self.save_dir / "downloads"

        if not download_dir.exists():
            url = "https://archive.org/download/minerl_navigate/minerl_navigate.zip"
            print(f"Downloading from {url}...")
            tmp_file, headers = urllib.request.urlretrieve(url)
            print("Extracting...")
            with ZipFile(tmp_file, "r") as zip_ref:
                zip_ref.extractall(download_dir)

        for s in ["training", "validation"]:
            (self.save_dir / s / "videos").rmdir()
        os.system(f"ln -sfn {download_dir.absolute()}/minerl_navigate/train {self.save_dir.absolute()}/training/videos")
        os.system(
            f"ln -sfn {download_dir.absolute()}/minerl_navigate/test {self.save_dir.absolute()}/validation/videos"
        )

        # define the range of each action: https://minerl.readthedocs.io/en/latest/environments/index.html
        range_per_action = {
            "camera_y": [-180.0, 180.0],
            "left": [0, 1],
            "sneak": [0, 1],
            "forward": [0, 1],
            "back": [0, 1],
            "camera_x": [-180.0, 180.0],
            "jump": [0, 1],
            "place": [0, 1],
            "right": [0, 1],
            "sprint": [0, 1],
            "compassAngle": [-180.0, 180.0],
            "attack": [0, 1],
        }

        for split in ["training", "validation"]:
            # read the JSON file with the actions
            with open(os.path.join(f"{self.save_dir}/{split}/videos", "metadata.json"), "r") as f:
                all_actions = json.load(f)

            # store videos and actions in separate npy files
            video_dir = self.save_dir / split / "videos"
            for video_name in video_dir.glob("*.mp4"):
                actions = all_actions[video_name.name]
                actions = self.normalize_actions(actions, range_per_action)
                actions = np.array(list(actions.values())).T
                np.save(self.save_dir / split / "conditions" / f"{video_name.stem}.npy", actions)

    @staticmethod
    def normalize_actions(actions: dict, range_per_action: dict):
        """
        Normalize actions to [-1, 1]
        :param actions: dictionary of actions
        :param range_per_action: dictionary of ranges (list) for each action
        :return: normalized actions as dictionary of numpy arrays
        """
        for action, range_ in range_per_action.items():
            actions[action] = (np.array(actions[action]) - range_[0]) / (range_[1] - range_[0]) * 2.0 - 1.0
        return actions


if __name__ == "__main__":
    import torch
    from unittest.mock import MagicMock
    import tqdm

    cfg = MagicMock()
    cfg.resolution = 64
    cfg.external_cond_dim = 0
    cfg.n_frames = 64
    cfg.save_dir = "data/minerl"
    cfg.validation_multiplier = 1

    dataset = MinerlVideoDataset(cfg, "training")
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, shuffle=True, num_workers=16)

    for batch in tqdm.tqdm(dataloader):
        pass
