import os 
import json 
from transformers import AutoTokenizer
import argparse
def count_output_tokens(model_path: str, output_jsonl_path: str):
    model_path = os.path.expanduser(model_path)
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    outputs = [json.loads(q) for q in open(output_jsonl_path, "r")]
    average_tokens = sum([len(tokenizer.encode(q['text'])) for q in outputs]) / len(outputs)
    print(f"Average tokens: {average_tokens}")


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--model_path", type=str, default="/mnt/petrelfs/usr/models/models--Qwen--Qwen1.5-1.8B-Chat")
    parser.add_argument("--output_jsonl_path", type=str, required=True)
    args = parser.parse_args()
    
    count_output_tokens(args.model_path, args.output_jsonl_path)