import torch
import torch.nn as nn

class UpperBound(nn.Module):
    def __init__(self, D):
        super(UpperBound, self).__init__()
        self.D = D

    def forward(self, x_s, x_p, b=1e-10):
        assert not x_s.requires_grad
        inv_ar = self.D.ar(x_s, x_p)
        loss = torch.mean(torch.log(inv_ar) + inv_ar)
        return loss
