from dataclasses import dataclass, field
from typing import Dict, List, Optional

from omegaconf import OmegaConf, MISSING


@dataclass
class FourierMappingConfig:
    type: Optional[str] = "deterministic_transinr"
    trainable: bool = False
    ff_sigma: int = 1024
    ff_dim: int = 128
    ff_sigma_min: Optional[float] = None


@dataclass
class HypoNetActivationConfig:
    type: str = "relu"
    siren_w0: Optional[float] = 30.0


@dataclass
class CrossAttentionBlockConfig:
    embed_dim: int = MISSING
    n_head: int = MISSING
    dropout: float = 0.0
    bias: bool = True
    input_layernorm: bool = False
    residual: bool = True


@dataclass
class DecoderWithCrossAttentionConfig:
    n_layer: int = 5
    hidden_dim: List[int] = MISSING
    use_bias: bool = True
    input_dim: int = 2
    output_dim: int = 3
    output_bias: float = 0.5

    latent_dim: int = MISSING

    fourier_mapping: Optional[FourierMappingConfig] = FourierMappingConfig()
    activation: HypoNetActivationConfig = HypoNetActivationConfig()

    cross_attention: CrossAttentionBlockConfig = CrossAttentionBlockConfig()

    attention_layer_idxs: List[int] = MISSING

    normalize_mlp_weights: bool = False

    def __post_init__(self):
        if self.fourier_mapping is not None:
            if self.fourier_mapping.type is None:
                self.fourier_mapping = None
                

CrossAttentionLayerConfig = CrossAttentionBlockConfig


@dataclass
class QueryFourierMappingConfig:
    type: Optional[str] = "deterministic_transinr_range"
    trainable: bool = False
    ff_dim: int = 128
    ff_sigma: List[float] = MISSING
    ff_sigma_min: List[Optional[float]] = MISSING


@dataclass
class MultiBandDecoderWithCrossAttentionConfig:
    n_mlp_layer: int = 2
    hidden_dim: List[int] = MISSING
    use_bias: bool = True
    input_dim: int = 2
    output_dim: int = 3
    output_bias: float = 0.5

    latent_dim: int = MISSING
    num_latent_tokens: int = MISSING

    attn_fourier_mapping: Optional[FourierMappingConfig] = FourierMappingConfig()
    query_fourier_mapping: QueryFourierMappingConfig = QueryFourierMappingConfig()
    activation: HypoNetActivationConfig = HypoNetActivationConfig()

    cross_attention: CrossAttentionBlockConfig = CrossAttentionBlockConfig()

    n_mod_layer: int = 1
    output_from_every_layer: bool = False
    use_first_query_for_init_hidden: bool = False

    normalize_mlp_weights: bool = False
