import argparse
import os

import pandas as pd


def get_test_stats(log_file, expected_epochs):
    df = pd.read_csv(log_file, sep='\t')
    best_row = results.loc[results[' Test Robust '].idxmax()]
    best_train_acc = df[df['epoch '] == best_epoch][' acc '].iloc[0]
    last_train_acc = df[' acc '].iloc[-1]
    return 100 * best_train_acc, 100 * last_train_acc


def main():
    parser = argparse.ArgumentParser(description='Summarize results to dataframe')
    parser.add_argument('--data', '-d', type=str, required=True,
            help='Path to results data containing SortNet* folders')
    args = parser.parse_args()
    # Create dataframe with results
    records = []
    for aux in ['None', '50k', '100k', '200k', '500k', '1m', '5m', '10m']:
        # Dropout 1.0 is Linf-net
        for dropout in ['None', 0.85, 1.0]:
            epochs_range = [3000, 6000] if dropout != 1.0 else [800, 1600]
            for epochs in epochs_range:
                config = (aux, dropout, epochs)
                if dropout != 1.0:
                    epoch_str = '0,0,{},{},{}'.format(epochs // 15, epochs - epochs // 60, epochs)
                    results_dir = 'CIFAR10_{}_0.7_SortMLPModel(depth=6,width=5120,scalar=True,dropout={})_mixture(lam0=0.2,lam_end=0.002)_p8.0_p_end1000.0_eps0.09411_epoch{}_bs512_lr0.02_wd0.02__2021__0'.format(aux, dropout, epoch_str)
                else:
                    epoch_str = '0,0,{},{},{}'.format(epochs // 8, epochs - epochs // 16, epochs)
                    results_dir = 'CIFAR10_{}_0.7_SortMLPModel(depth=6,width=5120,identity_val=10.0,scalar=False,dropout={})_hinge_p8.0_p_end1000.0_eps0.1569_epoch{}_bs512_lr0.02_wd0.02__2021__0'.format(aux, dropout, epoch_str)
                results_dir = os.path.join(args.data, results_dir)
                if not os.path.exists(results_dir):
                    print('Missing aux={} dropout={} epochs={}'.format(*config))
                    continue
                test_file = os.path.join(results_dir, 'test_inf.log')
                results = pd.read_csv(test_file, sep='\t')
                if results.shape[0] != epochs // 5 + 1:
                    print('Unfinished aux={} dropout={} epochs={}'.format(*config))
                    continue
                best_row = results.loc[results['certified'].idxmax()]
                best_epoch, best_std_acc, best_cert_acc = int(best_row['epoch']), 100 * best_row['acc'], 100 * best_row['certified']
                last_row = results.iloc[-1]
                last_std_acc, last_cert_acc = 100 * last_row['acc'], 100 * last_row['certified']
                train_file = os.path.join(results_dir, 'train_inf.log')
                results = pd.read_csv(train_file, sep='\t')
                if results.shape[0] != epochs // 5:
                    print('Unfinished aux={} dropout={} epochs={}'.format(*config))
                    continue
                best_train_acc = 100 * results[results['epoch'] == best_epoch]['acc'].iloc[0]
                last_train_acc = 100 * results['acc'].iloc[-1]
                records.append((*config, best_epoch, best_train_acc, best_std_acc, best_cert_acc, last_train_acc, last_std_acc, last_cert_acc))
    columns = ['aux', 'dropout', 'epochs', 'best_epoch', 'best_train_acc', 'best_std_acc', 'best_cert_acc', 'last_train_acc', 'last_std_acc', 'last_cert_acc']
    df = pd.DataFrame.from_records(records, columns=columns)
    df.to_csv('sortnet-results.csv', index=False, float_format='%.02f')


if __name__ == '__main__':
    main()

