import csv
import os

from dataset_diy.modules import template


def getOnePrompts():
    objects = template.occupations
    for target in objects:
        print("a photo of {}.".format(target))


def getTwoPrompts():
    objects = template.occupations
    for i in range(len(objects)):
        for j in range(i + 1, len(objects)):
            print("a photo of {} and {}.".format(objects[i], objects[j]))


def getMapping(data_row, image_path):
    data_label = []
    for image in os.listdir(image_path):
        if image.endswith(".png"):
            image_dir = os.path.join(image_path, image)
            # get the last two level dir
            image_dir = image_dir.split("/")[-2:]
            image_dir = "/".join(image_dir)
            data_label.append([data_row[0], data_row[1], image_dir])
    return data_label


file_path = "../src/key prompts.csv"
image_path = "../data/2024-01-17-sdxl"
data_label_path = "../data/2024-01-17-sdxl/data_label.csv"


def getTrainData(train_data_path):
    prompt = []
    image = []
    stereotype = []

    with open(train_data_path, "r") as f:
        file = csv.DictReader(f)
        for line in file:
            prompt.append(line["Prompts"])

            farther_dir = os.path.dirname(train_data_path)
            image_path = os.path.join(farther_dir, line["Image_ID"])
            image.append(image_path)

            sensitive_att = line["Gender"] + "," + line["Race"] + "," + line["Region"] + "," + line["Religion"]
            stereotype.append(sensitive_att)
    return prompt, image, stereotype


# f = open(data_label_path, "a", newline="")
#
# for group in data_label:
#     for item in group:
#         try:
#             writer = csv.writer(f)
#             writer.writerow(item)
#         except:
#             print(item)
# f.close()
