import torch
import torch.nn as nn
import numpy as np


class SafeAlgo2(SafeAlgo):
    def __init__(self, L, U):
        super().__init__(L, U)
        self.log_lambda = nn.Parameter(torch.zeros(FLAGS.opt_s.batch_size), requires_grad=True)
        self.opt_lambda = torch.optim.Adam([self.log_lambda], lr=3e-3)

    def init_lambda(self):
        sum_U_s = self.U_s.sum()
        sum_L_s = self.L_s.sum()
        with torch.no_grad():
            d_U_s = torch.autograd.grad(sum_U_s, self._var_s, create_graph=False)[0]
            d_L_s = torch.autograd.grad(sum_L_s, self._var_s, create_graph=False)[0]
            self.log_lambda.set_(((d_U_s * d_L_s).sum(dim=-1) / d_L_s.norm(dim=-1)**2).log())

    def optimize_s(self):
        obj = self.U_s - self.log_lambda.exp() * self.L_s

        loss = (-obj).mean()
        self.opt_s.zero_grad()
        loss.mean().backward()  # maximize
        self.opt_s.step()

        loss_lambda = -self.log_lambda * self.L_s.detach()
        self.opt_lambda.zero_grad()
        loss_lambda.mean().backward()
        self.opt_lambda.step()

        if np.random.rand() < 0.002:
            print(self.log_lambda.exp()[:4], self.L_s[:4], self.U_s[:4])

        with torch.no_grad():
            v = self._var_s + torch.randn_like(self._var_s) * FLAGS.opt_s.langevin_eps
            device = v.device
            ub = torch.tensor([np.pi / 2, 1], device=device)
            lb = torch.tensor([-np.pi / 2, -1], device=device)
            v = torch.min(ub, torch.max(v, lb))
            self._var_s.set_(v)

    def optimize_L(self):
        obj = self.U_s - self.log_lambda.exp() * self.L_s

        self.opt_L.zero_grad()
        obj.mean().backward()  # minimize
        self.opt_L.step()
