from dataclasses import dataclass, field
from typing import List
from omegaconf import MISSING

from config.DatasetConfig import DataConfig
from config.ModelConfig import ModelConfig


@dataclass
class LogConfig:
    # wandb
    wandb_entity: str = "XXX"
    wandb_group: str = "XXX"
    wandb_run_name: str = ""
    wandb_project_name: str = "mv_mimic"
    wandb_log_freq: int = 30
    wandb_offline: bool = True
    wandb_local_instance: bool = False

    # logs
    dir_logs: str = "XXX"

    # logging frequencies
    downstream_logging_frequency: int = 60
    coherence_logging_frequency: int = 20000
    likelihood_logging_frequency: int = 2000000
    img_plotting_frequency: int = 60

    # debug level wandb
    debug: bool = True


@dataclass
class EvalConfig:
    # latent representation
    num_samples_train: int = 20000
    max_iteration: int = 10000
    eval_downstream_task: bool = True
    downstream_rf: bool = True
    save_encodings: bool = True

    # classifiers trained in the downstream task - only for MIMIC
    classifier_list: List[str] = field(
        default_factory=lambda: [
            "RF",
            "LR",
        ]
    )

    # metrics used for evaluation - only for MIMIC
    metric_list: List[str] = field(
        default_factory=lambda: [
            "AP",
            "AUROC",
        ]
    )
    #RF PARAMETERS
    f_n_estimators: int = 5000
    f_min_samples_split: int = 5
    f_min_samples_leaf: int = 1
    f_max_features: str = "sqrt"
    f_max_depth: int = 30
    f_criterion: str = "entropy"
    f_bootstrap: bool = True

    # hyperparameter tuning
    hp_tuning: bool = False
    hp_iter: int = 3
    hp_cv: int = 3
    verbosity: int = 2
    # rf search space
    n_estimator: List[int] = field( default_factory=lambda: [50, 100, 150, 200, 250, 300, 500, 1000, 2000, 5000])
    max_depth: List[int] = field(default_factory=lambda: [10, 20, 30, 40, 50]) # add None at runtime
    criterion: List[str] = field(default_factory=lambda: ["gini", "entropy"])
    min_samples_split: List[int] = field(default_factory=lambda: [2, 5, 10])
    min_samples_leaf: List[int] = field(default_factory=lambda: [1, 2, 4])
    max_features: List[str] = field(default_factory=lambda: ["sqrt", "log2"])
    bootstrap: List[bool] = field(default_factory=lambda: [True, False])

    # coherence
    coherence: bool = True


@dataclass
class MyMVWSLConfig:
    seed: int = 1
    checkpoint_metric: str = "val/loss/loss"
    # logger
    log: LogConfig = MISSING
    # dataset
    dataset: DataConfig = MISSING
    # model
    model: ModelConfig = MISSING
    # eval
    eval: EvalConfig = MISSING
