from PIL import Image
import torch
import numpy as np
import json

import sys
    
# ==================
#        Viz
# ==================
def make_grid(images, size=(128, 128)):
    images = [img.resize(size) for img in images]
    total_width = sum(img.width for img in images)
    max_height = max(img.height for img in images)
    stitched_image = Image.new('RGB', (total_width, max_height))
    x_offset = 0
    for img in images:
        stitched_image.paste(img, (x_offset, 0))
        x_offset += img.width
    return stitched_image

# ==================
#      General
# ==================
def exclude(lst, el):
    lst = np.array(lst)
    lst = lst[lst != el]
    return lst

# ==================
#       ICL
# ==================
def idx_to_mc(idx):
    return chr(idx + 65)

def mc_to_idx(mc):
    mc = mc.upper()
    if 'A' <= mc <= 'Z':
        return ord(mc) - ord('A')
    else:
        return -1

def format_options(options):
    return "\n".join([idx_to_mc(i) + ". " + options[i] for i in range(len(options))])

def prepare_icl_prompt(texts, answers, prefix=None, exclude_last=True):
    conversation = []
    for i, answer in enumerate(answers):
        prompt = ""
        # if prefix is not None:
        #     prompt += prefix
        prompt += f"USER: <image>\n{texts[i]} ASSISTANT:"
        if i < len(answers) - 1 or not exclude_last:
            prompt += " " + idx_to_mc(answer)
        conversation.append(prompt)
    return conversation