from dataclasses import dataclass
from omegaconf import MISSING


@dataclass
class DataConfig:
    name: str = MISSING
    num_workers: int = 8
    # num views
    num_views: int = MISSING


@dataclass
class PolyMNISTDataConfig(DataConfig):
    num_views: int = 3
    dir_data_base: str = "PUT YOUR DATA FOLDER HERE"
    dir_clfs_base: str = "PUT YPUR PRETRAINED CLF FOLDER HERE"
    n_clfs_outputs: int = 10
    num_labels: int = 1


@dataclass
class PMvanillaDataConfig(PolyMNISTDataConfig):
    name: str = "PM_vanilla"
    suffix_data_train: str = "PolyMNIST_vanilla/train"
    suffix_data_test: str = "PolyMNIST_vanilla/test"
    suffix_clfs: str = "vanilla_resnet"


@dataclass
class PMtranslatedData50Config(PolyMNISTDataConfig):
    name: str = "PM_translated_50"
    suffix_data_train: str = "PolyMNIST_translated_50/train"
    suffix_data_test: str = "PolyMNIST_translated_50/test"
    suffix_clfs: str = "translatedl50_resnet"


@dataclass
class PMtranslatedData55Config(PolyMNISTDataConfig):
    name: str = "PM_translated_55"
    suffix_data_train: str = "PolyMNIST_translated_55/train"
    suffix_data_test: str = "PolyMNIST_translated_55/test"
    suffix_clfs: str = "translatedl55_resnet"


@dataclass
class PMtranslatedData60Config(PolyMNISTDataConfig):
    name: str = "PM_translated_60"
    suffix_data_train: str = "PolyMNIST_translated_60/train"
    suffix_data_test: str = "PolyMNIST_translated_60/test"
    suffix_clfs: str = "translated60_resnet"


@dataclass
class PMtranslatedData65Config(PolyMNISTDataConfig):
    name: str = "PM_translated_65"
    suffix_data_train: str = "PolyMNIST_translated_65/train"
    suffix_data_test: str = "PolyMNIST_translated_65/test"
    suffix_clfs: str = "translated65_resnet"


@dataclass
class PMtranslatedData70Config(PolyMNISTDataConfig):
    name: str = "PM_translated_70"
    suffix_data_train: str = "translated_70/train"
    suffix_data_test: str = "translated_70/test"
    suffix_clfs: str = "translated70_resnet"


@dataclass
class PMtranslatedData75Config(PolyMNISTDataConfig):
    name: str = "PM_translated75"
    suffix_data_train: str = "PolyMNIST_translated_scale075/train"
    suffix_data_test: str = "PolyMNIST_translated_scale075/test"
    suffix_clfs: str = "translated75_resnet"


@dataclass
class PMtranslatedData50FixedConfig(PolyMNISTDataConfig):
    name: str = "PM_translated_50_fixed"
    suffix_data_train: str = "PolyMNIST_translated_50_fixed/train"
    suffix_data_test: str = "PolyMNIST_translated_50_fixed/test"
    suffix_clfs: str = "translated_50_fixed_resnet"


@dataclass
class PMrotatedDataConfig(PolyMNISTDataConfig):
    name: str = "PM_rotated"
    suffix_data_train: str = "PolyMNIST_rotated/train"
    suffix_data_test: str = "PolyMNIST_rotated/test"
    suffix_clfs: str = "rotated_resnet"


@dataclass
class CelebADataConfig(DataConfig):
    name: str = "celeba"
    num_views: int = 2
    dir_data: str = "PUT YOUR DATA FOLDER HERE"
    dir_alphabet: str = "PUT YOUR PATH TO ALPHABET HERE"
    dir_clf: str = "PUT YPUR PRETRAINED CLF FOLDER HERE"

    len_sequence: int = 256
    random_text_ordering: bool = False
    random_text_startindex: bool = True
    img_size: int = 64
    image_channels: int = 3
    crop_size_img: int = 148
    n_clfs_outputs: int = 40
    num_labels: int = 40

    num_features: int = 41  # len(alphabet)
    # num_layers_text: int = 7
    num_layers_img: int = 5
    filter_dim_img: int = 64
    filter_dim_text: int = 64
    beta_img: float = 1.0
    beta_text: float = 1.0
    skip_connections_img_weight_a: float = 1.0
    skip_connections_img_weight_b: float = 1.0
    skip_connections_text_weight_a: float = 1.0
    skip_connections_text_weight_b: float = 1.0

    use_rec_weight: bool = True
    include_channels_rec_weight: bool = False
