import torch
import exp_utils as PQ


class FLAGS(PQ.BaseFLAGS):
    class obj(PQ.BaseFLAGS):
        eps = 0.0
        neg_coef = 1.0


class CRABS(torch.nn.Module):
    FLAGS = FLAGS

    def __init__(self, barrier, uncertainty, policy, env_barrier_fn, normalizer):
        super().__init__()
        self.barrier = barrier
        self.policy = policy
        self.uncertainty = uncertainty
        self.env_barrier_fn = env_barrier_fn
        self.normalizer = normalizer

    def U(self, states, actions=None):
        if actions is None:
            actions = self.policy(states)
        return self.uncertainty(states, actions)

    def L(self, states):
        return self.barrier(states)

    def obj_eval(self, s):
        L = self.L(s)
        U = self.U(s)

        # can't be 1e30: otherwise 100 + 1e30 = 1e30
        eps = FLAGS.obj.eps
        obj = U + eps
        mask = (L < 0) & (U + eps > 0)
        return {
            'L': L,
            'U': U,
            's': s,
            'obj': obj,
            'constraint': L,
            'mask': mask,
            'max_obj': (obj * mask).max(),
            'hard_obj': torch.where(L < 0, U + eps, -L - 1000)
        }
