from pickle import NONE
import sys,os
import torch
import torch.nn as nn
import numpy as np
import math

import config as cfg


class LossFlow(torch.nn.Module):
    def __init__(self):
        super(LossFlow, self).__init__()

        self.flow_loss = FlowLoss()
        self.cycle_loss = FlowLoss()

    def forward(self, deform_gt, deform_pred, eval=False):
        batch_size = deform_gt.shape[0]
        num_samples = deform_gt.shape[1]

        loss_total = torch.zeros((1), dtype=deform_gt.dtype, device=deform_gt.device)

        # flow error.
        loss_flow = torch.zeros((1), dtype=deform_gt.dtype, device=deform_gt.device)
        if cfg.lambda_flow:
            loss_flow = self.flow_loss(deform_pred.view(batch_size, num_samples, 3), deform_gt.view(batch_size, num_samples, 3))
            loss_total += cfg.lambda_flow * loss_flow
    
        return loss_total, loss_flow
    

class FlowLoss(nn.Module):
    def __init__(self, type='l2', weight_cosin=0.1): #l2 , 'l2_cosin', # 'coord_square'
        super(FlowLoss, self).__init__()

        self.criterion_L1 = nn.L1Loss(reduction='none')
    
        self.type = type
        self.weight_cosin = weight_cosin
        
    def forward(self, points_pred, points_gt):
        
        if self.type == "l2":
            # l2-distance error
            loss = torch.mean(torch.sub(points_pred, points_gt).pow(2).sum(dim=2)/2.0)
        
        elif self.type == "l2_cosin":
            # l2-distance and cosine coordinate errors
            points_pred_norm = torch.norm(points_pred, dim=-1, keepdim=True) + 1e-6
            points_pred_normalize = points_pred / points_pred_norm
            points_gt_norm = torch.norm(points_gt, dim=-1, keepdim=True) + 1e-6
            points_gt_normalize = points_gt / points_gt_norm
            loss = torch.mean( torch.sub(points_pred, points_gt).pow(2).sum(dim=2)) + \
                   self.weight_cosin * torch.mean( (1.0-(points_pred_normalize*points_gt_normalize).sum(dim=2)) )
        
        
        elif self.type == "coord_square":
            # square coordniate errors
            loss = torch.mean( torch.sub(points_pred, points_gt).pow(2))
        
        elif self.type == "coord_abs":
            # absolution coordinate errors
            loss = torch.mean( torch.sub(points_pred, points_gt).abs())

        return loss.sum(-1).mean()