import torch
import torch.nn as nn

import exp_utils as PQ


class SBlockGradOptimizer(nn.Module):
    def __init__(self, dim_state, crabs, normalizer, batch_size=10000, reset_block=0):
        super().__init__()
        self.crabs = crabs
        self.s = nn.Parameter(torch.randn(batch_size, dim_state), requires_grad=True)
        self.opt = torch.optim.Adam([self.s], lr=1e-3, betas=(0, 0.999))
        self.normalizer = normalizer
        self.index = 0
        self.reset_block = reset_block

    def step(self):
        # self.reinit()
        result = self.crabs.obj_eval(self.s)
        obj = result['hard_obj']
        loss = (-obj).mean()

        self.opt.zero_grad()
        loss.mean().backward()
        self.opt.step()
        return {
            'optimal': obj.max().item(),
        }

    def discard_worst(self, ratio=0.5):
        n = len(self.s)
        result = self.crabs.obj_eval(self.s)
        _, indices = torch.topk(result['hard_obj'], int(n * ratio), largest=False)
        nn.init.normal_(self.s[indices])

    @torch.no_grad()
    def reinit(self):
        nn.init.normal_(self.s[self.index:self.index + self.reset_block])
        self.index = (self.index + self.reset_block) % len(self.s)

    def evaluate(self, *, step):
        result = self.crabs.obj_eval(self.s)
        hardD = result['hard_obj']
        constraint = result['constraint']
        U = result['obj']
        idx = hardD.argmax()
        nmPQ = self.crabs.barrier.net[0]
        # print(nmPQ(self.s[idx]).cpu().detach().numpy())
        PQ.log.info(f"[S cont grad opt] hardD = {hardD.max().item():.6f}, constraint = {constraint[idx].item():.6f}, "
                    f"U = {U[idx].item():.6f}, "
                    f"inside = {(constraint <= 0).sum().item()}")

        return {
            'optimal': hardD.max().item(),
        }
