import torch
import torch.nn as nn
import torch.nn.functional as F


class MseExpLoss(nn.Module):
    @staticmethod
    def forward(pred, target, reduction="mean"):
        delta = (pred - target).abs()
        mse_loss = F.mse_loss(pred, target, reduction="none")
        exp_loss = 1 - torch.exp(-delta)
        # if delta > 1 use exponential loss and mse otherwise
        # exp loss is rescaled such that if delta == 1 -> mse == exp
        loss = torch.where(delta > 1, 1.5819767713546753 * exp_loss, mse_loss)
        if reduction == "mean":
            return loss.mean()
        elif reduction == "none":
            return loss
        raise NotImplementedError
