import argparse
import os
import random
from pathlib import Path

import yaml


def main(args):
    Path(f'{args.dataset}_{args.metric}_{args.protected}').mkdir(exist_ok=True)
    args.num_runs = int(args.num_runs)
    args.protected = int(args.protected)
    hyperparam_dict = {
        'num_deep': [2, 5, 10, 15],
        'hid': [16, 32, 64, 128],
        'dropout_p': [0.1, 0.2, 0.3, 0.5]
    }

    hyperparam = {k: random.choice(v) for k, v in hyperparam_dict.items()}

    for i in range(args.num_runs):
        baselines_config = {
            'experiment_name': f'{args.dataset}_{args.metric}_{args.protected}_{i}_baselines',
            'dataset': args.dataset,
            'protected': args.protected,
            'modelpath': f'models/{args.dataset}_{i}_model.pt',
            'metric': args.metric,
            'models': [
                'default',
                'ROC',
                'EqOdds',
                'CalibEqOdds',
                'random',
                'adversarial'
            ],
            'CalibEqOdds': {'cost_constraint': 'fpr'},
            'random': {'num_trials': 201},
            'adversarial': {'epochs': 16, 'critic_steps': 201, 'actor_steps': 101, 'batch_size': 64, 'lambda': 0.75},
            'hyperparameters': hyperparam
        }

        with open(f'{args.dataset}_{args.metric}_{args.protected}/config_{args.dataset}_{args.metric}_{args.protected}_{i}_baselines.yaml', 'w') as fh:
            yaml.dump(baselines_config, fh)


if __name__ == "__main__":
    """ This is executed when run from the command line """
    parser = argparse.ArgumentParser()

    parser.add_argument("dataset", help="Which dataset")
    parser.add_argument("metric", help="which metric")
    parser.add_argument("protected", help="which protected")
    parser.add_argument("num_runs", help="Number of runs")

    args = parser.parse_args()

    main(args)
