from dataclasses import dataclass

from omegaconf import MISSING

from config.DatasetConfig import DataConfig
from config.ModelConfig import ModelConfig


@dataclass
class LogConfig:
    # wandb
    wandb_entity: str = "PUT YOUR WANDB USER NAME HERE"
    wandb_group: str = "mv_wsl"
    wandb_run_name: str = ""
    wandb_project_name: str = "mvvae"
    wandb_log_freq: int = 50
    wandb_offline: bool = False
    wandb_local_instance: bool = False

    # logs
    dir_logs: str = "PUT YOUR LOG DIR HERE"

    # logging frequencies
    downstream_logging_frequency: int = 50
    coherence_logging_frequency: int = 1
    likelihood_logging_frequency: int = 1000
    img_plotting_frequency: int = 50
    fid_logging_frequency: int = 1

    # debug level wandb
    debug: bool = False


@dataclass
class EvalConfig:
    # latent representation
    num_samples_train: int = 10000
    max_iteration: int = 10000
    eval_downstream_task: bool = True

    # coherence
    coherence: bool = True

    # fid
    path_inception_weights: str = "PUT YOUR INCEPTION NET PATH HERE"


@dataclass
class MyMVWSLConfig:
    seed: int = 0
    checkpoint_metric: str = "val/loss/loss"
    # logger
    log: LogConfig = MISSING
    # dataset
    dataset: DataConfig = MISSING
    # model
    model: ModelConfig = MISSING
    # eval
    eval: EvalConfig = MISSING
