import torch.nn as nn
import torch

class BPRLoss(nn.Module):
    def __init__(self, gamma=1e-10, reduction: str = 'mean'):
        assert reduction in ['mean', 'sum', 'none']
        super(BPRLoss, self).__init__()
        self.reduction = reduction
        self.gamma = gamma

    def forward(self, pos_pd: torch.FloatTensor, neg_pd: torch.FloatTensor):
        assert len(pos_pd.shape) == 1 and len(neg_pd.shape) == 1 and pos_pd.shape[0] == neg_pd.shape[0]
        logits = - torch.log(torch.sigmoid(pos_pd - neg_pd) + self.gamma)
        if self.reduction == 'mean':
            return logits.mean()
        elif self.reduction == 'sum':
            return logits.sum()
        else:
            return logits

class MarginLossZeroOne(nn.Module):
    def __init__(self, margin=0.5, reduction: str = 'mean') -> None:
        assert reduction in ['mean', 'sum', 'none']
        super().__init__()
        self.margin = margin
        self.reduction = reduction

    def forward(self, pos_pd, neg_pd):
        logits = self.margin - (pos_pd - neg_pd)
        logits[logits < 0] = 0.0
        if self.reduction == 'mean':
            return logits.mean()
        elif self.reduction == 'sum':
            return logits.sum()
        else:
            return logits
