from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import scipy.io as sio
import numpy as np
import pickle
import os
import torch
# from MeshConvertPytorch3D import MeshConverter
from MeshMemoryMap import MeshConverter
from datasets.CalculatePointDirection import cal_point_weight, cal_point_weight2, direction_calculator
import BboxTools as bbt

# in case.
PASCAL_ROOT = '../PASCAL3D/PASCAL3D'
subtypes = ['hatchback', 'mini', 'minivan', 'race', 'sedan', 'SUV', 'truck', 'wagon', 'others']


def get_name_list(anno_path, category, vp, dataset, set_type):
    image_name_list = pickle.load(
        open(os.path.join(anno_path, 'full_{}_{}_{}.pkl'.format(category, dataset, set_type)), 'rb'))
    return image_name_list


# anno_path = '/export/ccvl12a/yutong/3DRepresentation'
def get_vp_list(anno_path, category, vp, dataset, set_type):
    annotations = pickle.load(
        open(os.path.join(anno_path, 'full_annotations_{}_{}_{}.pkl'.format(category, dataset, set_type)), 'rb'))
    vp_list = [ano[1:3] for ano in annotations]
    return vp_list


def get_kp_list(anno_path, category, vp, dataset, set_type):
    annotations = pickle.load(
        open(os.path.join(anno_path, 'full_annotations_{}_{}_{}.pkl'.format(category, dataset, set_type)), 'rb'))
    kp_list = [ano[5] for ano in annotations]
    return kp_list


def get_kp_status_list(anno_path, category, vp, dataset, set_type):
    annotations = pickle.load(
        open(os.path.join(anno_path, 'full_annotations_{}_{}_{}.pkl'.format(category, dataset, set_type)), 'rb'))
    kp_status_list = [ano[4] for ano in annotations]
    return kp_status_list


def get_bbox_list(anno_path, category, vp, dataset, set_type):
    annotations = pickle.load(
        open(os.path.join(anno_path, 'full_annotations_{}_{}_{}.pkl'.format(category, dataset, set_type)), 'rb'))
    bbox_list = [ano[3] for ano in annotations]
    return bbox_list


def draw_labelmap(img, pt, sigma, type='Gaussian'):
    # Draw a 2D gaussian
    # Adopted from https://github.com/anewell/pose-hg-train/blob/master/src/pypose/draw.py
    img = np.array(img)
    # print(img)
    # Check that any part of the gaussian is in-bounds
    ul = [int(pt[0] - 3 * sigma), int(pt[1] - 3 * sigma)]
    br = [int(pt[0] + 3 * sigma + 1), int(pt[1] + 3 * sigma + 1)]
    if (ul[0] >= img.shape[1] or ul[1] >= img.shape[0] or
            br[0] < 0 or br[1] < 0):
        # If not, just return the image as is
        return torch.Tensor(img), 0

    # Generate gaussian
    size = 6 * sigma + 1
    x = np.arange(0, size, 1, float)
    y = x[:, np.newaxis]
    x0 = y0 = size // 2
    # The gaussian is not normalized, we want the center value to equal 1
    if type == 'Gaussian':
        g = np.exp(- ((x - x0) ** 2 + (y - y0) ** 2) / (2 * sigma ** 2))
    elif type == 'Cauchy':
        g = sigma / (((x - x0) ** 2 + (y - y0) ** 2 + sigma ** 2) ** 1.5)

    # Usable gaussian range
    g_x = max(0, -ul[0]), min(br[0], img.shape[1]) - ul[0]
    g_y = max(0, -ul[1]), min(br[1], img.shape[0]) - ul[1]
    # Image range
    img_x = max(0, ul[0]), min(br[0], img.shape[1])
    img_y = max(0, ul[1]), min(br[1], img.shape[0])

    img[img_y[0]:img_y[1], img_x[0]:img_x[1]] = g[g_y[0]:g_y[1], g_x[0]:g_x[1]]
    return torch.Tensor(img), 1


