import sys,os

import math
from tkinter import E
import torch
import numpy as np
from skimage import io, transform
from torch.utils.data import Dataset
import json
import open3d as o3d
import pickle
import random
import struct
from timeit import default_timer as timer
from tqdm import tqdm
import trimesh
import open3d as o3d
from nnutils.laplacian import read_vertex_laplacian
import random

class MeshDataset(Dataset):
    def __init__(
        self, 
        dataset_dir, num_point_samples,
        cache_data=False, use_augmentation=False, augmentation_type='rotation', 
        partial_range = 0.1, iden_split = 'template_seq_identity', split = 'train', interval=1, 
        load_obj=False, inverse = False, num_data = 0 , num_surf_samples = None,
    ):
        self.dataset_dir = dataset_dir
        self.num_point_samples = num_point_samples
        self.cache_data = cache_data
        self.use_augmentation = use_augmentation
        self.augmentation_type = augmentation_type
        print('use_augmentation', use_augmentation, 'augmentation_type', augmentation_type)
        self.partial_range = partial_range
        print('partial range', partial_range)
        self.iden_split = iden_split
        self.split = split
        self.interval = interval
        self.load_obj = load_obj
        self.inverse = inverse
        if inverse:
            print('arbitrary pose1-to-pose0 transformation')
        else:
            print('arbitrary pose0-to-pose1 transformation')

        self.num_data = num_data
        self.num_surf_samples = num_surf_samples
        
        # Load labels.
        self.labels = []
        self.labels_pairs = []
        self._load()
            
    def _load(self):
        ### read template lst
        split_file = os.path.join(self.dataset_dir, self.iden_split + '.lst')
        with open(split_file, 'r') as f:
            models = f.read().split('\n')          
        
        sample_dirs = sorted([os.path.join(self.dataset_dir, f) for f in models if os.path.isdir(os.path.join(self.dataset_dir, f)) and f !=''])
        
        self.models_dict = {}
        for idx in range(len(sample_dirs)):
            sample_dir = sample_dirs[idx]
            model_name = os.path.basename(sample_dir)
            identity_name = model_name.split('_')[0]
            self.models_dict[identity_name] = (idx, model_name)  
        
        ### read train_unsee/val lst
        split_file = os.path.join(self.dataset_dir, self.split + '.lst')
        with open(split_file, 'r') as f:
            models = f.read().split('\n')
            
        sample_dirs = []
        for f in models:
            if os.path.isdir(os.path.join(self.dataset_dir, f)) and f !='':
                seqs = sorted(os.listdir(os.path.join(self.dataset_dir, f)))
                # evert interval ==3  append
                seqs_dirs = [os.path.join(self.dataset_dir, f, m) for m in seqs if int(m)%self.interval==0]
                sample_dirs += seqs_dirs

        for sample_dir in sample_dirs:
            self.labels.append({
                "data_dir": sample_dir
            })
        
        ### get deformation paris with template model name，from different motion sequences
        deform_pairs = []
        if self.split[:5] == "train":
            for f0 in models:
                if os.path.isdir(os.path.join(self.dataset_dir, f0)) and f0 !='':
                    identity_name0 = f0.split('_')[0]
                    identity_idx, identity_model_name = self.models_dict[identity_name0]
                    seqs0 = sorted(os.listdir(os.path.join(self.dataset_dir, f0)))
                    seqs0 = [m for m in seqs0 if int(m)%self.interval==0]
                    for f1 in models:
                        if os.path.isdir(os.path.join(self.dataset_dir, f1)) and f1 !='':
                            identity_name1 = f1.split('_')[0]
                            if identity_name0 != identity_name1:
                                continue
                            seqs1= sorted(os.listdir(os.path.join(self.dataset_dir, f1)))
                            seqs1 = [n for n in seqs1 if int(n)%self.interval==0]
                            
                            for m in seqs0:
                                for n in seqs1:
                                    deform_pairs.append((identity_idx, identity_model_name, 0, f0, int(m), f1, int(n)))
            print('total number of deformation pairs w/o aug of train dataset is :', len(deform_pairs))
        else:
            for f in models:
                if os.path.isdir(os.path.join(self.dataset_dir, f)) and f !='':
                    identity_name = f.split('_')[0]
                    identity_idx, identity_model_name = self.models_dict[identity_name]  
                    seqs = sorted(os.listdir(os.path.join(self.dataset_dir, f)))
                    seqs = [m for m in seqs if int(m)%self.interval==0]
                    
                    for m in seqs:
                        # first model as template & other frames
                        if int(m)>0:
                            deform_pairs.append((identity_idx, identity_model_name, 0, f, 0, f, int(m)))
                            
            print('total number of deformation pairs w/o aug of test dataset is :', len(deform_pairs))
  


        for pair in deform_pairs:
            self.labels_pairs.append({
                "pair_info": pair,   
            })
        
        if self.split[:5] == "train":
            self.random_shuffle_samples()
        else:
            if self.num_data!=0:
                random.Random(100).shuffle(self.labels_pairs)
                self.labels_sample_pairs = self.labels_pairs[:self.num_data]
            else:
                self.labels_sample_pairs = self.labels_pairs
    
    def random_shuffle_samples(self, num_samples=64000):
        random.Random(100).shuffle(self.labels_pairs)
        self.labels_sample_pairs = self.labels_pairs[:num_samples]
        
    @staticmethod
    def load_pts_file(path):
        _, ext = os.path.splitext(path)
        assert ext in ['.sdf', '.pts']
        l = 4 if ext == '.sdf' else 6
        with open(path, 'rb') as f:
            points = np.fromfile(f, dtype=np.float32)
        points = np.reshape(points, [-1, l])
        return points

    @staticmethod
    def load_grid(path):
        with open(path, 'rb') as f:
            content = f.read()
        res = struct.unpack('iii', content[:4 * 3])
        vcount = res[0] * res[1] * res[2]
        content = content[4 * 3:]
        tx = struct.unpack('f' * 16, content[:4 * 16])
        tx = np.array(tx).reshape([4, 4]).astype(np.float32)
        content = content[4 * 16:]
        grd = struct.unpack('f' * vcount, content[:4 * vcount])
        grd = np.array(grd).reshape(res).astype(np.float32)
        return grd, tx
    
    @staticmethod
    def load_npz_normals(path):
        flow_dict = np.load(path)
        points = flow_dict['points'].astype(np.float32)
        normals = flow_dict['normals'].astype(np.float32)
        return points, normals
    
    @staticmethod
    def load_npz(path):
        flow_dict = np.load(path)
        points = flow_dict['points'].astype(np.float32)
        return points

    def __len__(self):
        return len(self.labels_sample_pairs)

    def unpack(self, x):
        # We concatenate the first two dimensions, corresponding to 
        # batch size and sample size.
        n_dims = len(x.shape)
        
        new_shape = (x.shape[0] * x.shape[1], )
        for i in range(2, n_dims):
            new_shape += (x.shape[i],)
        
        return x.view(new_shape)

    def _load_data(self, data_dir):
        # Load data from directory.
        uniform_samples = np.zeros((100000, 3), dtype=np.float32)
        near_surface_samples = np.zeros((100000, 3), dtype=np.float32)
        bbox_upper = np.zeros((3), dtype=np.float32)
        bbox_lower = np.zeros((3), dtype=np.float32)


        surface_samples, surface_normals = MeshDataset.load_npz_normals(f'{data_dir}/surface_points.npz')
        grid, world2grid = np.zeros((1)).astype(np.float32), np.eye(4).astype(np.float32)
        orig2world = np.reshape(np.loadtxt(f'{data_dir}/orig_to_gaps.txt'), [4, 4]).astype(np.float32)
        world2orig = np.linalg.inv(orig2world).astype(np.float32)
        
        flow_samples = MeshDataset.load_npz(f'{data_dir}/flow.npz')
        
        if self.load_obj:
            mesh = trimesh.load_mesh(f'{data_dir}/mesh_orig.obj', process=False)
            vertices = np.array(mesh.vertices).astype(np.float32)
            vertex_normals = np.array(mesh.vertex_normals).astype(np.float32)
            edges = np.array(mesh.edges).astype(np.int32)
            reverse_edges = np.stack([edges[:, 1], edges[:, 0]], axis=-1)
            edges = np.concatenate([edges, reverse_edges], axis=0)
            faces = np.array(mesh.faces).astype(np.int32)
            adjacency = trimesh.graph.face_adjacency(faces=faces, mesh=mesh).astype(np.int32)
            vertex_adjacency = read_vertex_laplacian(mesh)
            return {
                'uniform_samples':          uniform_samples,
                'surface_samples':          surface_samples,
                'surface_normals':          surface_normals,
                'near_surface_samples':     near_surface_samples,
                'grid':                     grid,
                'world2grid':               world2grid,
                'world2orig':               world2orig,
                'bbox_lower':               bbox_lower,
                'bbox_upper':               bbox_upper,
                'flow_samples':             flow_samples,
                'flow_vertices':            vertices,
                'flow_vertex_normals':            vertex_normals,
                'flow_edges':               edges,
                'flow_faces':               faces,
                'flow_faces_adjacency':     adjacency,
                'flow_vertices_adjacency':  vertex_adjacency,
            }
                
        else:
            return {
                'uniform_samples':          uniform_samples,
                'surface_samples':          surface_samples,
                'surface_normals':          surface_normals,
                'near_surface_samples':     near_surface_samples,
                'grid':                     grid,
                'world2grid':               world2grid,
                'world2orig':               world2orig,
                'bbox_lower':               bbox_lower,
                'bbox_upper':               bbox_upper,
                'flow_samples':             flow_samples,
            }

    def __getitem__(self, index):
        identity_idx, identity_model_name, identity_seq_idx, model_name0, seq_idx0, model_name1, seq_idx1 \
            = self.labels_sample_pairs[index]["pair_info"]
        
        ## when mode == "train" and index is the last data sample:
        if self.split[:5] == "train" and index == len(self.labels_sample_pairs)-1:
            self.random_shuffle_samples()
            print("random shuffle deformation pairs in train dataset, and sample again")
        
        data_dir_cano = os.path.join(self.dataset_dir, identity_model_name, "%04d"%identity_seq_idx)
        data_dir0 = os.path.join(self.dataset_dir, model_name0, "%04d"%seq_idx0)
        data_dir1 = os.path.join(self.dataset_dir, model_name1, "%04d"%seq_idx1)
        
        data_cano = self._load_data(data_dir_cano)
        data0 = self._load_data(data_dir0)
        data1 = self._load_data(data_dir1)
        
        # Load transformation matrix
        world2orig_cano = data_cano['world2orig']
        world2orig0 = data0['world2orig']
        world2orig1 = data1['world2orig']
        orig2world_cano = np.linalg.inv(world2orig_cano)
        orig2world0 = np.linalg.inv(world2orig0)
        orig2world1 = np.linalg.inv(world2orig1)
        
        
        
        surface_samples_cano, surface_normals_cano = data_cano['surface_samples'], data_cano['surface_normals']
        surface_samples0, surface_normals0 = data0['surface_samples'], data0['surface_normals']
        surface_samples1, surface_normals1 = data1['surface_samples'], data1['surface_normals']
        bbox_min, bbox_max = surface_samples0.min(axis=0), surface_samples0.max(axis=0) 
        if surface_samples0.shape[0] > self.num_point_samples:
            if self.split[:5] == "train":
                if self.num_surf_samples is not None:
                    surface_samples_idxs = np.random.permutation(surface_samples0.shape[0])[:self.num_surf_samples]
                else:
                    surface_samples_idxs = np.random.permutation(surface_samples0.shape[0])[:self.num_point_samples] 
            else:
                if self.num_surf_samples is not None:
                    surface_samples_idxs = np.arange(self.num_surf_samples)
                else:
                    surface_samples_idxs = np.arange(self.num_point_samples)
            surface_samples_cano = surface_samples_cano[surface_samples_idxs, :]
            surface_samples0 = surface_samples0[surface_samples_idxs, :]
            surface_samples1 = surface_samples1[surface_samples_idxs, :]
            surface_normals_cano = surface_normals_cano[surface_samples_idxs, :]
            surface_normals0 = surface_normals0[surface_samples_idxs, :]
            surface_normals1 = surface_normals1[surface_samples_idxs, :]
        # always need to add rotation augmentation to normals
        #
        #
        #
        #
        if self.inverse:
            surface_normals = np.stack([surface_normals1, surface_normals0], axis=0)
        else:
            surface_normals = np.stack([surface_normals0, surface_normals1], axis=0)
        surface_normals = surface_normals.astype(np.float32)
        

        ## normalize or convert the coordniate system from t to t0
        if self.normalize_target:
            surface_samples_from0_to1 = surface_samples1
        else:         
            surface_samples_from0_to1 = (np.matmul(orig2world0[:3, :3], \
                (np.matmul(world2orig1[:3, :3], surface_samples1.T) + world2orig1[:3, 3:4])) + orig2world0[:3, 3:4]).T
        ## mask sample flow
        head_sample_idx = surface_samples_cano[:, 1] < bbox_min[1] + self.partial_range
        tail_sample_idx = surface_samples_cano[:, 1] > bbox_max[1] - self.partial_range
        foot_sample_idx = surface_samples_cano[:, 2] < bbox_min[2] + self.partial_range
        handle_sample_idx = head_sample_idx | tail_sample_idx | foot_sample_idx

        if self.inverse:
            surface_samples0_masked = surface_samples0 * handle_sample_idx[:, None]
            surface_samples = np.stack([surface_samples_from0_to1, surface_samples0_masked, handle_sample_idx[:, None].repeat(3, axis=1), surface_samples_cano, surface_samples0], axis=0)
        else:
            surface_samples_from0_to1_masked = surface_samples_from0_to1 * handle_sample_idx[:, None]
            surface_samples = np.stack([surface_samples0, surface_samples_from0_to1_masked, handle_sample_idx[:, None].repeat(3, axis=1), surface_samples_cano, surface_samples_from0_to1], axis=0)
        surface_samples = surface_samples.astype(np.float32)
        
        flow_samples0 = data0['flow_samples']
        flow_samples1 = data1['flow_samples']
        flow_samples_cano = data_cano['flow_samples']
        if flow_samples0.shape[0] > self.num_point_samples:
            flow_samples_idxs = np.random.permutation(flow_samples0.shape[0])[:self.num_point_samples]
            flow_samples0 = flow_samples0[flow_samples_idxs, :]
            flow_samples1 = flow_samples1[flow_samples_idxs, :]
            flow_samples_cano = flow_samples_cano[flow_samples_idxs, :]
        
        ## normalize or convert the coordniate system from t to t0
        flow_samples_from0_to1 = flow_samples1
        if self.inverse:
            flow_samples = np.stack([flow_samples_from0_to1, flow_samples0, flow_samples_cano], axis=0)
        else:
            flow_samples = np.stack([flow_samples0, flow_samples_from0_to1, flow_samples_cano], axis=0)
        flow_samples = flow_samples.astype(np.float32)
        
        if self.load_obj:
            flow_vetex_normals_cano = data_cano['flow_vertex_normals']
            flow_vetex_normals0 = data0['flow_vertex_normals']
            flow_vetex_normals1 = data1['flow_vertex_normals']
            
            flow_vertices_cano = data_cano['flow_vertices']
            flow_vertices0 = data0['flow_vertices']
            flow_vertices1 = data1['flow_vertices']
            flow_edges = data_cano['flow_edges']
            flow_faces = data_cano['flow_faces']
            flow_faces_adjacency = data_cano['flow_faces_adjacency']
            flow_vertices_adjacency = data_cano['flow_vertices_adjacency']

            flow_vertices_cano_normalize = (np.matmul(orig2world_cano[:3, :3], flow_vertices_cano.T) + orig2world_cano[:3, 3:4]).T
            flow_vertices0_normalize = (np.matmul(orig2world0[:3, :3], flow_vertices0.T) + orig2world0[:3, 3:4]).T
            flow_vertices1_normalize = (np.matmul(orig2world1[:3, :3], flow_vertices1.T) + orig2world1[:3, 3:4]).T

            ########### mask vertices flow
            ## mask vertices flow
            vert_bbox_min, vert_bbox_max = flow_vertices_cano_normalize.min(axis=0), flow_vertices_cano_normalize.max(axis=0) 
            head_vert_idx = flow_vertices_cano_normalize[:, 1] < vert_bbox_min[1] + self.partial_range
            tail_vert_idx = (flow_vertices_cano_normalize[:, 1] > vert_bbox_max[1] - self.partial_range) 
            foot_vert_idx = flow_vertices_cano_normalize[:, 2] < vert_bbox_min[2] + self.partial_range
            #handle_vert_idx = head_vert_idx | tail_vert_idx | foot_vert_idx     
            handle_vert_idx = (head_vert_idx + tail_vert_idx + foot_vert_idx) > 0

            if self.inverse:
                flow_vertices0_normalize_masked = flow_vertices0_normalize * handle_vert_idx[:, None]
                flow_vertices = np.stack([flow_vertices1_normalize, flow_vertices0_normalize_masked, handle_vert_idx[:, None].repeat(3, axis=1), flow_vertices_cano_normalize, flow_vertices0_normalize], axis=0)
                flow_vetex_normals = np.stack([flow_vetex_normals1, flow_vetex_normals0], axis=0)
            else:
                flow_vertices1_normalize_masked = flow_vertices1_normalize * handle_vert_idx[:, None]
                flow_vertices = np.stack([flow_vertices0_normalize, flow_vertices1_normalize_masked, handle_vert_idx[:, None].repeat(3, axis=1), flow_vertices_cano_normalize, flow_vertices1_normalize], axis=0)
                flow_vetex_normals = np.stack([flow_vetex_normals0, flow_vetex_normals1], axis=0)
            flow_vertices = flow_vertices.astype(np.float32)
            flow_vetex_normals = flow_vetex_normals.astype(np.float32)
            
        #load gird
        if self.inverse:
            grid = data1['grid'] 
            world2grid = data1['world2grid'] 
            world2orig = data1['world2orig'] 
            bbox_lower = data1['bbox_lower'] 
            bbox_upper = data1['bbox_upper']
        else:
            grid = data0['grid'] 
            world2grid = data0['world2grid'] 
            world2orig = data0['world2orig'] 
            bbox_lower = data0['bbox_lower'] 
            bbox_upper = data0['bbox_upper']
        
        #if self.use_augmentation:
        if True:
            rotated2gaps = np.eye(4).astype(np.float32)

            if self.load_obj:
                return flow_faces[np.newaxis, ...], flow_edges[np.newaxis, ...], \
                flow_vertices[np.newaxis, ...], flow_vertices_adjacency[np.newaxis, ...], surface_samples[np.newaxis, ...], \
                    surface_normals[np.newaxis, ...], surface_normals[np.newaxis, ...], \
                        rotated2gaps[np.newaxis, ...], \
                            bbox_lower[np.newaxis, ...], bbox_upper[np.newaxis, ...], \
                                identity_idx, index
            else:
                return surface_samples[np.newaxis, ...], flow_samples[np.newaxis, ...], \
                    surface_samples[np.newaxis, ...], flow_samples[np.newaxis, ...], grid[np.newaxis, ...],  \
                        surface_normals[np.newaxis, ...], surface_normals[np.newaxis, ...], \
                            rotated2gaps[np.newaxis, ...], \
                                bbox_lower[np.newaxis, ...], bbox_upper[np.newaxis, ...], \
                                    identity_idx, index
    

    def get_metadata(self, index):
        return self.labels_sample_pairs[index]