from typing import *

from data_generators import StableDiffusionGenerator
from concept_sampler import SamplingMethod, ConceptSampler
from concept_extractor import LLAVAConceptExtractor, GroundingDinoSamExtractor

def generate_fair_data(dataset_path: str, save_dir: str, class_list: List[str], mode: SamplingMethod, model_name: str = "runwayml/stable-diffusion-v1-5") -> None:
    """
    Generate fair data based on the given class list and sampling mode.

    Args:
        dataset_path (str): The path to the dataset.
        save_dir (str): The directory to save the generated data.
        class_list (List[str]): The list of class names.
        mode (SamplingMethod): The sampling method to be used.
        model_name (str, optional): The model name for the Stable Diffusion Generator. Defaults to "runwayml/stable-diffusion-v1-5".
    """
    extractor = LLAVAConceptExtractor()
    extractor.extract(dataset_path, save_dir)

    # graph_builder = GraphBuilder()
    # graph_builder.build(save_dir)
    graph = None

    # analyzer = GraphAnalyzer(save_dir)
    # analyzer.analyze()

    sampler = ConceptSampler(mode, class_list, graph)
    sampler.sample()

    generator = StableDiffusionGenerator(model_name)
    generator.generate()
