import sys
import os
import mmap
import pandas
from io import BytesIO
from time import sleep, time
import numpy
from multiprocessing import Pool
import numpy as np
import torch
import cv2
from event_utils import VoxelGrid
from torchvision.transforms import Resize
from tqdm import tqdm

opened_files = {}


def mkdir(pth):
    if not os.path.exists(pth):
        os.makedirs(pth)

def preprocess(train_pth):
    data_name = sorted(os.listdir(train_pth))
    data_ls = sorted(list(map(lambda x:os.path.join(train_pth,x),data_name)))

    for d in data_ls:
        print("/project_pth/utils/mmap_loader.py", d)
    # with Pool(64) as p:
    #     p.map(convert, data_ls)

def single_process(train_pth):
    data_name = sorted(os.listdir(train_pth))
    data_ls = sorted(list(map(lambda x:os.path.join(train_pth,x),data_name)))
    for subset_dir in data_ls:
        subset_name = os.path.basename(subset_dir)
        
        voxel_path = os.path.join(subset_dir, f"{subset_name}_voxel.npy")
        if os.path.exists(voxel_path):
            print(f'file {voxel_path} exits!')
            continue
        else:
            print(f'save to {voxel_path}')
            save_voxel(subset_dir)

def convert(subset_dir):
    tic = time()

    subset_name = os.path.basename(subset_dir)
    csv_path = os.path.join(subset_dir, f"{subset_name}.csv")
    npy_path = os.path.join(subset_dir, f"{subset_name}.npy")
    
    df = pandas.read_csv(csv_path)
    numpy.save(npy_path, df.values.astype(numpy.int32))

    
    toc = time()
    print("Done:", npy_path, f"Time: {toc - tic:.3f}s")


def save_voxel(subset_dir):
    tic = time()
    fps = 300
    num_bins = 3
    width = 1280
    height = 720
    device = 'cpu'
    
    subset_name = os.path.basename(subset_dir)
    csv_path = os.path.join(subset_dir, f"{subset_name}.csv")
    df = pandas.read_csv(csv_path)
    # npy_path = os.path.join(subset_dir, f"{subset_name}.npy")
    voxel_path = os.path.join(subset_dir, f"{subset_name}_voxel_crop.npy")
    # print(voxel_path)
    event_data = df.values.astype(numpy.int32)

    x = event_data[:,0]
    y = event_data[:,1]
    p = event_data[:,2]
    t = event_data[:,3]
    t -= t[0]
    fps = fps
    delta_t = 1000*1000/fps
    max_idx = int(t[-1] / delta_t)
    event_list = []

    for idx in tqdm(range(max_idx)):
        # print(idx,'/',max_idx)
        start_t = idx*delta_t
        end_t = (idx+1)*delta_t
        start_idx = np.argmin(abs(t-start_t))
        end_idx = np.argmin(abs(t-end_t))
        x_e = x[start_idx:end_idx].astype(np.float32)
        y_e = y[start_idx:end_idx].astype(np.float32)
        p_e = p[start_idx:end_idx].astype(np.float32)
        t_e = t[start_idx:end_idx].astype(np.float32)
        
        events = np.stack([x_e, y_e, t_e, p_e], axis=1)
        events = torch.from_numpy(events).to(device)
        voxel_grid = VoxelGrid((num_bins, height, width), normalize=False, device=device)
        event_voxel = voxel_grid.convert({
            'x': events[:,0],
            'y': events[:,1],
            't': events[:,2],
            'p': events[:,3],
        })   
        event_voxel = event_voxel[:,:,280:280+720]
        # print(event_voxel.max(),event_voxel.min(),event_voxel.mean())
        torch_resize = Resize([180,180],antialias=True)
        event_voxel = torch_resize(event_voxel).numpy()
        event_list.append(event_voxel)
    
    event_voxel_ls = np.array(event_list)
    # mkdir(os.path.join('/home/user/fsdownload/ft_svd/crop_npy/300npy',subset_name))
    numpy.save(voxel_path, event_voxel_ls.astype(numpy.float16))
    toc = time()
    print("Done:", voxel_path, f"Time: {toc - tic:.3f}s")





def change_voxel(subset_dir):
    tic = time()
    subset_name = os.path.basename(subset_dir)
    voxel_path = os.path.join(subset_dir, f"{subset_name}_voxel.npy")
    save_pth = os.path.join(subset_dir, f"{subset_name}_voxel_crop.npy")
    event_data = np.load(voxel_path)
    event_data = torch.from_numpy(event_data[:,:,:,280:280+720])
    torch_resize = Resize([180,180],antialias=True)
    event_data = torch_resize(event_data).numpy()
    numpy.save(save_pth, event_data.astype(numpy.float16))
    toc = time()
    print("Done:", save_pth, f"Time: {toc - tic:.3f}s")
    # event_data = event_data[]


if __name__ == "__main__":

