from functools import partial
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib import ticker
from trajdata.utils.arr_utils import agent_aware_diff

from .radian_formatter import Multiple


def pointwise_speed(df: pd.DataFrame, ax: Optional[plt.Axes] = None, bins: int = 100, **kwargs):
    df["speed"] = np.linalg.norm(df[["vx", "vy"]], axis=1)
    # print("Pointwise Speed stats:")
    # print(speed.describe())
    if ax is not None:
        ax = sns.histplot(df["speed"], bins=bins, stat="proportion", ax=ax, **kwargs)
        ax.set_xlabel(r"Speed $(m/s)$")
        ax.set_yscale("log")
    else:
        value_hist = dict()
        for agent_type, group_df in df["speed"].groupby("agent_type"):
            value_hist[agent_type] = np.histogram(group_df, bins=bins, **kwargs)

        return value_hist


def pointwise_acceleration(df: pd.DataFrame, ax: Optional[plt.Axes] = None, bins: int = 100, **kwargs):
    df = df.dropna()
    df["acc"] = np.linalg.norm(df[["ax", "ay"]], axis=1)
    # print("Pointwise Acceleration stats:")
    # print(acc.describe())
    if ax is not None:
        ax = sns.histplot(df["acc"], bins=bins, stat="proportion", ax=ax, **kwargs)
        ax.set_xlabel(r"Acceleration $(m/s^2)$")
        ax.set_yscale("log")
    else:
        value_hist = dict()
        for agent_type, group_df in df["acc"].groupby("agent_type"):
            value_hist[agent_type] = np.histogram(group_df, bins=bins, **kwargs)

        return value_hist


def pointwise_jerk(df: pd.DataFrame, DT: float, ax: Optional[plt.Axes] = None, bins: int = 100, **kwargs):
    df = df.dropna()
    acc = np.linalg.norm(df[["ax", "ay"]], axis=1)
    df["jerk"] = agent_aware_diff(acc, df.index.get_level_values("agent_id")) / DT

    # create the jerk histogram
    # print("Pointwise Jerk stats:")
    # print(pd.DataFrame(jerk).describe())
    if ax is not None:
        ax = sns.histplot(df["jerk"], bins=bins, stat="proportion", ax=ax, **kwargs)
        ax.set_xlabel(r"Jerk $(m/s^3)$")
        ax.set_yscale("log")
    else:
        value_hist = dict()
        for agent_type, group_df in df["jerk"].groupby("agent_type"):
            value_hist[agent_type] = np.histogram(group_df, bins=bins, **kwargs)

        return value_hist


def pointwise_heading(df: pd.DataFrame, ax: Optional[plt.Axes] = None, bins: int = 50, **kwargs):
    # print("Heading stats:")
    # print(df.heading.describe())
    if ax is not None:
        ax = sns.histplot(df.heading, bins=bins, stat="proportion", ax=ax, linewidth=0.1)
        ax.set_xlabel("Heading (radians)")
        ax.set_ylabel(None)
        ax.spines["polar"].set_visible(False)

        major = Multiple(denominator=4)
        # ax.xaxis.set_major_locator(major.locator())
        ax.xaxis.set_major_formatter(major.formatter())
        ax.yaxis.set_major_formatter(ticker.PercentFormatter(1.0, decimals=0))
        ax.set_axisbelow(True)
    else:
        value_hist = dict()
        for agent_type, group_df in df["heading"].groupby("agent_type"):
            value_hist[agent_type] = np.histogram(group_df, bins=bins, **kwargs)

        return value_hist


def agent_ego_distance(df: pd.DataFrame, ax: Optional[plt.Axes] = None, bins: int = 100, **kwargs):
    pos = df[["x", "y"]].copy()
    ego_df = pos.loc[pd.IndexSlice[:, "ego"], :]
    ego_df.reset_index(level=["agent_id", "agent_type"], drop=True, inplace=True)
    pos.drop(index="ego", level=1, inplace=True)
    pos.reset_index(level="agent_type", inplace=True)
    merged_df = pos.merge(ego_df, how="left", on=["scene_id", "scene_ts"], suffixes=("", "_ego"))
    assert len(merged_df) == len(pos)
    # merged_df["agent_id"] = pos.index.get_level_values("agent_id")
    # merged_df.set_index("agent_id", append=True, inplace=True)
    # merged_df = merged_df.swaplevel()

    merged_df["distance_to_ego"] = np.sqrt(
        (merged_df["x"] - merged_df["x_ego"]) ** 2 + (merged_df["y"] - merged_df["y_ego"]) ** 2
    )

    merged_df.set_index("agent_type", inplace=True)

    if ax is not None:
        ax = sns.histplot(merged_df["distance_to_ego"], bins=bins, stat="proportion", ax=ax)
        ax.set_xlabel("Agent-Ego Distance (m)")
        ax.yaxis.set_major_formatter(ticker.PercentFormatter(1.0, decimals=0))
    else:
        value_hist = dict()
        for agent_type, group_df in merged_df["distance_to_ego"].groupby("agent_type"):
            value_hist[agent_type] = np.histogram(group_df, bins=bins, **kwargs)

        return value_hist