class Pascal3DPlus(Dataset):
    def __init__(self, transforms, enable_cache=True, **kwargs):
        self.root_path = kwargs['rootpath']
        self.img_class = kwargs['imgclass']
        if 'data_pendix' in kwargs.keys():
            data_pendix = kwargs['data_pendix']
        else:
            data_pendix = ''

        if 'for_test' in kwargs:
            self.for_test = kwargs['for_test']
        else:
            self.for_test = False

        if 'anno_path' in kwargs:
            anno_path = kwargs['anno_path']
        else:
            anno_path = 'annotations'

        if 'img_path' in kwargs:
            img_path = kwargs['img_path']
        else:
            img_path = 'images'

        if 'list_path' in kwargs:
            list_path = kwargs['list_path']
        else:
            list_path = 'lists'

        if 'real_mask' in kwargs:
            self.real_mask = kwargs['real_mask']
        else:
            self.real_mask = False

        self.image_path = os.path.join(self.root_path, img_path, '%s/' % (self.img_class + data_pendix))
        self.annotation_path = os.path.join(self.root_path, anno_path, '%s/' % (self.img_class + data_pendix))
        list_path = os.path.join(self.root_path, list_path, '%s/' % (self.img_class + data_pendix))

        self.transforms = transforms

        if 'subtypes' in kwargs:
            self.subtypes = kwargs['subtypes']
        else:
            self.subtypes = subtypes

        self.file_list = sum(
            [[l.strip() for l in open(os.path.join(list_path, subtype_ + '.txt')).readlines()] for subtype_ in
             self.subtypes], [])

        if 'weighted' in kwargs:
            self.weighted = kwargs['weighted']
        else:
            self.weighted = False

        self.enable_cache = enable_cache
        self.cache_img = dict()
        self.cache_anno = dict()

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, item):
        # print(item)
        name_img = self.file_list[item]

        if name_img in self.cache_anno.keys():
            annotation_file = self.cache_anno[name_img]
            img = self.cache_img[name_img]
        else:
            img = Image.open(os.path.join(self.image_path, name_img))
            # tackle the gray images.
            if img.mode != 'RGB':
                img = img.convert('RGB')
            # print(img.size)
            annotation_file = np.load(os.path.join(self.annotation_path, name_img.split('.')[0] + '.npz'),
                                      allow_pickle=True)

            if self.enable_cache:
                self.cache_anno[name_img] = dict(annotation_file)
                self.cache_img[name_img] = img

        box_obj = bbt.from_numpy(annotation_file['box_obj'])
        # print('box_obj.boundary', box_obj.boundary)
        if self.real_mask:
            obj_mask = (np.array(annotation_file['obj_mask']) > 128).astype(np.float32)
            # print(obj_mask.shape, np.array(img).shape)
        else:
            obj_mask = np.zeros(box_obj.boundary, dtype=np.float32)
            box_obj.assign(obj_mask, 1)
        # print(obj_mask.shape)
        #         print('after__obj_mask', np.sum(obj_mask, axis=(0,1)))

        kp = annotation_file['cropped_kp_list']
        iskpvisible = annotation_file['visible']
        # print(iskpvisible)
        iskpvisible = iskpvisible == 1

        if self.weighted:
            iskpvisible = iskpvisible * annotation_file['kp_weights']

        if not self.for_test:
            iskpvisible = np.logical_and(iskpvisible, np.any(kp >= np.zeros_like(kp), axis=1))
            iskpvisible = np.logical_and(iskpvisible, np.any(kp < np.array([img.size[::-1]]), axis=1))

        kp = np.max([np.zeros_like(kp), kp], axis=0)
        kp = np.min([np.ones_like(kp) * (np.array([img.size[::-1]]) - 1), kp], axis=0)

        this_name = name_img.split('.')[0]
        
        # pose_ = np.array([annotation_file['distance'], annotation_file['elevation'], annotation_file['azimuth'], annotation_file['theta']], dtype=np.float32)
        pose_ = np.array([5, annotation_file['elevation'], annotation_file['azimuth'], annotation_file['theta']], dtype=np.float32)

        sample = {'img': img, 'kp': kp, 'iskpvisible': iskpvisible, 'this_name': this_name, 'obj_mask': obj_mask,
                  'box_obj': box_obj.shape, 'pose': pose_}
        # print(this_name, img.size)

        if self.transforms:
            sample = self.transforms(sample)
        # print(np.array(img).shape)
        return sample

    def get_image_size(self):
        name_img = self.file_list[0]
        img = Image.open(os.path.join(self.image_path, name_img))
        return np.array(img).shape[0:2]


