import glob
import os
import random
import time
import sys
import cv2
import imageio
import lpips
import open3d as o3d
import mcubes
import numpy as np
import tensorboardX
import torch
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import tqdm
import trimesh
from rich.console import Console
from skimage.metrics import structural_similarity
from torch_ema import ExponentialMovingAverage
from lidarnerf.nerf.custom_loss import chamfer_distance_low_capacity
from lidarnerf.nerf.custom_loss import chamfer_based_norm_loss_low_capacity
from lidarnerf.nerf.custom_loss import chamfer_distance
from lidarnerf.nerf.custom_loss import chamfer_based_norm_loss

from extern.chamfer3D.dist_chamfer_3D import chamfer_3DDist
from extern.fscore import fscore

from lidarnerf.dataset.base_dataset import custom_meshgrid

from lidarnerf.convert import pano_to_lidar

def is_ali_cluster():
    import socket

    hostname = socket.gethostname()
    return "auto-drive" in hostname
@torch.jit.script
def linear_to_srgb(x):
    return torch.where(x < 0.0031308, 12.92 * x, 1.055 * x**0.41666 - 0.055)
@torch.jit.script
def srgb_to_linear(x):
    return torch.where(x < 0.04045, x / 12.92, ((x + 0.055) / 1.055) ** 2.4)
def filter_bbox_dataset(pc, OBB_local):
    bbox_mask = np.isnan(pc[:, 0])
    z_min, z_max = min(OBB_local[:, 2]), max(OBB_local[:, 2])
    for i, (c1, c2) in enumerate(zip(pc[:, 2] <= z_max, pc[:, 2] >= z_min)):
        bbox_mask[i] = c1 and c2
    pc = pc[bbox_mask]
    OBB_local = sorted(OBB_local, key=lambda p: p[2])
    OBB_2D = np.array(OBB_local)[:4, :2]
    pc = filter_poly(pc, OBB_2D)
    return pc
def filter_poly(pcs, OBB_2D):
    OBB_2D = sort_quadrilateral(OBB_2D)
    mask = []
    for pc in pcs:
        mask.append(is_in_poly(pc[0], pc[1], OBB_2D))
    return pcs[mask]
def sort_quadrilateral(points):
    points = points.tolist()
    top_left = min(points, key=lambda p: p[0] + p[1])
    bottom_right = max(points, key=lambda p: p[0] + p[1])
    points.remove(top_left)
    points.remove(bottom_right)
    bottom_left, top_right = points
    if bottom_left[1] > top_right[1]:
        bottom_left, top_right = top_right, bottom_left
    return [top_left, top_right, bottom_right, bottom_left]
def is_in_poly(px, py, poly):
    """
    :param p: [x, y]
    :param poly: [[], [], [], [], ...]
    :return:
    """
    is_in = False
    for i, corner in enumerate(poly):
        next_i = i + 1 if i + 1 < len(poly) else 0
        x1, y1 = corner
        x2, y2 = poly[next_i]
        if (x1 == px and y1 == py) or (x2 == px and y2 == py):  # if point is on vertex
            is_in = True
            break
        if min(y1, y2) < py <= max(y1, y2):  # find horizontal edges of polygon
            x = x1 + (py - y1) * (x2 - x1) / (y2 - y1)
            if x == px:  # if point is on edge
                is_in = True
                break
            elif x > px:  # if point is on left-side of line
                is_in = not is_in
    return is_in
def seed_everything(seed):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    # torch.backends.cudnn.deterministic = True
    # torch.backends.cudnn.benchmark = True
def torch_vis_2d(x, renormalize=False):
    # x: [3, H, W] or [1, H, W] or [H, W]
    import matplotlib.pyplot as plt
    import numpy as np
    import torch

    if isinstance(x, torch.Tensor):
        if len(x.shape) == 3:
            x = x.permute(1, 2, 0).squeeze()
        x = x.detach().cpu().numpy()

    print(f"[torch_vis_2d] {x.shape}, {x.dtype}, {x.min()} ~ {x.max()}")

    x = x.astype(np.float32)

    # renormalize
    if renormalize:
        x = (x - x.min(axis=0, keepdims=True)) / (
            x.max(axis=0, keepdims=True) - x.min(axis=0, keepdims=True) + 1e-8
        )

    plt.imshow(x)
    plt.show()
def extract_fields(bound_min, bound_max, resolution, query_func, S=128):
    X = torch.linspace(bound_min[0], bound_max[0], resolution).split(S)
    Y = torch.linspace(bound_min[1], bound_max[1], resolution).split(S)
    Z = torch.linspace(bound_min[2], bound_max[2], resolution).split(S)

    u = np.zeros([resolution, resolution, resolution], dtype=np.float32)
    with torch.no_grad():
        for xi, xs in enumerate(X):
            for yi, ys in enumerate(Y):
                for zi, zs in enumerate(Z):
                    xx, yy, zz = custom_meshgrid(xs, ys, zs)
                    pts = torch.cat(
                        [xx.reshape(-1, 1), yy.reshape(-1, 1), zz.reshape(-1, 1)],
                        dim=-1,
                    )  # [S, 3]
                    val = (
                        query_func(pts)
                        .reshape(len(xs), len(ys), len(zs))
                        .detach()
                        .cpu()
                        .numpy()
                    )  # [S, 1] --> [x, y, z]
                    u[
                        xi * S : xi * S + len(xs),
                        yi * S : yi * S + len(ys),
                        zi * S : zi * S + len(zs),
                    ] = val
    return u
def extract_geometry(bound_min, bound_max, resolution, threshold, query_func):
    # print('threshold: {}'.format(threshold))
    u = extract_fields(bound_min, bound_max, resolution, query_func)

    # print(u.shape, u.max(), u.min(), np.percentile(u, 50))

    vertices, triangles = mcubes.marching_cubes(u, threshold)

    b_max_np = bound_max.detach().cpu().numpy()
    b_min_np = bound_min.detach().cpu().numpy()

    vertices = (
        vertices / (resolution - 1.0) * (b_max_np - b_min_np)[None, :]
        + b_min_np[None, :]
    )
    return vertices, triangles

class RaydropMeter:
    def __init__(self, ratio):
        self.V = []
        self.N = 0
        self.ratio = ratio

    def clear(self):
        self.V = []
        self.N = 0

    def prepare_inputs(self, *inputs):
        outputs = []
        for i, inp in enumerate(inputs):
            if torch.is_tensor(inp):
                inp = inp.detach().cpu().numpy()
            outputs.append(inp)

        return outputs

    def update(self, preds, truths):
        preds, truths = self.prepare_inputs(
            preds, truths
        )  # [B, N, 3] or [B, H, W, 3], range[0, 1]
        results = []

        rmse = (truths - preds) ** 2
        rmse = np.sqrt(rmse.mean())
        results.append(rmse)

        # raydrop = np.where(preds > self.ratio, 1, 0)
        # acc = (raydrop==truths).sum()/raydrop.size
        # results.append(acc)

        for i in range(9):
            ratio = 0.1*(i+1)
            raydrop = np.where(preds > ratio, 1, 0)
            acc = (raydrop==truths).sum()/raydrop.size
            results.append(acc)

        self.V.append(results)
        self.N += 1

    def measure(self):
        assert self.N == len(self.V)
        return np.array(self.V).mean(0)

    def write(self, writer, global_step, prefix="", suffix=""):
        writer.add_scalar(os.path.join(prefix, "raydrop error"), self.measure()[0], global_step)

    def report(self):
        return f"Raydrop_error (rmse, accuracy) = {self.measure()}"
class DepthMeter:
    def __init__(self, scale, lpips_fn=None):
        self.V = []
        self.N = 0
        self.scale = scale
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lpips_fn = lpips.LPIPS(net='alex').eval()

    def clear(self):
        self.V = []
        self.N = 0

    def prepare_inputs(self, *inputs):
        outputs = []
        for i, inp in enumerate(inputs):
            if torch.is_tensor(inp):
                inp = inp.detach().cpu().numpy()
            outputs.append(inp)

        return outputs

    def update(self, preds, truths):
        preds = preds / self.scale
        truths = truths / self.scale
        preds, truths = self.prepare_inputs(
            preds, truths
        )  # [B, N, 3] or [B, H, W, 3], range[0, 1]

        # simplified since max_pixel_value is 1 here.
        depth_error = self.compute_depth_errors(truths, preds)

        depth_error = list(depth_error)
        self.V.append(depth_error)
        self.N += 1

    def compute_depth_errors(
        self, gt, pred, min_depth=1e-6, max_depth=80, thresh_set=1.25
    ):  
        pred[pred < min_depth] = min_depth
        pred[pred > max_depth] = max_depth
        gt[gt < min_depth] = min_depth
        gt[gt > max_depth] = max_depth
        
        thresh = np.maximum((gt / pred), (pred / gt))
        a1 = (thresh < thresh_set).mean()
        a2 = (thresh < thresh_set**2).mean()
        a3 = (thresh < thresh_set**3).mean()

        rmse = (gt - pred) ** 2
        rmse = np.sqrt(rmse.mean())

        medae =  np.median(np.abs(gt - pred))

        lpips_loss = self.lpips_fn(torch.from_numpy(pred).squeeze(0), 
                                   torch.from_numpy(gt).squeeze(0), normalize=True).item()

        ssim = structural_similarity(
            pred.squeeze(0), gt.squeeze(0), data_range=np.max(gt) - np.min(gt)
        )

        psnr = 10 * np.log10(max_depth**2 / np.mean((pred - gt) ** 2))

        return rmse, medae, a1, a2, a3, lpips_loss, ssim, psnr

    def measure(self):
        assert self.N == len(self.V)
        return np.array(self.V).mean(0)

    def write(self, writer, global_step, prefix="", suffix=""):
        writer.add_scalar(
            os.path.join(prefix, f"depth error{suffix}"), self.measure()[0], global_step
        )

    def report(self):
        return f"Depth_error = {self.measure()}"
