import argparse
import collections
import statistics
from collections import defaultdict
from copy import deepcopy
from datetime import datetime

import matplotlib.pyplot as plt
from embedding import *
from sklearn.cluster import DBSCAN
from trajectory import Action, ActionType, Trajectory

from data_generation.clueweb.constants import VALID_ACTIONS

LOG_TAG = datetime.now().strftime("%m%d-%H%M")
LOG_FILE = f"./data_vis/.tmp/{LOG_TAG}.txt"


def retrieve_obs_embedding(trajectory_list, save_tag: str = ""):
    text = {}
    target_text = []
    for trajectory in trajectory_list:
        target_text.append(trajectory.obs)
    m2w_website_trajectories = Trajectory.from_file(
        "data_vis/data/website_test_code_model.json"
    )
    m2w_task_trajectories = Trajectory.from_file(
        "data_vis/data/task_test_code_model.json"
    )
    m2w_domain_trajectories = Trajectory.from_file(
        "data_vis/data/domain_test_code_model.json"
    )
    m2w_train_trajectories = Trajectory.from_file(
        "data_vis/data/final_output_full_direct_lkup_actnid_type_2_injected_scraped_test_code_model.json"
    )
    webarena_data = process_webarena_data(
        "data_vis/data/WA_json_dump.json", "observation"
    )

    m2w_website_text, m2w_domain_text, m2w_task_text, m2w_train_text = (
        [],
        [],
        [],
        [],
    )
    for m2w_website_trajectory in m2w_website_trajectories:
        m2w_website_text.append(m2w_website_trajectory.obs)
    for m2w_domain_trajectory in m2w_domain_trajectories:
        m2w_domain_text.append(m2w_domain_trajectory.obs)
    for m2w_task_trajectory in m2w_task_trajectories:
        m2w_task_text.append(m2w_task_trajectory.obs)
    for m2w_train_trajectory in m2w_train_trajectories:
        m2w_train_text.append(m2w_train_trajectory.obs)
    # text["mind2web_train"] = m2w_train_text
    # text["mind2web_domain"] = m2w_domain_text
    # text["mind2web_website"] = m2w_website_text
    # text["mind2web_task"] = m2w_task_text
    text["mind2web_all"] = (
        m2w_train_text + m2w_domain_text + m2w_website_text + m2w_task_text
    )
    text["webarena"] = webarena_data
    text["target_text"] = target_text
    plot_embedding(text, "obs", LOG_TAG)


def retrieve_task_embedding(trajectory_list, save_tag: str = ""):
    text = {}
    target_text = []
    for trajectory in trajectory_list:
        target_text.append(trajectory.objective)

    m2w_website_trajectories = Trajectory.from_file(
        "data_vis/data/website_test_code_model.json"
    )
    m2w_task_trajectories = Trajectory.from_file(
        "data_vis/data/task_test_code_model.json"
    )
    m2w_domain_trajectories = Trajectory.from_file(
        "data_vis/data/domain_test_code_model.json"
    )
    m2w_train_trajectories = Trajectory.from_file(
        "data_vis/data/final_output_full_direct_lkup_actnid_type_2_injected_scraped_test_code_model.json"
    )
    webarena_data = process_webarena_data(
        "data_vis/data/WA_json_dump.json", "objective"
    )

    m2w_website_text, m2w_domain_text, m2w_task_text, m2w_train_text = (
        [],
        [],
        [],
        [],
    )
    for m2w_website_trajectory in m2w_website_trajectories:
        m2w_website_text.append(m2w_website_trajectory.objective)
    for m2w_domain_trajectory in m2w_domain_trajectories:
        m2w_domain_text.append(m2w_domain_trajectory.objective)
    for m2w_task_trajectory in m2w_task_trajectories:
        m2w_task_text.append(m2w_task_trajectory.objective)
    for m2w_train_trajectory in m2w_train_trajectories:
        m2w_train_text.append(m2w_train_trajectory.objective)
    # text["mind2web_train"] = m2w_train_text
    # text["mind2web_domain"] = m2w_domain_text
    # text["mind2web_website"] = m2w_website_text
    # text["mind2web_task"] = m2w_task_text
    text["mind2web_all"] = (
        m2w_train_text + m2w_domain_text + m2w_website_text + m2w_task_text
    )
    text["webarena"] = webarena_data
    text["target_text"] = target_text
    plot_embedding(text, "objective", LOG_TAG)


