import json
import argparse

parser = argparse.ArgumentParser()
parser.add_argument('--refs', default='no', choices=['no', 'gt', 'ret'])
parser.add_argument('--ref-pretrain', action='store_true')
args = parser.parse_args()

base_path = '../data/latest/proofwiki.json'
base = json.load(open(base_path))
base = base['dataset']
refs = base['theorems'] + base['definitions'] + base['others']

if args.refs in ['no', 'gt']:
    input_path = '../data/latest/proofwiki__refs_ground_truth.json'
elif args.refs == 'ret':
    input_path = '../data/latest/proofwiki__refs_retrieved.json'

with open(input_path) as f:
    ds = json.load(f)

import random
random.seed(19260817)


for split in ['train', 'valid', 'test']:
    pairs = []
    for theorem in ds[split]:
        prompt = f'<theorem> <title> {theorem["title"]} </title> <content> {theorem["text"]} </content> </theorem>'
        if args.refs in ['gt', 'ret']:
            for ref in theorem['ctxs'][:20]:
                prompt += f' <reference> {ref["title"]} </reference>'
        prompt += ' <proof>'
        completion = f' {theorem["target"]} </proof>'
        pairs.append({'prompt': prompt, 'completion': completion})

    if args.ref_pretrain and split == 'train':
        line_sep = '\\n'
        for ref in refs:
            prompt = f'<{ref["type"]}> <title> {ref["title"]} </title> <content>'
            completion = f' {line_sep.join(ref["contents"])} </content> </{ref["type"]}>'
            pairs.append({'prompt': prompt, 'completion': completion})

        random.shuffle(pairs)

    output_path = f'data/latest/gpt3ft_proofwiki_{args.refs}refs{"_ref-pretrain" if args.ref_pretrain else ""}.{split}.jsonl'
    with open(output_path, 'w') as f:
        for pair in pairs:
            f.write(json.dumps(pair) + '\n')

