import torch
import torch.nn as nn
import exp_utils as PQ


class FLAGS(PQ.BaseFLAGS):
    eps = 0.0
    neg_coef = 1.0


class ObjEvaluator(nn.Module):
    FLAGS = FLAGS

    def __init__(self, safe_invariant):
        super().__init__()
        self.eps = FLAGS.eps
        self.neg_coef = FLAGS.neg_coef
        self.safe_invariant = safe_invariant

    def forward(self, s):
        L = self.safe_invariant.L(s)
        U = self.safe_invariant.U(s)

        # can't be 1e30: otherwise 100 + 1e30 = 1e30
        mask = (L < 1 + self.eps) & (U > 1 - self.eps)
        obj = U
        return {
            'L': L,
            'U': U,
            's': s,
            'obj': obj,
            'constraint': L - 1 - self.eps,
            'mask': mask,
            'max_obj': (obj * mask).max(),
            'hard_obj': torch.where(L < 1 + self.eps, U - 1 + self.eps, -L - 1000)
        }