def retrieve_action_counts(trajectory_list):
    count_dict = defaultdict(int)
    for trajectory in trajectory_list:
        count_dict[trajectory.next_action.action_type] += 1
    sorted_dict = dict(
        sorted(count_dict.items(), key=lambda item: item[1], reverse=True)
    )
    total = sum(count_dict.values())
    for k, v in sorted_dict.items():
        print(f"{k}: {v} ({v / total:.2f})")


def retrieve_cot(trajectory_list, debug: bool = False):
    cot_length = []
    for trajectory in trajectory_list:
        cot_length.append(len(trajectory.next_action.cot.split()))
        if debug:
            print(trajectory.next_action.cot)
            print("-------------------------")
    mean = statistics.mean(cot_length)
    std_dev = statistics.stdev(cot_length)
    median = statistics.median(cot_length)
    print("---------------------------")
    print("cot_length:")
    print(f"Mean: {mean:.2f}")
    print(f"Standard Deviation: {std_dev:.2f}")
    print(f"Median: {median:.2f}")


def retrieve_history_length(trajectory_list):
    history_lengths = []
    for trajectory in trajectory_list:
        history_lengths.append(len(trajectory.history))
    mean = statistics.mean(history_lengths)
    std_dev = statistics.stdev(history_lengths)
    median = statistics.median(history_lengths)
    print("---------------------------")
    print("history_length:")
    print(f"Mean: {mean:.2f}")
    print(f"Standard Deviation: {std_dev:.2f}")
    print(f"Median: {median:.2f}")


def retrieve_target_element(trajectory_list):
    counter_dict = defaultdict(lambda: defaultdict(int))
    for trajectory in trajectory_list:
        if trajectory.next_action.axt_node_id == -1:
            continue
        obs_lines = trajectory.obs.split("\n")
        for line in obs_lines:
            if f"[{trajectory.next_action.axt_node_id}]" in line:
                counter_dict[line.split()[1]][
                    trajectory.next_action.action_type
                ] += 1
    for element in counter_dict:
        for action in counter_dict[element]:
            if (
                "'" not in element
                and "." not in element
                and '"' not in element
            ):
                print(
                    "{:<20} {:<20} {:<10}".format(
                        element, action, counter_dict[element][action]
                    )
                )


def retrieve_code_nl_ratio(trajectory_list):
    total_code = total_cot = total_nl = total_subtask = 0
    for trajectory in trajectory_list:
        total_code += (
            len(trajectory.next_action.typed_string)
            + len(str(trajectory.next_action.action_type).split(".")[-1])
            + len(str(trajectory.next_action.axt_node_id))
        )
        total_cot += len(trajectory.next_action.cot)
        total_nl += len(trajectory.next_action.action_description)
        total_subtask += len(trajectory.next_action.subtask)

    print(f"total_code/total_cot: {total_code/total_cot:.2f}")
    print(f"total_code/total_nl: {total_code/total_nl:.2f}")
    print(
        f"total_code/(total_cot+total_nl+total_subtask): {total_code/(total_cot+total_nl+total_subtask):.2f}"
    )


def retrieve_stop_length(trajectory_list, debug=False):
    stop_lengths = []
    for trajectory in trajectory_list:
        if trajectory.next_action.action_type == ActionType.STOP:
            if debug:
                print(trajectory.next_action.typed_string)
            stop_lengths.append(len(trajectory.next_action.typed_string))

    mean = statistics.mean(stop_lengths)
    std_dev = statistics.stdev(stop_lengths)
    median = statistics.median(stop_lengths)
    print("---------------------------")
    print("stop_lengths:")
    print(f"Mean: {mean:.2f}")
    print(f"Standard Deviation: {std_dev:.2f}")
    print(f"Median: {median:.2f}")


