import argparse
import os
import torch
import shutil

def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument('--data', type=str, default='./data')
    parser.add_argument('--dataset', type=str, default='mnist', choices=['mnist'])
    parser.add_argument('--log_name', type=str, required=True)
    parser.add_argument('--alg', type=str, default='vae', choices=['our', 'vae'])
    parser.add_argument('--log_path', type=str, default='logs')
    parser.add_argument('--no-cuda', action='store_true')
    parser.add_argument('--beta', type=float, default=1)
    parser.add_argument('--batch_size', type=int, default=128)
    parser.add_argument('--latent_dim', type=int, default=50)
    parser.add_argument('--num_epochs', type=int, default=50)


    # OUR ARGS
    parser.add_argument('--lambda_', type=float, default=1e-3)
    parser.add_argument('--num_samples', type=int, default=3)
    parser.add_argument('--rotate', type=float, default=22.5)
    parser.add_argument('--translate', type=float, default=4)
    parser.add_argument('--trans_distance', choices=['kl', 'wass'], default='wass', type=str)
    args = parser.parse_args()

    args.cuda = True if torch.cuda.is_available() and not args.no_cuda else False
    assert args.alg == 'our'

    return args
