import glob
import h5py
import hdf5plugin
import numpy as np
import torch
import os
# from event_utils import VoxelGrid
from torch.utils.data import Dataset, DataLoader
import pandas as pd
from einops import rearrange, repeat
import cv2
from dv import AedatFile
# from .event_utils import VoxelGrid
from .event_utils import VoxelGrid



def normalize_input(
        item, 
        mean=[0.5, 0.5, 0.5], # Imagenet [0.485, 0.456, 0.406]
        std=[0.5, 0.5, 0.5], # Imagenet [0.229, 0.224, 0.225]
        use_simple_norm=False
    ):
        if item.dtype == torch.uint8 and not use_simple_norm:
            item = rearrange(item, 'f c h w -> f h w c')
            
            item = item.float() / 255.0
            mean = torch.tensor(mean)
            std = torch.tensor(std)

            out = rearrange((item - mean) / std, 'f h w c -> f c h w')
            
            return out
        else:
            
            item = rearrange(item, 'f c h w -> f h w c')
            return  rearrange(item / 127.5 - 1.0, 'f h w c -> f c h w')


class COEVideoDataset_test(Dataset):
    def __init__(
        self,
        width: int = 346,
        height: int = 260,
        n_sample_frames: int = 25,
        num_bins: int = 3,
        fps: int = 8,
        data_pth: str = "/home/user/proj/event_tracking/coe_example",
        seq_name: str = 'dvSave-2021_09_01_07_22_17',
        use_bucketing: bool = False,
        **kwargs
    ):

        self.use_bucketing = use_bucketing

        if not os.path.exists(data_pth):
            raise FileNotFoundError(f"The csv path does not exist: {data_pth}")

        self.crop_size = [256,256]
        self.width = width
        self.height = height
        self.num_bins = num_bins
        self.n_sample_frames = n_sample_frames
        self.fps = fps
        self.event_num = 40000
        self.data_pth = data_pth
        self.seq_name = seq_name
        self.seq_names = sorted(os.listdir(self.data_pth))

    def __getname__(self): 
        return 'coe'

    def __len__(self):
        return len(self.seq_names)

    def __getitem__(self, index):
        device = 'cpu'
        # seq_name = self.seq_names[index]
        # seq_name = self.seq_name
        seq_name = 'dvSave-2021_10_19_22_52_51'
        # seq_name = 'dvSave-2021_09_01_07_22_17'
        event_seq_file = os.path.join(self.data_pth,seq_name,seq_name+'.aedat4')
        frame_all = []
        frame_exposure_time = []
        frame_interval_time = []
        use_mode = 'frame_exposure_time'
        event_seq_file = ''

        with AedatFile(event_seq_file) as f:
            # print(f.names)
            for idx,frame in enumerate(f['frames']):
                frame_all.append(frame.image)
                frame_exposure_time.append([frame.timestamp_start_of_exposure,
                                            frame.timestamp_end_of_exposure])  ## [1607928583397102, 1607928583401102]
                frame_interval_time.append([frame.timestamp_start_of_frame,
                                            frame.timestamp_end_of_frame])  ## [1607928583387944, 1607928583410285]
                
                if idx>=self.n_sample_frames-1:
                    break
            if idx < self.n_sample_frames-1:
                seq_name = 'trash_%d_'% idx+seq_name
                zero_np = np.zeros_like(frame_all[0])
                for j in range(self.n_sample_frames-idx-1):
                    frame_all.append(zero_np)
        # print(len(frame_all))
        if use_mode == 'frame_exposure_time':
            frame_timestamp = frame_exposure_time
        elif use_mode == 'frame_interval_time':
            frame_timestamp = frame_interval_time
        # for packet in f['events'].numpy():
        #     print(packet.shape)
        events = np.hstack([packet for packet in f['events'].numpy()])
        frame_num = len(frame_timestamp)
        # print(frame_num)
        events_voxel_all = []
        max_norm = 0
        for i in range(frame_num):
            # print(np.where(events['timestamp'] >= frame_timestamp[i][0])[0])
            # start_idx = np.where(events['timestamp'] >= frame_timestamp[i][0])[0][0]
            # end_idx = np.where(events['timestamp'] >= frame_timestamp[i][1])[0][0]
            start_idx = np.argmin(abs(events['timestamp']-frame_timestamp[i][0]))
            # print(start_idx)
            end_idx = np.argmin(abs(events['timestamp']-frame_timestamp[i][1]))
            # print(start_idx,end_idx)
            event_num = end_idx-start_idx
            if event_num > self.event_num:
                events_chunk = events[start_idx+int(event_num/2)-20000:end_idx-int(event_num/2)+20000]
            else:
                events_chunk = events[start_idx:end_idx]
            # print(events_chunk.shape[0])
                
            # if 
            t_all = torch.tensor(events_chunk['timestamp'].astype(np.float64)).to(device)
            # print(1,t_all[end_idx-1]-t_all[start_idx])
            x_all = torch.tensor(events_chunk['x'].astype(np.float32)).to(device)
            y_all = torch.tensor(events_chunk['y'].astype(np.float32)).to(device)
            p_all = torch.tensor(events_chunk['polarity'].astype(np.float32)).to(device)
            # print(x_all,y_all,p_all)
            voxel_grid = VoxelGrid((self.num_bins, self.height, self.width), normalize=True, device=device)
            event_voxel = voxel_grid.convert({
                'x': x_all,
                'y': y_all,
                't': t_all,
                'p': p_all,
            })
            event_voxel[event_voxel==event_voxel.max()]=0
            event_voxel[event_voxel==event_voxel.max()]=0
            event_voxel[event_voxel==event_voxel.max()]=0
            event_voxel[event_voxel==event_voxel.max()]=0
            event_voxel[event_voxel==event_voxel.max()]=0
            a = abs(event_voxel.max())
            b = abs(event_voxel.min())
            # print(event_voxel.mean(),a,b)
            max_abs = a if a>b else b
            max_norm = max_norm if max_norm>max_abs else max_abs
            # print(max_norm)
            events_voxel_all.append(event_voxel)

        if frame_num < self.n_sample_frames:
            zero_torch = torch.zeros_like(event_voxel)
            for j in range(self.n_sample_frames-frame_num):
                events_voxel_all.append(zero_torch)
        # print(len(events_voxel_all))
        # events_voxel_all = list(map(lambda x: x / max_norm,events_voxel_all)) # norm to [-1,1]
        # print(max_norm)
        events_voxel_cat = torch.stack(events_voxel_all,dim=0)/max_norm
        frame_cat = np.stack(frame_all,axis=0)
        frame_cat = torch.from_numpy(frame_cat)/127.5 -1
        frame_cat = rearrange(frame_cat, "f h w c -> f c h w").to(device)


        # crop_x = np.random.randint(0,self.width-self.crop_size[0])
        # crop_y = np.random.randint(0,self.height-self.crop_size[1])
        crop_x = int((self.width-self.crop_size[0])/2)
        crop_y = int((self.height-self.crop_size[1])/2)

        events_voxel_cat = events_voxel_cat[:,:,crop_y:crop_y+self.crop_size[1],crop_x:crop_x+self.crop_size[0]]
        frame_cat = frame_cat[:,:,crop_y:crop_y+self.crop_size[1],crop_x:crop_x+self.crop_size[0]]
        # print(frame_cat.shape,frame_cat.max(),frame_cat.min())
        # print(events_voxel_cat.shape,events_voxel_cat.max(),events_voxel_cat.min())



        # print(max_norm)
        # print(frame_num)
        # print(frame_timestamp)
        # t_all = torch.tensor(events['timestamp'].astype(np.float32)).to(device)
        # x_all = torch.tensor(events['x'].astype(np.float32)).to(device)
        # y_all = torch.tensor(events['y'].astype(np.float32)).to(device)
        # p_all = torch.tensor(events['polarity'].astype(np.float32)).to(device)

        # voxel_grid = VoxelGrid((self.num_bins, self.height, self.width), normalize=True, device=device)
        # event_voxel = voxel_grid.convert({
        #     'x': x_all,
        #     'y': y_all,
        #     't': p_all,
        #     'p': t_all,
        # })

        # print(event_voxel.shape)

        # event0 = (events_voxel_cat[0].unsqueeze(0) +1)/2 
        # rgb0 = (frame_cat[0].unsqueeze(0) +1)/2 
        event0 = (events_voxel_cat[:1] +1)/2 
        rgb0 = (frame_cat[:3]+1)/2 

        return {"pixel_values": events_voxel_cat, "image": event0, 'dataset': self.__getname__(),'rgb_value':frame_cat,'rgb_first':rgb0,'seqname':seq_name}
        # return {"pixel_values": normalize_input(video[0]), "image": image, 'dataset': self.__getname__()}


if __name__ == '__main__':

    img = cv2.imread('/project_pth/video_look/video_1_23.png')
    print(img.shape)
    img_event = img[:,256:,]
    cv2.imwrite('/project_pth/video_look/save_cor_24.png',img_event)
    