
import json
import os
import sys
import time
import argparse

from PIL import Image
from matplotlib import pyplot as plt
from tqdm import tqdm
import numpy as np
import torch

from transformers import DetrImageProcessor, DetrForObjectDetection

parser = argparse.ArgumentParser()
parser.add_argument(
    "--dataset",
    type=str,
    default='nsd',
)

args = parser.parse_args()


def main():
    root_dir = f'/mnt/NSD_dataset/datasets/{args.dataset}'
    total = len(os.listdir(f'{root_dir}/images'))
    batch_size = 200
    print(f'[Total]: {total}')

    processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-101", revision="no_timm")
    model = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-101", revision="no_timm").to('cuda')

    results = []
    for i in tqdm(range(0, total, batch_size)):
        with torch.no_grad():
            image_ids = list(range(i, min(i + batch_size, total)))
            images = []
            for image_id in image_ids:
                image = Image.open(f'{root_dir}/images/{args.dataset}_image_{image_id:06}.png').convert('RGB')
                images.append(np.array(image))

            inputs = processor(images=images, return_tensors="pt").to('cuda')
            outputs = model(**inputs)

            result = processor.post_process_object_detection(outputs, threshold=0.5)
            results.extend(result)

    info_list = []
    for index, result in enumerate(results):
        instances = []
        for label, box, score in zip(result['labels'], result['boxes'], result['scores']):
            instances.append({
                'label_id': int(label.item()),
                'label_str': model.config.id2label[label.item()],
                'bbox': box.tolist(),
                'score': float(score.item())
            })

        info_list.append({
            'image_id': index,
            'instances': instances
        })

    # print(info_list)

    with open(f'{root_dir}/{args.dataset}_instances.json', 'w') as f:
        json.dump(info_list, f, indent=4)


if __name__ == '__main__':
    main()