def retrieve_history_diversity(trajectory_list: list[Trajectory]) -> None:
    """Check the diversity of the history"""
    comb_counter = collections.Counter()
    for traj in trajectory_list:
        hist_actions = []
        history = traj.history
        for line in history.split("\n"):
            line = line.strip()
            if line.startswith("#"):
                continue
            for action in VALID_ACTIONS:
                if line.startswith(f"{action}("):
                    hist_actions.append(action)
                    break
        comb_counter["-".join(hist_actions)] += 1
    print("---------------------------")
    print(f"Number of unique histories: {len(comb_counter)}")
    # plot the distribution with x-axis as the action string
    comb_counter = dict(
        sorted(comb_counter.items(), key=lambda x: x[1], reverse=True)
    )
    # Assuming comb_counter is defined
    keys = list(comb_counter.keys())[:100]
    values = list(comb_counter.values())[:100]

    _, ax = plt.subplots()
    bar_width = 0.5  # Adjust this value to set the width of the bars

    # Create an array for the x-axis positions
    x = np.arange(len(keys)) * (
        1 + bar_width
    )  # Increase the spacing by a factor

    ax.bar(x, values, width=bar_width)
    ax.set_xticks(x)
    ax.set_xticklabels(keys, rotation=90, fontsize=4)
    plt.savefig(
        f"./data_vis/.tmp/plots/{LOG_TAG}_history_diversity.pdf",
        bbox_inches="tight",
    )


def retrieve_task_diversity(trajectory_list: list[Trajectory]) -> None:
    objectives = [traj.objective for traj in trajectory_list]
    text = {"target_text": objectives}
    embeddings = np.vstack(get_embeddings(text, "objective"))
    clustering = DBSCAN(eps=1.0, min_samples=5).fit(embeddings)
    print("---------------------------")
    print(f"DBSCAN: {len(set(clustering.labels_))} clusters")


def config():
    parser = argparse.ArgumentParser()
    parser.add_argument("file_name", type=str, help="training data file")
    # basic statistics
    parser.add_argument("--history_length", action="store_true")
    parser.add_argument("--action_count", action="store_true")
    parser.add_argument("--stop_length", action="store_true")
    parser.add_argument("--stop_length_print", action="store_true")

    # diversity measure
    parser.add_argument("--obs_embedding", action="store_true")
    parser.add_argument("--task_embedding", action="store_true")
    parser.add_argument("--task_clustering", action="store_true")
    parser.add_argument("--history_diversity", action="store_true")

    # quality, accuracy
    # cot
    parser.add_argument("--cot", action="store_true")
    parser.add_argument("--cot_print", action="store_true")
    parser.add_argument("--code_nl_ratio", action="store_true")
    # element grounding
    parser.add_argument("--target_element", action="store_true")
    args = parser.parse_args()

    return args


def main():
    args = config()
    # save to a file with current timestamp

    trajectory_list = Trajectory.from_file(args.file_name)
    if args.obs_embedding:
        retrieve_obs_embedding(trajectory_list)
    if args.task_embedding:
        retrieve_task_embedding(trajectory_list)
    if args.action_count:
        retrieve_action_counts(trajectory_list)
    if args.cot or args.cot_print:
        retrieve_cot(trajectory_list, args.cot_print)
    if args.history_length:
        retrieve_history_length(trajectory_list)
    if args.target_element:
        retrieve_target_element(trajectory_list)
    if args.code_nl_ratio:
        retrieve_code_nl_ratio(trajectory_list)
    if args.stop_length or args.stop_length_print:
        retrieve_stop_length(trajectory_list, args.stop_length_print)
    if args.task_clustering:
        retrieve_task_diversity(trajectory_list)
    if args.history_diversity:
        retrieve_history_diversity(trajectory_list)


if __name__ == "__main__":
    main()
