import flax.linen as nn
from typing import (Callable, Iterable)
from jax.lax import stop_gradient


class MLP(nn.Module):
    hidden_sizes: Iterable[int]
    bnorm: Iterable[bool]
    act: Callable=nn.activation.relu
    detach_head: bool=False
    
    @nn.compact
    def __call__(self, x, train=True):
        
        if self.detach_head:
            x = stop_gradient(x)
        L = len(self.hidden_sizes)
        for i, l in enumerate(self.hidden_sizes):
            x = nn.Dense(features=l)(x)
            x = nn.BatchNorm(use_running_average=not train, use_bias=False, use_scale=False)(x) if self.bnorm[i] else x
            x = self.act(x) if (i<L-1) else x
        
        return x
    

class ID(nn.Module):
    detach_head: bool=False

    @nn.compact
    def __call__(self, x, train=True):
        if self.detach_head:
            x = x.detach()
        return x