import argparse
import os

import pandas as pd


def get_train_acc(log_file, best_epoch):
    df = pd.read_csv(log_file, sep='\t', skiprows=1, skipfooter=5, engine='python')
    best_train_acc = df[df['Epoch '] == best_epoch][' Train Acc '].iloc[0]
    last_train_acc = df[' Train Acc '].iloc[-1]
    return 100 * best_train_acc, 100 * last_train_acc


def main():
    parser = argparse.ArgumentParser(description='Summarize results to dataframe')
    parser.add_argument('--data', '-d', type=str, required=True,
            help='Path to results data containing LipConvnet* folders')
    args = parser.parse_args()
    # Create dataframe with results
    records = []
    for aux in ['None', '50k', '100k', '200k', '500k', '1m', '5m', '10m']:
        for opt in ['ms', 'oc']:
            for epochs in [200, 400, 600]:
                for blocks in [2, 4, 8]:
                    config = (aux, opt, epochs, blocks)
                    results_dir = os.path.join(args.data, 'LipConvnet_cifar10_{}_70_{}_lr0.1_e{}_{}_lot_32_hh1_cr0.5_res'.format(*config))
                    if not os.path.exists(results_dir):
                        print('Missing aux={} opt={} epochs={} blocks={}'.format(*config))
                        continue
                    log_file = os.path.join(results_dir, 'output.log')
                    lines = open(log_file, 'r').readlines()
                    if not lines[-5].startswith('Total train time:'):
                        print('Unfinished aux={} opt={} epochs={} blocks={}'.format(*config))
                        continue
                    best = lines[-3].split('\t')
                    best_epoch, best_std_acc, best_cert_acc = int(best[0]), 100 * float(best[2]), 100 * float(best[3])
                    last = lines[-1].split('\t')
                    last_std_acc, last_cert_acc = 100 * float(last[3]), 100 * float(last[3])
                    best_train_acc, last_train_acc = get_train_acc(log_file, best_epoch)
                    records.append((*config, best_epoch, best_train_acc, best_std_acc, best_cert_acc, last_train_acc, last_std_acc, last_cert_acc))
    columns = ['aux', 'opt', 'epochs', 'blocks', 'best_epoch', 'best_train_acc', 'best_std_acc', 'best_cert_acc', 'last_train_acc', 'last_std_acc', 'last_cert_acc']
    df = pd.DataFrame.from_records(records, columns=columns)
    df.to_csv('lot-results.csv', index=False, float_format='%.02f')


if __name__ == '__main__':
    main()

