from dataclasses import dataclass

from ...util.hparams import HyperParams
from typing import Optional, Any, List
import yaml


@dataclass
class MENDHyperParams(HyperParams):
    inner_params: List[str]

    archive: Any

    # Method
    lr: float
    edit_lr: float
    lr_lr: float
    lr_scale: float
    seed: int
    cedit: float
    cloc: float
    cbase: float
    dropout: float
    train_base: bool
    no_grad_layers: Any
    one_sided: bool
    n_hidden: int
    hidden_dim: Any
    init: str
    norm: bool
    combine: bool
    x_only: bool
    delta_only: bool
    act: str
    rank: int
    mlp_class: str
    shared: bool


    batch_size: int = 1
    @classmethod
    def from_hparams(cls, hparams_name_or_path: str):

        if '.yaml' not in hparams_name_or_path:
            hparams_name_or_path = hparams_name_or_path + '.yaml'

        with open(hparams_name_or_path, "r") as stream:
            config = yaml.safe_load(stream)
            config = super().construct_float_from_scientific_notation(config)

        return cls(**config)
