from dataclasses import dataclass, field
from typing import List

from omegaconf import MISSING


@dataclass
class DataConfig:
    name: str = MISSING
    num_workers: int = 40
    # num views
    num_views: int = MISSING


@dataclass
class MimicCXRDataConfig(DataConfig):
    name: str = "mimic_cxr"

    # num views = 2 : lateral (LATERAL + LL) and frontal (AP + PA)
    num_views: int = 2
    dir_data: str = "INSERT PATH"
    dir_cache: str = "INSERT PATH"
    use_cache: bool = True

    # split settings
    splitting_method: str = "random"
    train_val_split: float = 0.8
    test_val_split: float = 0.5
    split_seed: int = 0
    # one_frontal_one_lateral or all_combi_no_missing
    studies_policy: str = "all_combi_no_missing"
    reduced_dataset: bool = False

    # labels
    target_list: List[str] = field(
        default_factory=lambda: [
            "Atelectasis",
            "Cardiomegaly",
            "Consolidation",
            "Edema",
            "Enlarged Cardiomediastinum",
            "Fracture",
            "Lung Lesion",
            "Lung Opacity",
            "No Finding",
            "Pleural Effusion",
            "Pleural Other",
            "Pneumonia",
            "Pneumothorax",
            "Support Devices",
        ]
    )
    n_clfs_outputs: int = 14
    num_labels: int = 14

    img_size: int = 224  # use 224
    image_channels: int = 1

    # copied from celeba - text not used
    num_layers_img: int = 5
    filter_dim_img: int = 64
    filter_dim_text: int = 64
    beta_img: float = 1.0
    beta_text: float = 1.0
    skip_connections_img_weight_a: float = 1.0
    skip_connections_img_weight_b: float = 1.0

    use_rec_weight: bool = True
    include_channels_rec_weight: bool = False

    # img settings
    img_RGB: bool = False  # set to False if you want to use grayscale images
    pre_load_images: bool = False



@dataclass
class PolyMNISTDataConfig(DataConfig):
    num_views: int = 3
    dir_data_base: str = "INSERT PATH"
    dir_clfs_base: str = (
        "INSERT PATH"
    )
    n_clfs_outputs: int = 10
    num_labels: int = 1


@dataclass
class PMvanillaDataConfig(PolyMNISTDataConfig):
    name: str = "PM_vanilla"
    suffix_data_train: str = "PolyMNIST_vanilla/train"
    suffix_data_test: str = "PolyMNIST_vanilla/test"
    suffix_clfs: str = "vanilla_resnet"


@dataclass
class PMtranslatedData50Config(PolyMNISTDataConfig):
    name: str = "PM_translated_50"
    suffix_data_train: str = "PolyMNIST_translated_50/train"
    suffix_data_test: str = "PolyMNIST_translated_50/test"
    suffix_clfs: str = "translatedl50_resnet"


@dataclass
class PMtranslatedData55Config(PolyMNISTDataConfig):
    name: str = "PM_translated_55"
    suffix_data_train: str = "PolyMNIST_translated_55/train"
    suffix_data_test: str = "PolyMNIST_translated_55/test"
    suffix_clfs: str = "translatedl55_resnet"


@dataclass
class PMtranslatedData60Config(PolyMNISTDataConfig):
    name: str = "PM_translated_60"
    suffix_data_train: str = "PolyMNIST_translated_60/train"
    suffix_data_test: str = "PolyMNIST_translated_60/test"
    suffix_clfs: str = "translated60_resnet"


@dataclass
class PMtranslatedData65Config(PolyMNISTDataConfig):
    name: str = "PM_translated_65"
    suffix_data_train: str = "PolyMNIST_translated_65/train"
    suffix_data_test: str = "PolyMNIST_translated_65/test"
    suffix_clfs: str = "translated65_resnet"


@dataclass
class PMtranslatedData70Config(PolyMNISTDataConfig):
    name: str = "PM_translated_70"
    suffix_data_train: str = "translated_70/train"
    suffix_data_test: str = "translated_70/test"
    suffix_clfs: str = "translated70_resnet"


@dataclass
class PMtranslatedData75Config(PolyMNISTDataConfig):
    name: str = "PM_translated75"
    suffix_data_train: str = "PolyMNIST_translated_scale075/train"
    suffix_data_test: str = "PolyMNIST_translated_scale075/test"
    suffix_clfs: str = "translated75_resnet"


@dataclass
class PMtranslatedData50FixedConfig(PolyMNISTDataConfig):
    name: str = "PM_translated_50_fixed"
    suffix_data_train: str = "PolyMNIST_translated_50_fixed/train"
    suffix_data_test: str = "PolyMNIST_translated_50_fixed/test"
    suffix_clfs: str = "translated_50_fixed_resnet"


@dataclass
class PMrotatedDataConfig(PolyMNISTDataConfig):
    name: str = "PM_rotated"
    suffix_data_train: str = "PolyMNIST_rotated/train"
    suffix_data_test: str = "PolyMNIST_rotated/test"
    suffix_clfs: str = "rotated_resnet"


@dataclass
class CelebADataConfig(DataConfig):
    name: str = "celeba"
    num_views: int = 2
    dir_data: str = "INSERT PATH"
    dir_alphabet: str = (
        "INSERT PATH"
    )
    dir_clf: str = (
        "INSERT PATH"
    )

    len_sequence: int = 256
    random_text_ordering: bool = False
    random_text_startindex: bool = True
    img_size: int = 64
    image_channels: int = 3
    crop_size_img: int = 148
    n_clfs_outputs: int = 40
    num_labels: int = 40

    num_features: int = 41  # len(alphabet)
    # num_layers_text: int = 7
    num_layers_img: int = 5
    filter_dim_img: int = 64
    filter_dim_text: int = 64
    beta_img: float = 1.0
    beta_text: float = 1.0
    skip_connections_img_weight_a: float = 1.0
    skip_connections_img_weight_b: float = 1.0
    skip_connections_text_weight_a: float = 1.0
    skip_connections_text_weight_b: float = 1.0

    use_rec_weight: bool = True
    include_channels_rec_weight: bool = False
