from dataclasses import dataclass, field
from typing import List, Dict, Optional

from omegaconf import MISSING


@dataclass
class LogConfig:
    # wandb
    wandb_entity: str = "PUT YOUR INFO HERE"
    wandb_group: str = "mw_wsl"
    wandb_run_name: str = "mw_wsl"
    wandb_project_name: str = "mw_wsl"
    wandb_log_freq: int = 50
    wandb_offline: bool = False

    # logs
    dir_logs: str = "PUT YOUR INFO HERE"


@dataclass
class ModelConfig:
    device: str = "cuda"
    batch_size: int = 128
    lr: float = 5e-3
    epochs: int = 250

    latent_dim: int = 2

    resample_eval: bool = False

    # loss hyperparameters
    beta: float = 1.0

    # weight on N(0,1) in mixed prior
    stdnormweight: float = 0.0

    # network architectures
    # use_resnets: bool = True


@dataclass
class EvalConfig:
    # latent representation
    num_samples_train: int = 2560
    max_iteration: int = 10000
    eval_downstream_task: bool = True

    # coherence
    coherence: bool = False


@dataclass
class DRPMModelConfig(ModelConfig):
    name: str = "drpm"
    # drpm
    n_groups: int = 2
    hard_pi: bool = True
    add_gumbel_noise: bool = False

    # temperature annealing
    init_temp: float = 1.0
    final_temp: float = 0.5
    num_steps_annealing: int = 100000

    # loss hyperparameters
    gamma: float = 3.0
    delta: float = 0.03

    # learning drpm parameters
    learn_const_dist_params: bool = False
    encoders_rpm: bool = True


@dataclass
class JointModelConfig(ModelConfig):
    name: str = "joint"


@dataclass
class MixedPriorModelConfig(ModelConfig):
    name: str = "mixedprior"


@dataclass
class DataConfig:
    name: str = MISSING
    num_workers: int = 8
    # num views
    num_views: int = MISSING


@dataclass
class LFPDataConfig(DataConfig):
    name: str = "LFP"
    num_views: int = 5
    dir_data: str = "PUT YOUR INFO HERE"


@dataclass
class SPIKEDataConfig(DataConfig):
    name: str = "SPIKE"
    num_views: int = 5
    dir_data: str = "PUT YOUR INFO HERE"


@dataclass
class MyRATSWSLConfig:
    seed: int = 0
    checkpoint_metric: str = "val/loss/loss"
    # logger
    log: LogConfig = MISSING
    # dataset
    dataset: DataConfig = MISSING
    # model
    model: ModelConfig = MISSING
    # eval
    eval: EvalConfig = MISSING
