from typing import NamedTuple
from numpy.typing import NDArray
from jax import vmap, random
from jax import numpy as jnp

from fairgym.envs.state import State, create_state


class ActionResult(NamedTuple):
    action: NDArray
    fn_frac: NDArray
    tp_frac: NDArray
    tn_frac: NDArray
    fp_frac: NDArray
    accept_rate: NDArray
    fn_rate: NDArray
    tp_rate: NDArray
    tn_rate: NDArray
    fp_rate: NDArray


@vmap
def _take_old(array, single_action):
    """
    Given an array representing some P(Z <= z),
    finds P(Z <= action). Interpolates between nearby bins.
    :param array:
    :param single_action:
    :return:
    """
    feat_bins = array.shape[-1]
    return jnp.interp(
        single_action * (feat_bins - 1) * 0.999, jnp.arange(feat_bins), array
    )


def _take_cdf(array, actions):
    """
    For each action in (0,1) in actions, given an array representing some P(Z <= z),
    finds P(Z <= action). Interpolates between nearby bins.
    :param array:
    :param actions:
    :return:
    """
    # TODO need to do interpolation or something else here
    action_indices = jnp.ceil(actions * array.shape[1]).astype(int)
    return jnp.take_along_axis(array, action_indices[:, None], axis=1).flatten()


# TODO use old until/if we fix
_take = _take_old


def threshold_action(state: State, action: NDArray):
    """
    Indexed by group
    :param state:
    :param action:
    :return:
    """

    # frac (fraction) is proportion over (Y, Yhat)
    # rate is conditioned on Y
    # e.g... TODO

    fn_frac = _take(state.pr_Y1alX, action)
    tp_frac = state.pr_Y1 - fn_frac
    tn_frac = _take(state.pr_Y0alX, action)
    fp_frac = state.pr_Y0 - tn_frac

    accept_rate = 1.0 - _take(state.pr_lX, action)

    fn_rate = fn_frac / state.pr_Y1
    tp_rate = 1.0 - fn_rate
    tn_rate = tn_frac / state.pr_Y0
    fp_rate = 1.0 - tn_rate

    return ActionResult(
        action=action,
        fn_frac=fn_frac,
        tp_frac=tp_frac,
        tn_frac=tn_frac,
        fp_frac=fp_frac,
        accept_rate=accept_rate,
        fn_rate=fn_rate,
        tp_rate=tp_rate,
        tn_rate=tn_rate,
        fp_rate=fp_rate,
    )


if __name__ == "__main__":
    # Quick test
    key = random.PRNGKey(758493)
    pr_g = jnp.array([0.3, 0.4, 0.3])
    pr_x = random.uniform(key, shape=(3, 1000))
    pr_x /= jnp.sum(pr_x, axis=1)[:, None]
    pr_y1Gx = random.uniform(key, shape=(3, 1000))
    pr_y1Gx /= jnp.sum(pr_y1Gx, axis=1)[:, None]

    test_state = create_state(pr_g, pr_x, pr_y1Gx)
    test_action = jnp.array([0.7, 0.1, 0.8])