class IntensityMeter:
    def __init__(self, scale, lpips_fn=None):
        self.V = []
        self.N = 0
        self.scale = scale
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.lpips_fn = lpips.LPIPS(net='alex').eval()

    def clear(self):
        self.V = []
        self.N = 0

    def prepare_inputs(self, *inputs):
        outputs = []
        for i, inp in enumerate(inputs):
            if torch.is_tensor(inp):
                inp = inp.detach().cpu().numpy()
            outputs.append(inp)

        return outputs

    def update(self, preds, truths):
        preds = preds / self.scale
        truths = truths / self.scale
        preds, truths = self.prepare_inputs(
            preds, truths
        )  # [B, N, 3] or [B, H, W, 3], range[0, 1]

        # simplified since max_pixel_value is 1 here.
        intensity_error = self.compute_intensity_errors(truths, preds)

        intensity_error = list(intensity_error)
        self.V.append(intensity_error)
        self.N += 1

    def compute_intensity_errors(
        self, gt, pred, min_intensity=1e-6, max_intensity=1.0, thresh_set=1.25
    ):
        pred[pred < min_intensity] = min_intensity
        pred[pred > max_intensity] = max_intensity
        gt[gt < min_intensity] = min_intensity
        gt[gt > max_intensity] = max_intensity

        thresh = np.maximum((gt / pred), (pred / gt))
        a1 = (thresh < thresh_set).mean()
        a2 = (thresh < thresh_set**2).mean()
        a3 = (thresh < thresh_set**3).mean()

        rmse = (gt - pred) ** 2
        rmse = np.sqrt(rmse.mean())

        medae =  np.median(np.abs(gt - pred))

        lpips_loss = self.lpips_fn(torch.from_numpy(pred).squeeze(0), 
                                   torch.from_numpy(gt).squeeze(0), normalize=True).item()

        ssim = structural_similarity(
            pred.squeeze(0), gt.squeeze(0), data_range=np.max(gt) - np.min(gt)
        )

        psnr = 10 * np.log10(max_intensity**2 / np.mean((pred - gt) ** 2))

        return rmse, medae, a1, a2, a3, lpips_loss, ssim, psnr

    def measure(self):
        assert self.N == len(self.V)
        return np.array(self.V).mean(0)

    def write(self, writer, global_step, prefix="", suffix=""):
        writer.add_scalar(
            os.path.join(prefix, f"intensity error{suffix}"), self.measure()[0], global_step
        )

    def report(self):
        return f"Inten_error = {self.measure()}"
class PointsMeter:
    def __init__(self, scale, intrinsics):
        self.V = []
        self.N = 0
        self.scale = scale
        self.intrinsics = intrinsics

    def clear(self):
        self.V = []
        self.N = 0

    def prepare_inputs(self, *inputs):
        outputs = []
        for i, inp in enumerate(inputs):
            if torch.is_tensor(inp):
                inp = inp.detach().cpu().numpy()
            outputs.append(inp)

        return outputs

    def update(self, preds, truths):
        preds = preds / self.scale
        truths = truths / self.scale
        preds, truths = self.prepare_inputs(
            preds, truths
        )  # [B, N, 3] or [B, H, W, 3], range[0, 1]
        chamLoss = chamfer_3DDist()
        pred_lidar = pano_to_lidar(preds[0], self.intrinsics)
        gt_lidar = pano_to_lidar(truths[0], self.intrinsics)

        dist1, dist2, idx1, idx2 = chamLoss(
            torch.FloatTensor(pred_lidar[None, ...]).cuda(),
            torch.FloatTensor(gt_lidar[None, ...]).cuda(),
        )
        chamfer_dis = dist1.mean() + dist2.mean()
        threshold = 0.05  # monoSDF
        f_score, precision, recall = fscore(dist1, dist2, threshold)
        f_score = f_score.cpu()[0]

        self.V.append([chamfer_dis.cpu(), f_score])

        self.N += 1

    def measure(self):
        # return self.V / self.N
        assert self.N == len(self.V)
        return np.array(self.V).mean(0)

    def write(self, writer, global_step, prefix="", suffix=""):
        writer.add_scalar(os.path.join(prefix, "CD"), self.measure()[0], global_step)

    def report(self):
        return f"Point_error (CD, f-score) = {self.measure()}"