class Pascal3DPlus3D(Dataset):
    def __init__(self, transforms, **kwargs):
        self.root_path = kwargs['rootpath']
        self.img_class = kwargs['imgclass']
        if 'data_pendix' in kwargs.keys():
            data_pendix = kwargs['data_pendix']
        else:
            data_pendix = ''
        self.image_path = os.path.join(self.root_path, 'images/%s/' % (self.img_class + data_pendix))
        self.annotation_path = os.path.join(self.root_path, 'annotations/%s/' % (self.img_class + data_pendix))
        list_path = os.path.join(self.root_path, 'lists/%s/' % (self.img_class + data_pendix))

        self.transforms = transforms
        if 'subtypes' in kwargs:
            self.subtypes = kwargs['subtypes']
        else:
            self.subtypes = subtypes

        if 'mesh_path' in kwargs:
            mesh_path = kwargs['mesh_path']
        else:
            mesh_path = '../PASCAL3D/PASCAL3D+_release1.1/CAD_d4/car/'

        if 're_align' in kwargs:
            self.re_align = kwargs['re_align']
        else:
            self.re_align = False

        self.file_list = sum(
            [[l.strip() for l in open(os.path.join(list_path, subtype_ + '.txt')).readlines()] for subtype_ in
             self.subtypes], [])
        self.converter = MeshConverter(path=mesh_path)

        if 'weighted' in kwargs and kwargs['weighted']:
            self.weighted = True
            self.init_direction_dicts()
        else:
            self.weighted = False

    def init_direction_dicts(self):
        self.direction_dicts = []
        for t in self.converter.loader:
            self.direction_dicts.append(direction_calculator(*t))

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, item):
        name_img = self.file_list[item]
        img = Image.open(os.path.join(self.image_path, name_img))
        if img.mode != 'RGB':
            img = img.convert('RGB')
        annotation_file = np.load(os.path.join(self.annotation_path, name_img.split('.')[0] + '.npz'),
                                  allow_pickle=True)

        box_obj = bbt.from_numpy(annotation_file['box_obj'])

        obj_mask = np.zeros(box_obj.boundary, dtype=np.float32)

        box_obj.assign(obj_mask, 1)

        kp, iskpvisible = self.converter.get_one(annotation_file)
        iskpvisible = torch.from_numpy(iskpvisible)

        if kp is None:
            return self.__getitem__((item + 1) % self.__len__())

        this_name = name_img.split('.')[0]

        if self.weighted:
            cad_idx = annotation_file['cad_index'] - 1
            weights = torch.from_numpy(cal_point_weight(self.direction_dicts[cad_idx],
                                                        self.converter.loader[cad_idx][0], annotation_file)).type(
                torch.float32)
            iskpvisible = iskpvisible * torch.abs(weights)

        if self.re_align:
            pixels = kp
            box_ = bbt.Bbox2D(
                [(np.min(pixels[:, 0]), np.max(pixels[:, 0])), (np.min(pixels[:, 1]), np.max(pixels[:, 1]))])
            box_obj = bbt.from_numpy(annotation_file['box_obj'])

            foo_proj = bbt.projection_function_by_boxes(box_, box_obj)
            kp = foo_proj(pixels)

            kp = np.max([np.zeros_like(kp), kp], axis=0)
            kp = np.min([np.ones_like(kp) * (np.array([annotation_file['box_obj'][4::]]) - 1), kp], axis=0)
            kp = torch.from_numpy(kp).type(torch.long)
        else:
            kp = torch.from_numpy(kp)

        # print(box_obj.shape)
        sample = {'img': img, 'kp': kp, 'iskpvisible': iskpvisible, 'this_name': this_name, 'obj_mask': obj_mask,
                  'box_obj': box_obj.shape}
        # print(kp.shape)

        if self.transforms:
            sample = self.transforms(sample)

        # print(sample['kp'].shape)
        return sample


