import torch 
import torch.nn as nn 

from streaming_gfn.utils import Environment 

class Set(Environment): 

    def __init__(self, src_size, set_size, batch_size, log_reward, device='cpu'): 
        super(Set, self).__init__(batch_size, set_size, log_reward, device=device)
        self.src_size = src_size 
        self.set_size = set_size 
        self.state = torch.zeros((self.batch_size, self.src_size), device=self.device, dtype=int)  
        self.forward_mask = torch.ones((self.batch_size, self.src_size), device=self.device) 
        self.backward_mask = torch.zeros((self.batch_size, self.src_size), device=self.device) 

    @torch.no_grad() 
    def apply(self, indices): 
        self.state[self.batch_ids, indices] = self.state[self.batch_ids, indices] + 1  
        self.is_initial[:] = 0. 
        self.stopped[:] = (self.state.sum(dim=1) == self.set_size)
        self.forward_mask = 1 - self.state.type(self.forward_mask.dtype) 
        self.backward_mask = self.state.type(self.backward_mask.dtype) 

    @torch.no_grad() 
    def backward(self, indices): 
        self.state[self.batch_ids, indices] = self.state[self.batch_ids, indices] - 1 
        self.is_initial[:] = (self.state.sum(dim=1) == 0) 
        self.stopped[:] = 0 
        self.forward_mask = (1 - self.state).type(self.forward_mask.dtype) 
        self.backward_mask = self.state.type(self.backward_mask.dtype) 
        return indices 

    @torch.no_grad() 
    def merge(self, batch_state): 
        super().merge(batch_state) 
        self.state = torch.vstack([self.state, batch_state.state]) 
    
    @property 
    def unique_input(self): 
        return self.state.type(self.backward_mask.dtype)   