class Trainer(object):
    def __init__(
        self,
        name,  # name of this experiment
        opt,  # extra conf
        model,  # network
        criterion=None,  # loss function, if None, assume inline implementation in train_step
        optimizer=None,  # optimizer
        optimizer_pose_rot=None,
        optimizer_pose_trans=None,
        ema_decay=None,  # if use EMA, set the decay
        lr_scheduler=None,  # scheduler
        lr_scheduler_pose_rot=None,
        lr_scheduler_pose_trans=None,
        metrics=[],  # metrics for evaluation, if None, use val_loss to measure performance, else use the first metric.
        depth_metrics=[],
        local_rank=0,  # which GPU am I
        world_size=1,  # total num of GPUs
        device=None,  # device to use, usually setting to None is OK. (auto choose device)
        mute=False,  # whether to mute all print
        fp16=False,  # amp optimize level
        eval_interval=50,  # eval once every $ epoch
        max_keep_ckpt=2,  # max num of saved ckpts in disk
        workspace="workspace",  # workspace to save logs & ckpts
        best_mode="min",  # the smaller/larger result, the better
        use_loss_as_metric=True,  # use loss as the first metric
        report_metric_at_train=False,  # also report metrics at training
        use_checkpoint="latest",  # which ckpt to use at init time
        use_tensorboardX=True,  # whether to use tensorboard for logging
        scheduler_update_every_step=False,  # whether to call scheduler.step() after every train step
    ):
        self.pcds=None
        self.poses=None
        self.it=0
        self.name = name
        self.opt = opt
        self.mute = mute
        self.metrics = metrics
        self.depth_metrics = depth_metrics
        self.local_rank = local_rank
        self.world_size = world_size
        self.workspace = workspace
        self.ema_decay = ema_decay
        self.fp16 = fp16
        self.best_mode = best_mode
        self.use_loss_as_metric = use_loss_as_metric
        self.report_metric_at_train = report_metric_at_train
        self.max_keep_ckpt = max_keep_ckpt
        self.eval_interval = eval_interval
        self.use_checkpoint = use_checkpoint
        self.use_tensorboardX = use_tensorboardX
        self.time_stamp = time.strftime("%Y-%m-%d_%H-%M-%S")
        self.scheduler_update_every_step = scheduler_update_every_step
        self.device = (
            device
            if device is not None
            else torch.device(
                f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu"
            )
        )
        self.console = Console()

        self.optim_direction_rot = None
        self.optim_direction_trans = None
        self.optim_gradients_trans=[[] for i in range(self.opt.dataloader_size)]
        self.optim_gradients_rot=[[] for i in range(self.opt.dataloader_size)]


        model.to(self.device)
        if self.world_size > 1:
            model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[local_rank]
            )
        self.model = model

        if isinstance(criterion, nn.Module):
            criterion.to(self.device)
        self.criterion = criterion

        # optionally use LPIPS loss for patch-based training
        # if self.opt.patch_size > 1:
        #     import lpips
        #     self.criterion_lpips = lpips.LPIPS(net='alex').to(self.device)

        if optimizer is None:
            self.optimizer = optim.Adam(
                self.model.parameters(), lr=0.001, weight_decay=5e-4
            )  # naive adam
        else:
            self.optimizer = optimizer(self.model)
            self.optimizer_pose_rot=optimizer_pose_rot(self.model)
            self.optimizer_pose_trans=optimizer_pose_trans(self.model)

        if lr_scheduler is None:
            self.lr_scheduler = optim.lr_scheduler.LambdaLR(
                self.optimizer, lr_lambda=lambda epoch: 1
            )  # fake scheduler
        else:
            self.lr_scheduler = lr_scheduler(self.optimizer)
            self.lr_scheduler_pose_rot=lr_scheduler_pose_rot(self.optimizer_pose_rot)
            self.lr_scheduler_pose_trans=lr_scheduler_pose_trans(self.optimizer_pose_trans)

        if ema_decay is not None:
            #self.ema = ExponentialMovingAverage(
                #self.model.parameters(), decay=ema_decay
            #)
            self.ema=None
        else:
            self.ema = None

        self.scaler = torch.cuda.amp.GradScaler(enabled=self.fp16)

        # variable init
        self.epoch = 0
        self.global_step = 0
        self.local_step = 0
        self.stats = {
            "loss": [],
            "valid_loss": [],
            "results": [],  # metrics[0], or valid_loss
            "checkpoints": [],  # record path of saved ckpt, to automatically remove old ckpt
            "best_result": None,
        }

        # auto fix
        if len(metrics) == 0 or self.use_loss_as_metric:
            self.best_mode = "min"

        # workspace prepare
        self.log_ptr = None
        if self.workspace is not None:
            os.makedirs(self.workspace, exist_ok=True)
            self.log_path = os.path.join(workspace, f"log_{self.name}.txt")
            self.log_ptr = open(self.log_path, "a+")

            self.ckpt_path = os.path.join(self.workspace, "checkpoints")
            self.best_path = f"{self.ckpt_path}/{self.name}.pth"
            os.makedirs(self.ckpt_path, exist_ok=True)

        self.log(
            f'[INFO] Trainer: {self.name} | {self.time_stamp} | {self.device} | {"fp16" if self.fp16 else "fp32"} | {self.workspace}'
        )
        self.log(
            f"[INFO] #parameters: {sum([p.numel() for p in model.parameters() if p.requires_grad])}"
        )

        if self.workspace is not None:
            if self.use_checkpoint == "scratch":
                self.log("[INFO] Training from scratch ...")
            elif self.use_checkpoint == "latest":
                self.log("[INFO] Loading latest checkpoint ...")
                self.load_checkpoint()
            elif self.use_checkpoint == "latest_model":
                self.log("[INFO] Loading latest checkpoint (model only)...")
                self.load_checkpoint(model_only=True)
            elif self.use_checkpoint == "best":
                if os.path.exists(self.best_path):
                    self.log("[INFO] Loading best checkpoint ...")
                    self.load_checkpoint(self.best_path)
                else:
                    self.log(f"[INFO] {self.best_path} not found, loading latest ...")
                    self.load_checkpoint()
            else:  # path to ckpt
                self.log(f"[INFO] Loading {self.use_checkpoint} ...")
                self.load_checkpoint(self.use_checkpoint)

    def __del__(self):
        if self.log_ptr:
            self.log_ptr.close()
    def log(self, *args, **kwargs):
        if self.local_rank == 0:
            if not self.mute:
                # print(*args)
                self.console.print(*args, **kwargs)
            if self.log_ptr:
                print(*args, file=self.log_ptr)
                self.log_ptr.flush()  # write immediately to file
    def save_mesh(self, save_path=None, resolution=256, threshold=10):
        if save_path is None:
            save_path = os.path.join(
                self.workspace, "meshes", f"{self.name}_{self.epoch}.ply"
            )

        self.log(f"==> Saving mesh to {save_path}")

        os.makedirs(os.path.dirname(save_path), exist_ok=True)

        def query_func(pts):
            with torch.no_grad():
                with torch.cuda.amp.autocast(enabled=self.fp16):
                    sigma = self.model.density(pts.to(self.device))["sigma"]
            return sigma

        vertices, triangles = extract_geometry(
            self.model.aabb_infer[:3],
            self.model.aabb_infer[3:],
            resolution=resolution,
            threshold=threshold,
            query_func=query_func,
        )

        mesh = trimesh.Trimesh(
            vertices, triangles, process=False
        )  # important, process=True leads to seg fault...
        mesh.export(save_path)

        self.log(f"==> Finished saving mesh.")
    def save_checkpoint(self, name=None, full=False, best=False, remove_old=True):
        if name is None:
            name = f"{self.name}_ep{self.epoch:04d}"

        state = {
            "epoch": self.epoch,
            "global_step": self.global_step,
            "stats": self.stats,
            "it": self.it,
        }

        if full:
            state["optimizer"] = self.optimizer.state_dict()
            state["optimizer_pose_rot"]=self.optimizer_pose_rot.state_dict()
            state["optimizer_pose_trans"]=self.optimizer_pose_trans.state_dict()
            state["lr_scheduler"] = self.lr_scheduler.state_dict()
            state["lr_scheduler_pose_rot"]=self.lr_scheduler_pose_rot.state_dict()
            state["lr_scheduler_pose_trans"]=self.lr_scheduler_pose_trans.state_dict()
            state["scaler"] = self.scaler.state_dict()
            if self.ema is not None:
                state["ema"] = self.ema.state_dict()
        
        if not best:
            state["model"] = self.model.state_dict()

            file_path = f"{self.ckpt_path}/{name}.pth"

            if remove_old:
                self.stats["checkpoints"].append(file_path)

                if len(self.stats["checkpoints"]) > self.max_keep_ckpt:
                    old_ckpt = self.stats["checkpoints"].pop(0)
                    if os.path.exists(old_ckpt):
                        os.remove(old_ckpt)

            torch.save(state, file_path)

        else:
            if len(self.stats["results"]) > 0:
                if (
                    self.stats["best_result"] is None
                    or self.stats["results"][-1] < self.stats["best_result"]
                ):
                    self.log(
                        f"[INFO] New best result: {self.stats['best_result']} --> {self.stats['results'][-1]}"
                    )
                    self.stats["best_result"] = self.stats["results"][-1]

                    # save ema results
                    if self.ema is not None:
                        self.ema.store()
                        self.ema.copy_to()

                    state["model"] = self.model.state_dict()

                    # we don't consider continued training from the best ckpt, so we discard the unneeded density_grid to save some storage (especially important for dnerf)
                    if "density_grid" in state["model"]:
                        del state["model"]["density_grid"]

                    if self.ema is not None:
                        self.ema.restore()

                    torch.save(state, self.best_path)
            else:
                self.log(
                    f"[WARN] no evaluated results found, skip saving best checkpoint."
                )
    def load_checkpoint(self, checkpoint=None, model_only=False):
        if checkpoint is None:
            checkpoint_list = sorted(glob.glob(f"{self.ckpt_path}/{self.name}_ep*.pth"))
            if checkpoint_list:
                checkpoint = checkpoint_list[-1]
                self.log(f"[INFO] Latest checkpoint is {checkpoint}")
            else:
                self.log("[WARN] No checkpoint found, model randomly initialized.")
                return

        checkpoint_dict = torch.load(checkpoint, map_location=self.device)

        if "model" not in checkpoint_dict:
            
            self.model.load_state_dict(checkpoint_dict)
            self.log("[INFO] loaded model.")
            return

        missing_keys, unexpected_keys = self.model.load_state_dict(
            checkpoint_dict["model"], strict=False
        )
        self.log("[INFO] loaded model.")
        if len(missing_keys) > 0:
            self.log(f"[WARN] missing keys: {missing_keys}")
        if len(unexpected_keys) > 0:
            self.log(f"[WARN] unexpected keys: {unexpected_keys}")

        if self.ema is not None and "ema" in checkpoint_dict:
            self.ema.load_state_dict(checkpoint_dict["ema"])

        if model_only:
            return
        self.it = checkpoint_dict["it"]
        self.stats = checkpoint_dict["stats"]
        self.epoch = checkpoint_dict["epoch"]
        self.global_step = checkpoint_dict["global_step"]
        self.log(f"[INFO] load at epoch {self.epoch}, global step {self.global_step}")

        if self.optimizer and "optimizer" in checkpoint_dict:
            try:
                self.optimizer.load_state_dict(checkpoint_dict["optimizer"])
                self.optimizer_pose_rot.load_state_dict(checkpoint_dict["optimizer_pose_rot"])
                self.optimizer_pose_trans.load_state_dict(checkpoint_dict["optimizer_pose_trans"])
                self.log("[INFO] loaded optimizer.")
            except:
                self.log("[WARN] Failed to load optimizer.")

        if self.lr_scheduler and "lr_scheduler" in checkpoint_dict:
            try:
                self.lr_scheduler.load_state_dict(checkpoint_dict["lr_scheduler"])
                self.lr_scheduler_pose_rot.load_state_dict(checkpoint_dict["lr_scheduler_pose_rot"])
                self.lr_scheduler_pose_trans.load_state_dict(checkpoint_dict["lr_scheduler_pose_trans"])
                self.log("[INFO] loaded scheduler.")
            except:
                self.log("[WARN] Failed to load scheduler.")

        if self.scaler and "scaler" in checkpoint_dict:
            try:
                self.scaler.load_state_dict(checkpoint_dict["scaler"])
                self.log("[INFO] loaded scaler.")
            except:
                self.log("[WARN] Failed to load scaler.")
   
    #if needed
    def train_pose_one_epoch(self,loader,start1,start2):
        
        self.model.train()
        paralist=['se3_refine_rot.weight','se3_refine_trans.weight']
        for name, param in self.model.named_parameters():
            if name not in paralist:
                param.requires_grad = False

        if self.world_size > 1:
            loader.sampler.set_epoch(self.epoch)
        if self.local_rank == 0:
            pbar = tqdm.tqdm(
                total=len(loader) * loader.batch_size,
                bar_format="{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
            )
        self.local_step = 0
        record=[i[start1:] for i in self.model.loss_record]
        mean_loss=[]
        print(mean_loss)
        for l in record:
            suml=sum(l)
            lenl=len(l)
            meanl=suml/lenl
            mean_loss.append(meanl)
        sorted_pairs=sorted(zip(mean_loss,[i for i in range(len(mean_loss))]))
        mean_loss,idx=zip(*sorted_pairs)
        idx=idx[start2:]
        if self.opt.notinerf:
            idx=[]

        total_loss=0
        for data in loader:
            if data["index"] in idx:
                if data["index"]==idx[-1]:
                    print(self.model.se3_refine[idx[-1]])
                    print(self.model.se3_refine_rot.weight[idx[-1]])
                self.local_step += 1
                self.global_step += 1
                if self.opt.trans:
                    self.optimizer1.zero_grad()
                if self.opt.rot:
                    self.optimizer2.zero_grad()
                with torch.cuda.amp.autocast(enabled=self.fp16):
                    (
                        pred_intensity,
                        gt_intensity,
                        pred_depth,
                        gt_depth,
                        loss,
                    ) = self.train_step(data)

                self.scaler.scale(loss).backward()
                if self.opt.rot:
                    self.scaler.step(self.optimizer2)
                    if self.opt.scheduler:
                        self.scheduler2.step()
                if self.opt.trans:
                    self.scaler.step(self.optimizer1)
                    if self.opt.scheduler:
                        self.scheduler1.step()
                self.scaler.update()


                loss_val = loss.item()
                total_loss += loss_val

                if self.local_rank == 0:
                    if self.report_metric_at_train:
                        for i, metric in enumerate(self.depth_metrics):
                            if i < 2:  # hard code
                                metric.update(pred_intensity, gt_intensity)
                            else:
                                metric.update(pred_depth, gt_depth)

                    if self.use_tensorboardX:
                        self.writer.add_scalar("train/loss", loss_val, self.global_step)
                        self.writer.add_scalar(
                            "train/lr",
                            self.optimizer.param_groups[0]["lr"],
                            self.global_step,
                        )

                    if self.scheduler_update_every_step:
                        pbar.set_description(
                            f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}"
                        )
                    else:
                        pbar.set_description(
                            f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})"
                        )
                    pbar.update(loader.batch_size)
        self.cal_pose_error_when_train_pose(loader)
        if self.ema is not None:
            self.ema.update()
        if self.opt.notinerf:
            average_loss = 0
        else:
            average_loss = total_loss / self.local_step
        self.stats["loss"].append(average_loss)
        self.log(f"average_loss: {average_loss}.")

        if self.local_rank == 0:
            pbar.close()
            if self.report_metric_at_train:
                for metric in self.depth_metrics:
                    self.log(metric.report(), style="red")
                    if self.use_tensorboardX:
                        metric.write(self.writer, self.epoch, prefix="LiDAR_train")
                    metric.clear()

        self.log(f"==> Finished Epoch {self.epoch}.")       
    def save_train_pose(self,loader):
        self.model.eval()
        with torch.no_grad():
            for i,data in enumerate(loader):
                a=123# pass
                #self.model.save_pose(data["index"],data["pose"])
    
    def cal_cosine_similarity(self,desc1, desc2):
        '''
        Input:
            desc1: [B,N,*,3]
            desc2: [B,N,*,3]
        Ret:
            similarity: [B,N,*]
        '''
        inner_product = torch.sum(torch.mul(desc1, desc2), dim=-1, keepdim=False) #36
        norm_1 = torch.norm(desc1, dim=-1, keepdim=False) #36
        norm_2 = torch.norm(desc2, dim=-1, keepdim=False) #36
        similarity = inner_product/(torch.mul(norm_1, norm_2)+1e-6) #36
        return similarity
    def cal_rre_rte(self,pose,gt_pose):
        R_gt=gt_pose[0,:3,:3]
        R=pose[0,:3,:3]
        t_gt=gt_pose[0,:3,3]
        t=pose[0,:3,3]

        tmp = (torch.trace(torch.matmul(R.transpose(0,1),R_gt))-1) / 2
        tmp = torch.clip(tmp, -1.0, 1.0)  
        rre=torch.acos(tmp).item()
        rte=torch.norm(t_gt-t,p=2).item()
        return rre,rte
    def cal_pose_error(self,data):
        idx=data["index"]
        gt_pose=data["pose"]
        noise=self.model.pose_noise[idx]
        pose=gt_pose@noise
        if self.epoch>1:
            se3_refine = self.model.se3_refine[idx]
            pose_refine = self.model.lie.se3_to_SE3(se3_refine,self.device)
            pose = self.model.lie.compose_pair(pose_refine,pose)
        rre,rte=self.cal_rre_rte(pose,gt_pose)
        self.model.rre[idx].append(rre)
        self.model.rte[idx].append(rte)
    def cal_pose_error_when_graph_optim(self,loader):
        for data in loader:
            idx=data["index"]
            gt_pose=data["pose"]
            noise=self.model.pose_noise[idx]
            se3_refine = self.model.se3_refine[idx]
            pose_refine = self.model.lie.se3_to_SE3(se3_refine,self.device)
            pose=gt_pose@noise
            pose_new = self.model.lie.compose_pair(pose_refine,pose)
            rre,rte=self.cal_rre_rte(pose_new,gt_pose)
            self.model.rre_when_graph_optim[idx].append(rre)
            self.model.rte_when_graph_optim[idx].append(rte)
    def cal_pose_error_when_train_pose(self,loader):
        for data in loader:
            idx=data["index"]
            gt_pose=data["pose"]
            noise=self.model.pose_noise[idx]
            se3_refine = self.model.se3_refine[idx]
            pose_refine = self.model.lie.se3_to_SE3(se3_refine,self.device)
            pose=gt_pose@noise
            pose_new = self.model.lie.compose_pair(pose_refine,pose)
            rre,rte=self.cal_rre_rte(pose_new,gt_pose)
            self.model.rre_when_train_pose[idx].append(rre)
            self.model.rte_when_train_pose[idx].append(rte)
    
    def train_step(self, data):
        pred_intensity = None
        gt_intensity = None
        pred_depth = None
        gt_depth = None
        loss = 0
        if self.opt.enable_lidar and (self.epoch%5!=0 or not self.opt.geo_loss):
            outputs_lidar = self.model.render(
                data,
                cal_lidar_color=True,
                staged=False,
                perturb=True,
                force_all_rays=False if self.opt.patch_size == 1 else True,
                **vars(self.opt),
            )
            self.it+=1
            self.model.progress.data.fill_(self.it/self.opt.iters)
            self.model.progress.requires_grad=False
            image_lidar_sample_rays=outputs_lidar["image_lidar_sample_rays"]
            gt_raydrop = image_lidar_sample_rays[:, :, 0]
            gt_intensity = image_lidar_sample_rays[:, :, 1] * gt_raydrop
            gt_depth = image_lidar_sample_rays[:, :, 2] * gt_raydrop

            pred_raydrop = outputs_lidar["intensity"][:, :, 0]
            pred_intensity = outputs_lidar["intensity"][:, :, 1] * gt_raydrop
            pred_depth = outputs_lidar["depth_lidar"] * gt_raydrop
            lidar_loss = (
                self.opt.alpha_d * self.criterion["depth"](pred_depth, gt_depth)
                + self.opt.alpha_r * self.criterion["raydrop"](pred_raydrop, gt_raydrop)
                + self.opt.alpha_i * self.criterion["intensity"](pred_intensity, gt_intensity)
                # + 0.01 * outputs_lidar["loss_dist"] #TODO
            )
            loss=lidar_loss
            if len(loss.shape) == 3:  # [K, B, N]
                loss = loss.mean(0)
            loss = loss.mean()

            if isinstance(self.opt.patch_size_lidar, int):
                patch_size_x, patch_size_y = (
                    self.opt.patch_size_lidar,
                    self.opt.patch_size_lidar,
                )
            elif len(self.opt.patch_size_lidar) == 1:
                patch_size_x, patch_size_y = (
                    self.opt.patch_size_lidar[0],
                    self.opt.patch_size_lidar[0],
                )
            else:
                patch_size_x, patch_size_y = self.opt.patch_size_lidar

            if self.opt.enable_lidar and patch_size_x > 1:
                pred_depth = (
                    pred_depth.view(-1, patch_size_x, patch_size_y, 1)
                    .permute(0, 3, 1, 2)
                    .contiguous()
                    / self.opt.scale
                )
                if self.opt.sobel_grad:
                    pred_grad_x = F.conv2d(
                        pred_depth,
                        torch.tensor(
                            [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32
                        )
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .to(self.device),
                        padding=1,
                    )
                    pred_grad_y = F.conv2d(
                        pred_depth,
                        torch.tensor(
                            [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32
                        )
                        .unsqueeze(0)
                        .unsqueeze(0)
                        .to(self.device),
                        padding=1,
                    )
                else:
                    pred_grad_y = torch.abs(
                        pred_depth[:, :, :-1, :] - pred_depth[:, :, 1:, :]
                    )
                    pred_grad_x = torch.abs(
                        pred_depth[:, :, :, :-1] - pred_depth[:, :, :, 1:]
                    )

                dy = torch.abs(pred_grad_y)
                dx = torch.abs(pred_grad_x)

                if self.opt.grad_norm_smooth:
                    grad_norm = torch.mean(torch.exp(-dx)) + torch.mean(torch.exp(-dy))
                    loss = loss + self.opt.alpha_grad_norm * grad_norm

                if self.opt.spatial_smooth:
                    spatial_loss = torch.mean(dx**2) + torch.mean(dy**2)
                    loss = loss + self.opt.alpha_spatial * spatial_loss

                if self.opt.tv_loss:
                    tv_loss = torch.mean(dx) + torch.mean(dy)
                    loss = loss + self.opt.alpha_tv * tv_loss

                if self.opt.grad_loss:
                    gt_depth = (
                        gt_depth.view(-1, patch_size_x, patch_size_y, 1)
                        .permute(0, 3, 1, 2)
                        .contiguous()
                        / self.opt.scale
                    )
                    gt_raydrop = (
                        gt_raydrop.view(-1, patch_size_x, patch_size_y, 1)
                        .permute(0, 3, 1, 2)
                        .contiguous()
                    )

                    if self.opt.sobel_grad:
                        gt_grad_y = F.conv2d(
                            gt_depth,
                            torch.tensor(
                                [[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32
                            )
                            .unsqueeze(0)
                            .unsqueeze(0)
                            .to(self.device),
                            padding=1,
                        )

                        gt_grad_x = F.conv2d(
                            gt_depth,
                            torch.tensor(
                                [[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32
                            )
                            .unsqueeze(0)
                            .unsqueeze(0)
                            .to(self.device),
                            padding=1,
                        )
                    else:
                        gt_grad_y = gt_depth[:, :, :-1, :] - gt_depth[:, :, 1:, :]
                        gt_grad_x = gt_depth[:, :, :, :-1] - gt_depth[:, :, :, 1:]

                    grad_clip_x = 0.01
                    grad_mask_x = torch.where(torch.abs(gt_grad_x) < grad_clip_x, 1, 0)
                    grad_clip_y = 0.01
                    grad_mask_y = torch.where(torch.abs(gt_grad_y) < grad_clip_y, 1, 0)
                    if self.opt.sobel_grad:
                        mask_dx = gt_raydrop * grad_mask_x
                        mask_dy = gt_raydrop * grad_mask_y
                    else:
                        mask_dx = gt_raydrop[:, :, :, :-1] * grad_mask_x
                        mask_dy = gt_raydrop[:, :, :-1, :] * grad_mask_y

                    if self.opt.depth_grad_loss == "cos":
                        patch_num = pred_grad_x.shape[0]
                        grad_loss = self.criterion["grad"](
                            (pred_grad_x * mask_dx).reshape(patch_num, -1),
                            (gt_grad_x * mask_dx).reshape(patch_num, -1),
                        )
                        grad_loss = 1 - grad_loss
                    else:
                        grad_loss = self.criterion["grad"](
                            pred_grad_x * mask_dx, gt_grad_x * mask_dx
                        )
                    loss = loss + self.opt.alpha_grad * grad_loss.mean()

        elif self.opt.enable_lidar and self.epoch%5==0 and self.opt.geo_loss:
            data["num_rays_lidar"]=-1
            temp=self.opt.num_steps
            self.opt.num_steps=160
            if self.opt.dataloader=="kitti360":
                self.opt.num_steps=80

            outputs_lidar = self.model.render(
                data,
                cal_lidar_color=True,
                staged=True,
                perturb=True,
                force_all_rays=False if self.opt.patch_size == 1 else True,
                **vars(self.opt),
            )
            self.opt.num_steps=temp
            self.it+=1
            self.model.progress.data.fill_(self.it/self.opt.iters)
            self.model.progress.requires_grad=False

            image=outputs_lidar["image"] #1 32 1080 3
            gt_raydrop = image[:, :, :, 0]
            gt_intensity = image[:, :, :, 1] * gt_raydrop
            gt_depth = image[:, :, :, 2] * gt_raydrop
            B_lidar, H_lidar, W_lidar, C_lidar = image.shape
            
            pred_rgb_lidar = outputs_lidar["intensity"].reshape(
                B_lidar, H_lidar, W_lidar, 2
            )
            pred_raydrop = pred_rgb_lidar[:, :, :, 0]
            raydrop_mask = torch.where(pred_raydrop > 0.5, 1, 0)

            pred_intensity = pred_rgb_lidar[:, :, :, 1]
            pred_depth = outputs_lidar["depth_lidar"].reshape(B_lidar, H_lidar, W_lidar)
            lidar_loss = (
                self.opt.alpha_d * self.criterion["depth"](pred_depth * raydrop_mask, gt_depth).mean()
                + self.opt.alpha_r
                * self.criterion["raydrop"](pred_raydrop, gt_raydrop).mean()
                + self.opt.alpha_i
                * self.criterion["intensity"](pred_intensity * raydrop_mask, gt_intensity).mean()
            )

            loss=lidar_loss
            if len(loss.shape) == 3:  
                loss = loss.mean(0)
            loss = loss.mean()

            pcd1,pcd2=self.result_process(data,outputs_lidar)
            custom_loss_cd,idx1,idx2=chamfer_distance_low_capacity(pcd1,pcd2)
            custom_loss_norm=chamfer_based_norm_loss_low_capacity(pcd1,pcd2,idx1,idx2)
            print("cd_loss:",custom_loss_cd)
            print("norm_loss:",custom_loss_norm)
            custom_loss=(custom_loss_cd*1+custom_loss_norm*5)
            loss=0.8*loss+custom_loss
        else:
            lidar_loss = 0
            loss = lidar_loss

        return (
            pred_intensity,
            gt_intensity,
            pred_depth,
            gt_depth,
            loss,
        )
    def eval_step(self, data):
        pred_intensity = None
        pred_depth = None
        pred_depth_crop = None
        pred_raydrop = None
        gt_intensity = None
        gt_depth = None
        gt_depth_crop = None
        gt_raydrop = None
        loss = 0

        if self.opt.enable_lidar:

            outputs_lidar = self.model.render(
                data,
                training=False,
                cal_lidar_color=True,
                staged=True,
                perturb=False,
                #force_all_rays=False if self.opt.patch_size == 1 else True,
                **vars(self.opt),
            )


            image=outputs_lidar["image"]
            gt_raydrop = image[:, :, :, 0]
            gt_intensity = image[:, :, :, 1] * gt_raydrop
            gt_depth = image[:, :, :, 2] * gt_raydrop
            B_lidar, H_lidar, W_lidar, C_lidar = image.shape
            
            pred_rgb_lidar = outputs_lidar["intensity"].reshape(
                B_lidar, H_lidar, W_lidar, 2
            )
            pred_raydrop = pred_rgb_lidar[:, :, :, 0]
            raydrop_mask = torch.where(pred_raydrop > 0.5, 1, 0)

            pred_intensity = pred_rgb_lidar[:, :, :, 1]
            pred_depth = outputs_lidar["depth_lidar"].reshape(B_lidar, H_lidar, W_lidar)
            # raydrop_mask = gt_raydrop  # TODO
            #if self.opt.alpha_r > 0 and (not torch.all(raydrop_mask == 0)):
                #pred_intensity = pred_intensity * raydrop_mask
                #pred_depth = pred_depth * raydrop_mask

            lidar_loss = (
                self.opt.alpha_d * self.criterion["depth"](pred_depth * raydrop_mask, gt_depth).mean()
                + self.opt.alpha_r
                * self.criterion["raydrop"](pred_raydrop, gt_raydrop).mean()
                + self.opt.alpha_i
                * self.criterion["intensity"](pred_intensity * raydrop_mask, gt_intensity).mean()
            )

            # pred_intensity = pred_intensity.unsqueeze(-1)
            # pred_raydrop = pred_raydrop.unsqueeze(-1)
            # gt_intensity = gt_intensity.unsqueeze(-1)
            # gt_raydrop = gt_raydrop.unsqueeze(-1)
        else:
            lidar_loss = 0

        loss = lidar_loss

        return (
            pred_intensity,
            pred_depth,
            pred_depth_crop,
            pred_raydrop,
            gt_intensity,
            gt_depth,
            gt_depth_crop,
            gt_raydrop,
            loss,
        )
    def test_step(self, data, bg_color=None, perturb=False):
        pred_raydrop = None
        pred_intensity = None
        pred_depth = None
        
        if self.opt.enable_lidar:
            H_lidar, W_lidar = data["H_lidar"], data["W_lidar"]
            #self.opt.num_steps=128
            outputs_lidar = self.model.render(
                data,
                training=False,
                cal_lidar_color=True,
                staged=True,
                perturb=False,
                #force_all_rays=False if self.opt.patch_size == 1 else True,
                **vars(self.opt),
            )
            #'''

            pred_rgb_lidar = outputs_lidar["intensity"].reshape(
                -1, H_lidar, W_lidar, 2
            )
            pred_raydrop = pred_rgb_lidar[:, :, :, 0]
            raydrop_mask = torch.where(pred_raydrop > 0.5, 1, 0)
            pred_intensity = pred_rgb_lidar[:, :, :, 1]
            pred_depth = outputs_lidar["depth_lidar"].reshape(-1, H_lidar, W_lidar)
            if self.opt.alpha_r > 0:
                pred_intensity = pred_intensity * raydrop_mask
                pred_depth = pred_depth * raydrop_mask
            
            #'''
        #return 0
        return pred_raydrop, pred_intensity, pred_depth

    def fetch_all_lidar(self,loader):
        all_lidar={}
        
        if self.pcds==None:
            npoints=24000
            if self.opt.dataloader=="kitti360":
                npoints=40000
            pcds=torch.zeros(self.opt.dataloader_size,npoints,4,device=self.opt.device)
            for data in loader:
                idx=data["index"]
                rangemap=data["image"] 
                depth=rangemap[0,:,:,2]
                depth=depth.cpu().numpy()
                pcd = pano_to_lidar(depth, [10,40])/self.opt.scale   
                N = pcd.shape[0]
                p=o3d.geometry.PointCloud()
                p.points = o3d.utility.Vector3dVector(pcd[:,:3])
                p.voxel_down_sample(voxel_size=0.2)
                pcd=np.asarray(p.points)
                print(pcd.shape)
                pcd=np.concatenate([pcd,np.ones((pcd.shape[0],1))],axis=1) 
                if N >= npoints:
                    sample_idx = np.random.choice(N, npoints, replace=False)
                else:
                    sample_idx = np.concatenate((np.arange(N), np.random.choice(N, npoints-N, replace=True)), axis=-1)
                pcd = pcd[sample_idx, :].astype('float32')
                pcd_on_cpu=pcd[None, ...] #1 N 4
                pcd_on_cpu = torch.FloatTensor(pcd_on_cpu)
                device = torch.device(self.opt.device)
                pcd_on_gpu = pcd_on_cpu.to(device)
                pcds[idx,:,:]=pcd_on_gpu
            self.pcds=pcds

        poses=torch.zeros(self.opt.dataloader_size,4,4,device=self.opt.device)
        for data in loader:
            idx=data["index"] 
            pose=self.model.get_pose(data["index"],data["pose"])
            pose[:,:3,3]=pose[:,:3,3]/self.opt.scale
            poses[idx,:,:]=pose
        self.poses=poses
    def matrix_construct(self,loader):
        #_,pcds,poses=self.fetch_all_lidar(loader)
        self.fetch_all_lidar(loader)
        pcds=self.pcds
        poses=self.poses
        transpose_poses=poses.permute(0, 2, 1)
        new_pcds_=torch.matmul(pcds, transpose_poses) #36 N 4
        new_pcds=new_pcds_[:,:,:3] #36 N 3
        N=new_pcds.shape[0]
        #print(N)
        matrix1=new_pcds[1:,:,:]
        matrix11=new_pcds[:N-1,:,:]
        # matrix_source=matrix1
        # matrix_target=matrix11
        matrix2=new_pcds[2:,:,:]
        matrix22=new_pcds[:N-2,:,:]
        matrix_target=torch.cat([matrix1,matrix2],dim=0)   #69 N 3
        matrix_source=torch.cat([matrix11,matrix22],dim=0) #69 N 3
        chamLoss = chamfer_3DDist()
        dist1, dist2, _, _ = chamLoss(
            matrix_source,
            matrix_target,
        )

        if self.opt.dataloader=="kitti360":
            t_control=self.model.progress.data*0.3
            d=0.15
        else:
            t_control=self.model.progress.data*0.5
            d=0.15
            #t_control=0
        #t_control=self.model.progress.data*0.5 #
        dist1_=dist1**0.5
        dist1_[dist1_ <= d] = d
        dist1_to_weight=torch.exp(t_control/dist1_)
        sum_dist1_to_weight=torch.sum(dist1_to_weight,dim=1,keepdim=True)
        weight1=dist1_to_weight/sum_dist1_to_weight
        new_dist1_soft_mean=weight1*dist1

        dist2_=dist2**0.5
        dist2_[dist2_ <= 0.15] = 0.15
        dist2_to_weight=torch.exp(t_control/dist2_)
        sum_dist2_to_weight=torch.sum(dist2_to_weight,dim=1,keepdim=True)
        weight2=dist2_to_weight/sum_dist2_to_weight
        new_dist2_soft_mean=weight2*dist2
        robust_cd=torch.sum(new_dist1_soft_mean)/dist1.shape[0]+torch.sum(new_dist2_soft_mean)/dist1.shape[0]
        print(robust_cd)
        
        #chamfer_dis = dist1.mean() + dist2.mean()
        #print(chamfer_dis)
        #return chamfer_dis
        return robust_cd
    def graph_based_train(self,loader):
        self.optimizer_graph_trans.zero_grad()
        self.optimizer_graph_rot.zero_grad()
        loss=self.matrix_construct(loader)
        self.scaler.scale(loss).backward()
        self.scaler.step(self.optimizer_graph_trans)
        self.scheduler_graph_trans.step()
        self.scaler.step(self.optimizer_graph_rot)
        self.scheduler_graph_rot.step()
        self.scaler.update()
        self.cal_pose_error_when_graph_optim(loader)    
    
    
    
    
    
    def result_process(self,data,outputs_lidar):
        idx=data["index"]
        rangemap=data["image"] 
        depth=rangemap[0,:,:,2] #1 32 1080
        depth=depth.cpu().numpy()
        pcd = pano_to_lidar(depth, [10,40])/self.opt.scale 
        pcd_on_cpu=pcd[None, ...]
        pcd_on_cpu = torch.FloatTensor(pcd_on_cpu)
        device = torch.device(self.opt.device)
        pcd1_on_gpu = pcd_on_cpu.to(device)

        image_lidar_sample_rays=outputs_lidar["image_lidar_sample_rays"]
        gt_raydrop = image_lidar_sample_rays[:, :, 0] 
        pred_raydrop = outputs_lidar["intensity"][:, :, 0] 
        pred_depth = outputs_lidar["depth_lidar"] * gt_raydrop 
        pred_depth=pred_depth.view(depth.shape)
        is_tensor_flag=True
        pcd2 = pano_to_lidar(pred_depth, [10,40],is_tensor=is_tensor_flag)/self.opt.scale
        pcd2_on_gpu=pcd2.unsqueeze(0)

        return pcd1_on_gpu,pcd2_on_gpu
    def train(self, train_loader, test_loader, max_epochs):
        if self.use_tensorboardX and self.local_rank == 0:
            if is_ali_cluster() and self.opt.cluster_summary_path is not None:
                summary_path = self.opt.cluster_summary_path
            else:
                summary_path = os.path.join(self.workspace, "run", self.name)
            self.writer = tensorboardX.SummaryWriter(summary_path)

        change_dataloder = False
        if self.opt.change_patch_size_lidar[0] > 1:
            change_dataloder = True
        for epoch in range(self.epoch + 1, max_epochs + 1):
            self.epoch = epoch
            if change_dataloder:
                if self.epoch % self.opt.change_patch_size_epoch == 0:
                    train_loader._data.patch_size_lidar = (
                        self.opt.change_patch_size_lidar
                    )
                    self.opt.patch_size_lidar = self.opt.change_patch_size_lidar
                else:
                    train_loader._data.patch_size_lidar = 1
                    self.opt.patch_size_lidar = 1
            temp=self.matrix_construct(train_loader)
            print(temp)
            if self.epoch==1:
                self.save_train_pose(train_loader)
            self.train_one_epoch(train_loader)

            if self.workspace is not None and self.local_rank == 0:
                self.save_checkpoint(full=True, best=False)
            
            if self.epoch % self.eval_interval == 0:
                self.evaluate_one_epoch(test_loader)
                self.save_checkpoint(full=False, best=True)
            
            if self.epoch%3==0 or self.epoch==1:
                self.save_train_pose(train_loader)

            if self.opt.no_gt_pose:
                itv1,itv2,itv3=6,10,25
                ep1,ep2,ep3=60,20,20
                reweight_graph=4
                bound1,bound2,bound3=151,650,1300
            else:
                itv1,itv2,itv3=6,15,30 
                ep1,ep2,ep3=60,30,20 
                reweight_graph=1
                bound1,bound2,bound3=151,650,1300 
                if self.opt.dataloader=="kitti360":
                    itv1,itv2,itv3=1,2,1
                    ep1,ep2,ep3=9,5,2 
                    reweight_graph=10 
                    bound1,bound2,bound3=25,975,1950
    
            if self.opt.graph_optim and self.epoch<=bound1 and self.epoch%itv1==0:
                lr_trans=self.lr_scheduler_pose_trans.get_last_lr()[0]
                lr_rot=self.lr_scheduler_pose_rot.get_last_lr()[0]
                self.optimizer_graph_trans=torch.optim.Adam(self.model.get_params_pose_trans(reweight_graph*5*1*0.15*10*lr_trans), betas=(0.9, 0.99), eps=1e-15)
                self.optimizer_graph_rot=torch.optim.Adam(self.model.get_params_pose_rot(reweight_graph*5*0.5*2*lr_rot), betas=(0.9, 0.99), eps=1e-15)
                self.scheduler_graph_trans =torch.optim.lr_scheduler.LambdaLR(self.optimizer_graph_trans, lambda iter: 0.01 ** min(iter / 300, 1))
                self.scheduler_graph_rot = torch.optim.lr_scheduler.LambdaLR(self.optimizer_graph_rot, lambda iter: 0.01 ** min(iter / 300, 1))
                before_pose_rot=self.model.se3_refine_rot.weight.clone()
                before_pose_trans=self.model.se3_refine_trans.weight.clone()
                for ep in range(ep1):
                    self.graph_based_train(train_loader)
                self.save_train_pose(train_loader)
                after_pose_rot=self.model.se3_refine_rot.weight.clone() 
                after_pose_trans=self.model.se3_refine_trans.weight.clone()
                self.optim_direction_rot=(before_pose_rot-after_pose_rot).detach_() #36,3
                self.optim_direction_trans=(before_pose_trans-after_pose_trans).detach_() #36,3
                self.save_train_pose(train_loader)
            
            if self.opt.graph_optim and bound2>self.epoch>bound1 and self.epoch%itv2==0:
                lr_trans=self.lr_scheduler_pose_trans.get_last_lr()[0]
                lr_rot=self.lr_scheduler_pose_rot.get_last_lr()[0]
                self.optimizer_graph_trans=torch.optim.Adam(self.model.get_params_pose_trans(reweight_graph*10*1*0.15*10*lr_trans), betas=(0.9, 0.99), eps=1e-15)
                self.optimizer_graph_rot=torch.optim.Adam(self.model.get_params_pose_rot(reweight_graph*2*0.5*2*lr_rot), betas=(0.9, 0.99), eps=1e-15)
                self.scheduler_graph_trans =torch.optim.lr_scheduler.LambdaLR(self.optimizer_graph_trans, lambda iter: 0.01 ** min(iter / 1200, 1))
                self.scheduler_graph_rot = torch.optim.lr_scheduler.LambdaLR(self.optimizer_graph_rot, lambda iter: 0.01 ** min(iter / 1200, 1))
                before_pose_rot=self.model.se3_refine_rot.weight.clone()
                before_pose_trans=self.model.se3_refine_trans.weight.clone()
                for ep in range(ep2):
                    self.graph_based_train(train_loader)
                self.save_train_pose(train_loader)
                after_pose_rot=self.model.se3_refine_rot.weight.clone() #36,3
                after_pose_trans=self.model.se3_refine_trans.weight.clone()
                self.optim_direction_rot=(before_pose_rot-after_pose_rot).detach_() #36,3
                self.optim_direction_trans=(before_pose_trans-after_pose_trans).detach_() #36,3
                self.save_train_pose(train_loader)
    
            if self.opt.graph_optim and bound2<=self.epoch<=bound3 and self.epoch%itv3==0: 
                lr_trans=self.lr_scheduler_pose_trans.get_last_lr()[0]
                lr_rot=self.lr_scheduler_pose_rot.get_last_lr()[0]
                self.optimizer_graph_trans=torch.optim.Adam(self.model.get_params_pose_trans(reweight_graph*10*1*0.15*10*lr_trans), betas=(0.9, 0.99), eps=1e-15)
                self.optimizer_graph_rot=torch.optim.Adam(self.model.get_params_pose_rot(reweight_graph*1*0.5*2*lr_rot), betas=(0.9, 0.99), eps=1e-15)
                self.scheduler_graph_trans =torch.optim.lr_scheduler.LambdaLR(self.optimizer_graph_trans, lambda iter: 0.01 ** min(iter / 1200, 1))
                self.scheduler_graph_rot = torch.optim.lr_scheduler.LambdaLR(self.optimizer_graph_rot, lambda iter: 0.01 ** min(iter / 1200, 1))
                before_pose_rot=self.model.se3_refine_rot.weight.clone()
                before_pose_trans=self.model.se3_refine_trans.weight.clone()
                for ep in range(ep3):
                    self.graph_based_train(train_loader)
                self.save_train_pose(train_loader)
                
                after_pose_rot=self.model.se3_refine_rot.weight.clone() #36,3
                after_pose_trans=self.model.se3_refine_trans.weight.clone()

                self.optim_direction_rot=(before_pose_rot-after_pose_rot).detach_() #36,3
                self.optim_direction_trans=(before_pose_trans-after_pose_trans).detach_() #36,3
                self.save_train_pose(train_loader)
    
            if self.opt.rot or self.opt.trans:
                if self.epoch%600==0 and self.epoch>700:
                    if self.opt.trans:
                        lr_trans=self.lr_scheduler_pose_trans.get_last_lr()[0] 
                        self.optimizer1=torch.optim.Adam(self.model.get_params_pose_trans(10*0.5*10*lr_trans), betas=(0.9, 0.99), eps=1e-15)
                        if self.opt.scheduler:
                            self.scheduler1 =torch.optim.lr_scheduler.LambdaLR(self.optimizer1, lambda iter: 0.01 ** min(iter / 120, 1))
                    if self.opt.rot:
                        lr_rot=self.lr_scheduler_pose_rot.get_last_lr()[0] 
                        self.optimizer2=torch.optim.Adam(self.model.get_params_pose_rot(5*0.5*2*lr_rot), betas=(0.9, 0.99), eps=1e-15)
                        if self.opt.scheduler:
                            self.scheduler2 = torch.optim.lr_scheduler.LambdaLR(self.optimizer2, lambda iter: 0.01 ** min(iter / 120, 1))
                    for epoch2 in range(1):
                        self.train_pose_one_epoch(train_loader,-5,-5)
            #############################################################################################################
        if self.use_tensorboardX and self.local_rank == 0:
            self.writer.close()
    def evaluate(self, loader, name=None):
        self.use_tensorboardX, use_tensorboardX = False, self.use_tensorboardX
        self.evaluate_one_epoch(loader, name)
        self.use_tensorboardX = use_tensorboardX
    def test(self, loader, save_path=None, name=None, write_video=True):
        if save_path is None:
            save_path = os.path.join(self.workspace, "results")

        if name is None:
            name = f"{self.name}_ep{self.epoch:04d}"

        os.makedirs(save_path, exist_ok=True)

        self.log(f"==> Start Test, save results to {save_path}")

        pbar = tqdm.tqdm(
            total=len(loader) * loader.batch_size,
            bar_format="{percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
        )
        self.model.eval()

        if write_video:
            all_preds = []
            all_preds_depth = []

        with torch.no_grad():
            for i, data in enumerate(loader):
                with torch.cuda.amp.autocast(enabled=self.fp16):
                    preds_raydrop, preds_intensity, preds_depth = self.test_step(data)
                #'''
                if self.opt.enable_lidar:
                    pred_raydrop = preds_raydrop[0].detach().cpu().numpy()
                    pred_raydrop = (np.where(pred_raydrop > 0.5, 1.0, 0.0)).reshape(
                        loader._data.H_lidar, loader._data.W_lidar
                    )
                    pred_raydrop = (pred_raydrop * 255).astype(np.uint8)

                    pred_intensity = preds_intensity[0].detach().cpu().numpy()
                    pred_intensity = (pred_intensity * 255).astype(np.uint8)

                    pred_depth = preds_depth[0].detach().cpu().numpy()
                    pred_lidar = pano_to_lidar(
                        pred_depth / self.opt.scale, loader._data.intrinsics_lidar
                    )
                    if self.opt.dataloader == "nerf_mvl":
                        pred_lidar = filter_bbox_dataset(
                            pred_lidar, data["OBB_local"][:, :3]
                        )

                    np.save(
                        os.path.join(save_path, f"test_{name}_{i:04d}_depth_lidar.npy"),
                        pred_lidar,
                    )

                    pred_depth = (pred_depth * 255).astype(np.uint8)

                    if write_video:
                        all_preds.append(cv2.cvtColor(cv2.applyColorMap(pred_intensity, 1), cv2.COLOR_BGR2RGB))
                        all_preds_depth.append(cv2.cvtColor(cv2.applyColorMap(pred_depth, 20), cv2.COLOR_BGR2RGB))
                    else:
                        cv2.imwrite(
                            os.path.join(save_path, f"test_{name}_{i:04d}_raydrop.png"),
                            pred_raydrop,
                        )
                        cv2.imwrite(
                            os.path.join(
                                save_path, f"test_{name}_{i:04d}_intensity.png"
                            ),
                            cv2.applyColorMap(pred_intensity, 1),
                        )
                        cv2.imwrite(
                            os.path.join(save_path, f"test_{name}_{i:04d}_depth.png"),
                            cv2.applyColorMap(pred_depth, 20),
                        )

                pbar.update(loader.batch_size)
                #'''
        #'''
        if write_video:
            if self.opt.enable_lidar:
                all_preds = np.stack(all_preds, axis=0)
                all_preds_depth = np.stack(all_preds_depth, axis=0)
                imageio.mimwrite(
                    os.path.join(save_path, f"{name}_lidar_rgb.mp4"),
                    all_preds,
                    fps=25,
                    quality=8,
                    macro_block_size=1,
                )
                imageio.mimwrite(
                    os.path.join(save_path, f"{name}_depth.mp4"),
                    all_preds_depth,
                    fps=25,
                    quality=8,
                    macro_block_size=1,
                )

        self.log(f"==> Finished Test.")
        #'''
    def train_one_epoch(self, loader):     
        for name, param in self.model.named_parameters():
            param.requires_grad = True

        log_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime())
        self.log(
            f"[{log_time}] ==> Start Training Epoch {self.epoch}, lr={self.optimizer.param_groups[0]['lr']:.6f} ..."
        )

        total_loss = 0
        if self.local_rank == 0 and self.report_metric_at_train:
            for metric in self.metrics:
                metric.clear()
            for metric in self.depth_metrics:
                metric.clear()

        self.model.train()

        if self.world_size > 1:
            loader.sampler.set_epoch(self.epoch)

        if self.local_rank == 0:
            pbar = tqdm.tqdm(
                total=len(loader) * loader.batch_size,
                bar_format="{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
            )

        self.local_step = 0
        
        idx_test=[2,12,22]

        for data in loader:
            if self.epoch==1:
                self.cal_pose_error(data)
            self.local_step += 1
            self.global_step += 1

            self.optimizer_pose_rot.zero_grad()
            self.optimizer_pose_trans.zero_grad()
            self.optimizer.zero_grad()
            
            with torch.cuda.amp.autocast(enabled=self.fp16):
                (
                    pred_intensity,
                    gt_intensity,
                    pred_depth,
                    gt_depth,
                    loss,
                ) = self.train_step(data)
            self.scaler.scale(loss).backward()
            current_lr_trans=self.lr_scheduler_pose_trans.get_last_lr()[0]
            current_lr_rot=self.lr_scheduler_pose_rot.get_last_lr()[0]
            current_lr_network=self.lr_scheduler.get_last_lr()[0]
            if data["index"] in idx_test:
                reweight=0
                self.optimizer.param_groups[0]['lr']=current_lr_network*reweight#%*0.0001#.item()#,1)
                self.optimizer.param_groups[1]['lr']=current_lr_network*reweight#*0.0001#.item()#,1)
                self.optimizer.param_groups[2]['lr']=current_lr_network*reweight#*0.0001#.item()#,1)
                self.optimizer.param_groups[3]['lr']=current_lr_network*reweight#*0.0001#.item()#,1)
                self.optimizer.param_groups[4]['lr']=current_lr_network*reweight#*0.0001#.item()#,1)
                self.optimizer.param_groups[5]['lr']=current_lr_network*reweight#*0.0001#.item()#,1)
            record=[i[-5:] for i in self.model.loss_record]
            mean_loss=[]
            for l in record:
                suml=sum(l)
                lenl=len(l)
                meanl=suml/(lenl+0.0001)
                mean_loss.append(meanl)
            sorted_pairs=sorted(zip(mean_loss,[i for i in range(len(mean_loss))]))
            mean_loss,idx=zip(*sorted_pairs)
            if self.opt.dataloader=="kitti360":
                indx_loss=idx[-4:]
            else:
                indx_loss=idx[-5:]
            print(indx_loss)
            if data["index"] not in idx_test and data["index"] in indx_loss and 1200>self.epoch and self.opt.reweight:
                print("reweight now")
                if self.epoch<10:
                    reweight_loss=0.1
                else:
                    reweight_loss=min(0.15+0.85*self.epoch/1200,1)
                self.optimizer.param_groups[0]['lr']=current_lr_network*reweight_loss#.item()#,1)
                self.optimizer.param_groups[1]['lr']=current_lr_network*reweight_loss#.item()#,1)
                self.optimizer.param_groups[2]['lr']=current_lr_network*reweight_loss#.item()#,1)
                self.optimizer.param_groups[3]['lr']=current_lr_network*reweight_loss#.item()#,1)
                self.optimizer.param_groups[4]['lr']=current_lr_network*reweight_loss#.item()#,1)
                self.optimizer.param_groups[5]['lr']=current_lr_network*reweight_loss#.item()#,1)

            self.scaler.step(self.optimizer)
            self.optimizer.param_groups[0]['lr']=current_lr_network
            self.optimizer.param_groups[1]['lr']=current_lr_network
            self.optimizer.param_groups[2]['lr']=current_lr_network
            self.optimizer.param_groups[3]['lr']=current_lr_network
            self.optimizer.param_groups[4]['lr']=current_lr_network
            self.optimizer.param_groups[5]['lr']=current_lr_network
            self.optimizer_pose_rot.param_groups[0]['lr']=current_lr_rot
            self.optimizer_pose_trans.param_groups[0]['lr']=current_lr_trans

            if self.opt.rot:
                self.scaler.step(self.optimizer_pose_rot)
            if self.opt.trans:
                self.scaler.step(self.optimizer_pose_trans)
            self.scaler.update()
            if self.scheduler_update_every_step:
                self.lr_scheduler.step()
                if self.opt.rot:
                    self.lr_scheduler_pose_rot.step()
                if self.opt.trans:
                    self.lr_scheduler_pose_trans.step()

            loss_val = loss.item()
            self.cal_pose_error(data)
            self.model.loss_record[data["index"]].append(loss_val)
            total_loss += loss_val

            if self.local_rank == 0:
                if self.report_metric_at_train:
                    for i, metric in enumerate(self.depth_metrics):
                        if i < 2:  # hard code
                            metric.update(pred_intensity, gt_intensity)
                        else:
                            metric.update(pred_depth, gt_depth)

                if self.use_tensorboardX:
                    self.writer.add_scalar("train/loss", loss_val, self.global_step)
                    self.writer.add_scalar(
                        "train/lr",
                        self.optimizer.param_groups[0]["lr"],
                        self.global_step,
                    )

                if self.scheduler_update_every_step:
                    pbar.set_description(
                        f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f}), lr={self.optimizer.param_groups[0]['lr']:.6f}"
                    )
                else:
                    pbar.set_description(
                        f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})"
                    )
                pbar.update(loader.batch_size)
        

        if self.ema is not None:
            self.ema.update()

        average_loss = total_loss / self.local_step
        self.stats["loss"].append(average_loss)
        self.log(f"average_loss: {average_loss}.")

        if self.local_rank == 0:
            pbar.close()
            if self.report_metric_at_train:
                for metric in self.depth_metrics:
                    self.log(metric.report(), style="red")
                    if self.use_tensorboardX:
                        metric.write(self.writer, self.epoch, prefix="LiDAR_train")
                    metric.clear()

        if not self.scheduler_update_every_step:
            if isinstance(
                self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau
            ):
                self.lr_scheduler.step(average_loss)
                if self.opt.rot:
                    self.lr_scheduler_pose_rot.step(average_loss)
                if self.opt.trans:
                    self.lr_scheduler_pose_trans.step(average_loss)
            else:
                self.lr_scheduler.step()
                if self.opt.rot:
                    self.lr_scheduler_pose_rot.step()
                if self.opt.trans:
                    self.lr_scheduler_pose_trans.step()

        self.log(f"==> Finished Epoch {self.epoch}.")
    def evaluate_one_epoch(self, loader, name=None):
        self.log(f"++> Evaluate at epoch {self.epoch} ...")
        if name is None:
            name = f"{self.name}_ep{self.epoch:04d}"

        total_loss = 0
        if self.local_rank == 0:
            for metric in self.metrics:
                metric.clear()
            for metric in self.depth_metrics:
                metric.clear()

        self.model.eval()

        if self.ema is not None:
            self.ema.store()
            self.ema.copy_to()

        if self.local_rank == 0:
            pbar = tqdm.tqdm(
                total=len(loader) * loader.batch_size,
                bar_format="{desc}: {percentage:3.0f}% {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]",
            )

        with torch.no_grad():
            self.local_step = 0

            for data in loader:
                if self.opt.all_eval:
                    eva_idx=[i for i in range(self.opt.dataloader_size)]
                else:
                    eva_idx=[2,12,22]
                if data["index"] in eva_idx:
                    self.local_step += 1

                    with torch.cuda.amp.autocast(enabled=self.fp16):
                        (   
                            preds_intensity,
                            preds_depth,
                            preds_depth_crop,
                            preds_raydrop,
                            gt_intensity,
                            gt_depth,
                            gt_depth_crop,
                            gt_raydrop,
                            loss,
                        ) = self.eval_step(data)
                    

                    preds_mask = torch.where(preds_raydrop > 0.5, 1, 0)
                    gt_mask = gt_raydrop

                    if self.world_size > 1:
                        dist.all_reduce(loss, op=dist.ReduceOp.SUM)
                        loss = loss / self.world_size

                        preds_list = [
                            torch.zeros_like(preds).to(self.device)
                            for _ in range(self.world_size)
                        ]  # [[B, ...], [B, ...], ...]
                        dist.all_gather(preds_list, preds)
                        preds = torch.cat(preds_list, dim=0)

                        preds_depth_list = [
                            torch.zeros_like(preds_depth).to(self.device)
                            for _ in range(self.world_size)
                        ]  # [[B, ...], [B, ...], ...]
                        dist.all_gather(preds_depth_list, preds_depth)
                        preds_depth = torch.cat(preds_depth_list, dim=0)

                        truths_list = [
                            torch.zeros_like(truths).to(self.device)
                            for _ in range(self.world_size)
                        ]  # [[B, ...], [B, ...], ...]
                        dist.all_gather(truths_list, truths)
                        truths = torch.cat(truths_list, dim=0)

                    loss_val = loss.item()
                    total_loss += loss_val

                    # only rank = 0 will perform evaluation.
                    if self.local_rank == 0:
                        for i, metric in enumerate(self.depth_metrics):
                            if i == 0:  # hard code
                                metric.update(preds_raydrop, gt_raydrop)
                            elif i == 1:
                                metric.update(preds_intensity*gt_mask, gt_intensity)
                            elif i == 2:
                                metric.update(preds_depth*gt_mask, gt_depth)
                            elif i == 3:
                                metric.update(preds_intensity*preds_mask, gt_intensity)
                            else:
                                metric.update(preds_depth*preds_mask, gt_depth)
                                #metric.update(preds_depth*gt_mask, gt_depth)

                        if self.opt.enable_lidar:
                            save_path_pred = os.path.join(
                                self.workspace,
                                "validation",
                                f"{name}_{self.local_step:04d}.png",
                            )
                            os.makedirs(os.path.dirname(save_path_pred), exist_ok=True)

                            pred_raydrop = preds_raydrop[0].detach().cpu().numpy()
                            # pred_raydrop = (np.where(pred_raydrop > 0.5, 1.0, 0.0)).reshape(
                            #     loader._data.H_lidar, loader._data.W_lidar
                            # )
                            img_raydrop = (pred_raydrop * 255).astype(np.uint8)
                            img_raydrop = cv2.cvtColor(img_raydrop, cv2.COLOR_GRAY2BGR)

                            pred_intensity = preds_intensity[0].detach().cpu().numpy()
                            img_intensity = (pred_intensity * 255).astype(np.uint8)
                            img_intensity = cv2.applyColorMap(img_intensity, 1) #1, 10, 14, 15
                            
                            pred_depth = preds_depth[0].detach().cpu().numpy()
                            img_depth = (pred_depth * 255).astype(np.uint8)
                            # img_depth = (pred_depth / self.opt.scale).astype(np.uint8)
                            img_depth = cv2.applyColorMap(img_depth, 20)

                            preds_mask = preds_mask[0].detach().cpu().numpy()
                            img_mask = (preds_mask * 255).astype(np.uint8)
                            img_raydrop_masked = cv2.cvtColor(img_mask, cv2.COLOR_GRAY2BGR)

                            img_intensity_masked = (pred_intensity * preds_mask * 255).astype(np.uint8)
                            img_intensity_masked = cv2.applyColorMap(img_intensity_masked, 1) #1, 10, 14, 15
                            
                            img_depth_masked = (pred_depth * preds_mask * 255).astype(np.uint8)
                            img_depth_masked = cv2.applyColorMap(img_depth_masked, 20)

                            img_pred = cv2.vconcat([img_raydrop, img_intensity, img_depth, 
                                                    img_raydrop_masked, img_intensity_masked, img_depth_masked])
                            cv2.imwrite(save_path_pred, img_pred)

                            pred_lidar = pano_to_lidar(pred_depth * preds_mask / self.opt.scale, loader._data.intrinsics_lidar)
            
                            
                            np.save(
                                os.path.join(
                                    self.workspace,
                                    "validation",
                                    f"{name}_{self.local_step:04d}_lidar.npy",
                                ),
                                pred_lidar,
                            )

                        pbar.set_description(
                            f"loss={loss_val:.4f} ({total_loss/self.local_step:.4f})"
                        )
                        pbar.update(loader.batch_size)
                else:
                    pass
        average_loss = total_loss / self.local_step
        self.stats["valid_loss"].append(average_loss)

        if self.local_rank == 0:
            pbar.close()
            if len(self.depth_metrics) > 0:
                # result = self.metrics[0].measure()
                result = self.depth_metrics[-1].measure()[0]  # hard code
                self.stats["results"].append(
                    result if self.best_mode == "min" else -result
                )  # if max mode, use -result
            else:
                self.stats["results"].append(
                    average_loss
                )  # if no metric, choose best by min loss

            np.set_printoptions(linewidth=150, suppress=True, precision=8)
            for i, metric in enumerate(self.depth_metrics):
                if i == 1:
                    self.log(f"=== ↓ GT mask ↓ ==== RMSE{' '*6}MedAE{' '*8}a1{' '*10}a2{' '*10}a3{' '*8}LPIPS{' '*8}SSIM{' '*8}PSNR ===")
                if i == 3:
                    self.log(f"== ↓ Final pred ↓ == RMSE{' '*6}MedAE{' '*8}a1{' '*10}a2{' '*10}a3{' '*8}LPIPS{' '*8}SSIM{' '*8}PSNR ===")
                self.log(metric.report(), style="blue")
                if self.use_tensorboardX:
                    suffix = ""
                    if i==1 or i==2:
                        suffix = "_masked"
                    metric.write(self.writer, self.epoch, prefix="LiDAR_evaluate", suffix=suffix)
                metric.clear()

        if self.ema is not None:
            self.ema.restore()

        self.log(f"++> Evaluate epoch {self.epoch} Finished.")
    
