import json
import csv
from constants import COCO_CLASS_TRAIN, VG_CLASS_TRAIN, ADE_CLASS_TRAIN, COCO_CLASS_VALIDATION, ADE_CLASS_VALIDATION
from embedding_similarity_utils import max_cosine_similarity

def categorize_errors_total(data,wv_from_bin,threshold, data_flag):
    error_categories = {
        'image_hallucination': 0,
        'text_hallucination': 0,
        'dual_category_errors': 0,
        'other_errors': 0
    }
    total_errors = 0
    
    for entry in data:
        object_set=entry['object_set']
        data_source = entry['data_source']
        if data_flag=='validation':
            coco_class = COCO_CLASS_VALIDATION
            ade_class = ADE_CLASS_VALIDATION
            class_list = coco_class if data_source == 'COCO' else ade_class
        elif data_flag=='train':
            coco_class = COCO_CLASS_TRAIN
            ade_class = ADE_CLASS_TRAIN
            vg_class = VG_CLASS_TRAIN
            class_list = coco_class if data_source == 'COCO' else vg_class if data_source == 'VG' else ade_class
        
        for obj in entry['objects']:
            actual_name = obj['name'].lower()
            prediction = obj.get('prediction', "")
            predicted_name = ""
            if isinstance(prediction, str):
                predicted_name = prediction.lower().strip('.')
            else:
                total_errors += 1
                error_categories['other_errors'] += 1
                continue
            
            if actual_name != predicted_name:
                total_errors += 1
                similarity = max_cosine_similarity(wv_from_bin, predicted_name, object_set)
                is_object_set_error = similarity > threshold
                is_class_error = predicted_name in class_list

                if is_object_set_error and is_class_error:
                    error_categories['dual_category_errors'] += 1
                    error_categories['image_hallucination'] += 1
                    error_categories['text_hallucination'] += 1
                elif is_object_set_error:
                    error_categories['image_hallucination'] += 1
                elif is_class_error:
                    error_categories['text_hallucination'] += 1
                else:
                    error_categories['other_errors'] += 1

    error_percentages = {key: (value / total_errors * 100) if total_errors > 0 else 0 for key, value in error_categories.items()}
    return error_categories, error_percentages


def categorize_errors_bbox(data,wv_from_bin,threshold,data_flag):
    bbox_errors = {}
    for entry in data:
        object_set=entry['object_set']
        data_source = entry['data_source']
        if data_flag=='validation':
            coco_class = COCO_CLASS_VALIDATION
            ade_class = ADE_CLASS_VALIDATION
            class_list = coco_class if data_source == 'COCO' else ade_class
        elif data_flag=='train':
            coco_class = COCO_CLASS_TRAIN
            ade_class = ADE_CLASS_TRAIN
            vg_class = VG_CLASS_TRAIN
            class_list = coco_class if data_source == 'COCO' else vg_class if data_source == 'VG' else ade_class
        
        for obj in entry['objects']:
            bbox_number = obj['bbox_number']
            if bbox_number not in bbox_errors:
                bbox_errors[bbox_number] = {
                    'image_hallucination': 0,
                    'text_hallucination': 0,
                    'dual_category_errors': 0,
                    'other_errors': 0,
                    'total_errors': 0
                }
                
            actual_name = obj['name'].lower()
            prediction = obj.get('prediction', "")
            predicted_name = ""
            if isinstance(prediction, str):
                predicted_name = prediction.lower().strip('.')
            else:
                bbox_errors[bbox_number]['total_errors'] += 1
                bbox_errors[bbox_number]['other_errors'] += 1
                continue
            if actual_name != predicted_name:
                bbox_errors[bbox_number]['total_errors'] += 1
                similarity = max_cosine_similarity(wv_from_bin, predicted_name, object_set)
                is_object_set_error = similarity > threshold
                is_class_error = predicted_name in class_list

                if is_object_set_error and is_class_error:
                    bbox_errors[bbox_number]['dual_category_errors'] += 1
                    bbox_errors[bbox_number]['image_hallucination']+=1
                    bbox_errors[bbox_number]['text_hallucination'] += 1
                elif is_object_set_error:
                    bbox_errors[bbox_number]['image_hallucination'] += 1
                    
                elif is_class_error:
                    bbox_errors[bbox_number]['text_hallucination'] += 1
                else:
                    bbox_errors[bbox_number]['other_errors'] += 1
                    
    return bbox_errors



def save_to_csv_total(data, filename):
    with open(filename, 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['Image Hallucination (%)', 'Text Hallucination (%)', 'Dual Category Errors (%)', 'Other Errors (%)'])
        percentages = data[1]
        writer.writerow([f"{percentages['image_hallucination']:.2f}%", f"{percentages['text_hallucination']:.2f}%", f"{percentages['dual_category_errors']:.2f}%", f"{percentages['other_errors']:.2f}%"])

def save_to_csv_bbox(bbox_data, filename):
    with open(filename, 'w', newline='') as file:
        writer = csv.writer(file)
        writer.writerow(['BBox Number', 'Image Hallucination %', 'Text Hallucination %', 'Dual Category Errors %', 'Other Errors %'])
        for bbox_number, data in bbox_data.items():
            total_errors = data['total_errors']
            percentages = [(data[key] / total_errors * 100) if total_errors > 0 else 0 for key in ['image_hallucination', 'text_hallucination', 'dual_category_errors', 'other_errors']]
            formatted_percentages = [f'{percent:.2f}%' for percent in percentages]
            writer.writerow([bbox_number] + formatted_percentages)
            writer.writerow([]) 