
from exp_stability import exp_stability
import argparse
import os
import yaml
from datetime import datetime
from pytz import timezone
import numpy as np

def save_yaml_config(config, path):
    """Load the config file in yaml format.
    Args:
        config (dict object): Config.
        path (str): Path to save the config.
    """
    with open(path, 'w') as outfile:
        yaml.dump(config, outfile, default_flow_style=False)

def get_datetime_str(add_random_str=False):
    """Get string based on current datetime."""
    datetime_str = datetime.now(timezone('EST')).strftime('%Y-%m-%d_%H-%M-%S-%f')[:-3]
    if add_random_str:
        return '{}_{}'.format(datetime_str, np.random.randint(low=1, high=10000))
    else:
        return datetime_str

if __name__ == '__main__':
   parser = argparse.ArgumentParser(description='Stability measured by Euclidean distance')
   parser.add_argument('-d', '--data', default='svmguide3', help='name of the dataset (default: diabetes)')
   parser.add_argument('-t', '--topology_type', default='ring', type=str, help='topology type (default: ring)')
   parser.add_argument('-n', '--num_nodes', default=3, type=int, help='number of nodes (default: 10)')
   parser.add_argument('-e', '--eta_p', default=1, type=float, help='eta_p (default: 0.1)')
   parser.add_argument('--exp_type', default='test', type=str, help='different experiments (default: different_eta)')
   parser.add_argument('--sample_size', default=0, type=int, help='sample_size on each node (default: 790)')
   args = parser.parse_args()

   args.work_dir = os.path.join('experiments',
                                 args.data,
                                 args.exp_type,
                                 get_datetime_str(add_random_str=True)
                                 )
   
   if not os.path.exists(args.work_dir):
      os.makedirs(args.work_dir)
   
   if args is not None:
        save_yaml_config(vars(args), path='{}/args_info.yaml'.format(args.work_dir))

   exp_stability(args.data, args.eta_p, args.topology_type, args.num_nodes, args.sample_size, args.work_dir)