# Need to get viewpoint, keypoint, iskpvisible
class Pascal3DPlusOld(Dataset):
    def __init__(self, transforms, **kwargs):
        # anno_path ='/export/ccvl12a/yutong/3DRepresentation'
        self.root_path = kwargs['rootpath']
        self.img_class = kwargs['imgclass']
        self.target_vp = kwargs['vp']
        self.dataset = kwargs['dataset']
        self.mode = kwargs['mode']
        self.sigma = kwargs['sigma']
        self.transforms = transforms
        self.anno_path = os.path.join(self.root_path, '3DRepresentation')
        self.image_path = os.path.join(self.root_path, 'PASCAL3D+_cropped', '{}_imagenet'.format(self.img_class))
        self.name_list = get_name_list(self.anno_path, self.img_class, self.target_vp, self.dataset, self.mode)
        self.vp_list = get_vp_list(self.anno_path, self.img_class, self.target_vp, self.dataset, self.mode)
        self.kp_status_list = get_kp_status_list(self.anno_path, self.img_class, self.target_vp, self.dataset,
                                                 self.mode)
        self.kp_list = get_kp_list(self.anno_path, self.img_class, self.target_vp, self.dataset, self.mode)
        self.bbox_list = get_bbox_list(self.anno_path, self.img_class, self.target_vp, self.dataset, self.mode)

    def __getitem__(self, index):
        this_name = self.name_list[index]
        # print(this_name)
        img_path = os.path.join(self.image_path, this_name + '.JPEG')
        # print(img_path)
        img = Image.open(img_path)
        # print('img.shape',np.array(img).shape)
        azi, ele = self.vp_list[index]
        iskpvisible_ori = self.kp_status_list[index]
        iskpvisible = []
        for i, t in enumerate(iskpvisible_ori):
            if t == 1: iskpvisible.append(i)
            # else: iskpvisible.append(0)
        kp = self.kp_list[index]
        # (w_original - 1) // 32 + 1, (h_original - 1) // 32 + 1
        w = self.bbox_list[index][2] - self.bbox_list[index][0]
        h = self.bbox_list[index][3] - self.bbox_list[index][1]
        # print('w,h', w,h)
        ratio = min(w, h) / 224
        kp = np.array(kp)
        kp = kp / ratio
        # generate ground truth
        img = np.array(img)
        h_, w_, _ = img.shape
        w_gtmap = (w_ - 1) // 32 + 1
        h_gtmap = (h_ - 1) // 32 + 1
        kp_gtmap = (kp - 1) // 32 + 1
        nparts = len(kp)
        target = torch.zeros(nparts, int(h_gtmap), int(w_gtmap))
        # print('target.shape',target.shape)

        # print('w_gtmap, h_gtmap', w_gtmap, h_gtmap)
        kp = torch.Tensor(kp)
        kp = torch.cat([kp[:, 1].view(-1, 1), kp[:, 0].view(-1, 1)], dim=1)
        # print(kp)
        kp_gtmap = torch.Tensor(kp_gtmap)
        kp_gtmap = torch.cat([kp_gtmap[:, 1].view(-1, 1), kp_gtmap[:, 0].view(-1, 1)], dim=1)
        # print(kp_gtmap)
        for i in range(nparts):
            target[i], vis = draw_labelmap(target[i], kp_gtmap[i], self.sigma, type='Gaussian')
        sample = {'img': img, 'kp': kp, 'iskpvisible': iskpvisible, 'this_name': this_name, 'gt': target}
        if self.transforms:
            sample = self.transforms(sample)
        return sample

    def __len__(self):
        return len(self.name_list)


class ToTensor(object):
    def __init__(self):
        self.trans = transforms.ToTensor()

    def __call__(self, sample):
        sample['img'] = self.trans(sample['img'])
        if not type(sample['iskpvisible']) == torch.Tensor:
            sample['iskpvisible'] = torch.Tensor(sample['iskpvisible'])
        if not type(sample['kp']) == torch.Tensor:
            sample['kp'] = torch.Tensor(sample['kp'])
        return sample


class Normalize(object):
    def __init__(self):
        self.trans = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

    def __call__(self, sample):
        sample['img'] = self.trans(sample['img'])
        return sample


def hflip(sample):
    sample['img'] = transforms.functional.hflip(sample['img'])
    sample['kp'][:, 1] = sample['img'].size[0] - sample['kp'][:, 1] - 1
    return sample


class RandomHorizontalFlip(object):
    def __init__(self):
        self.trans = transforms.RandomApply([lambda x: hflip(x)], p=0.5)

    def __call__(self, sample):
        sample = self.trans(sample)

        return sample


class ColorJitter(object):
    def __init__(self):
        self.trans = transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.4, hue=0)

    def __call__(self, sample):
        sample['img'] = self.trans(sample['img'])

        return sample
