from typing import Optional
from dataclasses import dataclass

from omegaconf import OmegaConf, MISSING


@dataclass
class GINRDiTConfig:

    type: str = "ginr-diffusion"
    ema: Optional[float] = None

    input_size: int  = MISSING
    in_channels: int = MISSING
    hidden_size: int = MISSING
    depth: int = MISSING
    num_heads: int = MISSING
    mlp_ratio: float = MISSING
    class_dropout_prob: float = MISSING
    num_classes: int = MISSING
    learn_sigma: bool = MISSING
    stats_path: Optional[str] = None
    latent_scale: float = 1.0 

    @classmethod
    def create(cls, config):
        defaults = OmegaConf.structured(cls())
        config = OmegaConf.merge(defaults, config)
        return config
