import dataclasses
import functools
from typing import Callable

import chex
import jax
import jax.numpy as jnp
import optax
from clu import metrics as clu_metrics
from flax.training import train_state

from tabular_mvdrl import mmd
from tabular_mvdrl.envs.mrp import MarkovRewardProcess
from tabular_mvdrl.kernels import Kernel
from tabular_mvdrl.models import EWPModel, TabularProbabilityModel
from tabular_mvdrl.state import WeightedParticleState
from tabular_mvdrl.trainer import MVDRLTransferTrainer
from tabular_mvdrl.types import MRPTransitionBatch
from tabular_mvdrl.utils import jitpp, support_init
from tabular_mvdrl.utils.discrete_distributions import DiscreteDistribution
from tabular_mvdrl.utils.jitpp import Bind, Donate, Static

LOSS_MMD = "loss__mmd"

CatUpdateStep = Callable[
    [chex.PRNGKey, WeightedParticleState, MRPTransitionBatch], WeightedParticleState
]


@dataclasses.dataclass(frozen=True, kw_only=True)
class CatProjectedTDTrainer(MVDRLTransferTrainer[WeightedParticleState]):
    optim: optax.GradientTransformation
    kernel: Kernel
    discount: float
    support_map_initializer: support_init.SupportMapInitializer
    ewp_steps: int
    signed: bool = False

    @property
    def identifier(self):
        num_atoms = self.support_map_initializer(jax.random.PRNGKey(0)).shape[1]
        prefix = "Signed-" if self.signed else ""
        return f"{prefix}Cat-TD:{num_atoms}"

    # @functools.cached_property
    # def kernel_tensor(self) -> chex.Array:
    #     return _kernel_tensor(self.state, self.env, self.kernel)

    @functools.cached_property
    def kernel_inv_blocks(self) -> tuple[chex.Array, chex.Array]:
        supports = jax.vmap(self.state.support_map.apply_fn, in_axes=(None, 0))(
            self.state.support_map.params, jnp.arange(self.env.num_states)
        )
        K = jax.vmap(mmd.kernel_matrix, in_axes=(None, 0, 0))(
            self.kernel, supports, supports
        )
        K_inv = jnp.linalg.inv(K)
        return K_inv[: self.num_atoms - 1, : self.num_atoms - 1], K_inv[
            -1, : self.num_atoms - 1
        ]

    @functools.cached_property
    def metrics(self) -> clu_metrics.Collection:
        metric_tags = [LOSS_MMD]
        metric_keepers = {
            tag: clu_metrics.Average.from_output(tag) for tag in metric_tags
        }
        return clu_metrics.Collection.create(**metric_keepers)

    @functools.cached_property
    def state(self) -> WeightedParticleState:
        locs_key, probs_key = jax.random.split(jax.random.PRNGKey(self.seed))
        # TODO: determine num_atoms from initializer without computing params
        support_map = self.support_map_initializer(locs_key)
        num_atoms = support_map.shape[1]
        locs_model = EWPModel(self.env.num_states, self.env.reward_dim, num_atoms)
        locs_params = locs_model.init_with_support(
            locs_key, jnp.int32(0), self.support_map_initializer
        )
        support_map_state = train_state.TrainState.create(
            apply_fn=locs_model.apply, params=locs_params, tx=self.optim
        )

        probs_model = TabularProbabilityModel(
            self.env.num_states,
            num_atoms,
            logits=False,
            initializer=jax.nn.initializers.constant(1 / num_atoms),
        )
        probs_params = probs_model.init(probs_key, jnp.int32(0))
        return WeightedParticleState.create(
            params=probs_params,
            apply_fn=probs_model.apply,
            tx=self.optim,
            support_map=support_map_state,
            metrics=self.metrics.empty(),
        )

    @jitpp.jit
    @staticmethod
    def train_step_ewp(
        key: chex.PRNGKey,
        state: Donate[WeightedParticleState],
        batch: MRPTransitionBatch,
        *,
        kernel: Bind[Static[Kernel]],
        discount: Bind[float],
    ) -> WeightedParticleState:
        def _mmd_loss(eta_pred: chex.Array, eta_target: chex.Array) -> chex.Scalar:
            num_atoms = eta_pred.shape[0]
            uniform_probs = jnp.ones(num_atoms) / num_atoms
            return num_atoms * mmd.mmd2(
                kernel, eta_pred, eta_target, uniform_probs, uniform_probs
            )

        @jax.value_and_grad
        def loss_fn(params: chex.ArrayTree, batch_: MRPTransitionBatch):
            eta_t = jax.vmap(state.support_map.apply_fn, in_axes=(None, 0))(
                params, batch_.o_t
            )
            eta_tp1 = jax.vmap(state.support_map.apply_fn, in_axes=(None, 0))(
                state.support_map.params, batch_.o_tp1
            )
            eta_target = batch_.r_t[:, None, ...] + discount * eta_tp1
            return jnp.mean(jax.vmap(_mmd_loss)(eta_t, eta_target))

        loss, grads = loss_fn(state.support_map.params, batch)
        # grads = jax.tree_util.tree_map(lambda x: x * num_atoms, grads)
        metrics = {LOSS_MMD: loss}
        return state.replace(
            support_map=state.support_map.apply_gradients(grads=grads),
            metrics=state.metrics.single_from_model_output(**metrics),
            step=state.step + 1,
        )

    @jitpp.jit
    @staticmethod
    def train_step_cat(
        key: chex.PRNGKey,
        state: Donate[WeightedParticleState],
        batch: MRPTransitionBatch,
        *,
        kernel: Bind[Static[Kernel]],
        env: Bind[MarkovRewardProcess],
        discount: Bind[float],
        signed: Bind[Static[bool]],
        # kernel_inv_blocks: Bind[chex.Array],
    ) -> WeightedParticleState:
        def _kernel_tensor(
            state: WeightedParticleState, env: MarkovRewardProcess, kernel: Kernel
        ):
            supports = jax.vmap(state.support_map.apply_fn, in_axes=(None, 0))(
                state.support_map.params, jnp.arange(env.num_states)
            )
            return jax.vmap(mmd.kernel_matrix, in_axes=(None, 0, 0))(
                kernel, supports, supports
            )

        locs_t = jax.vmap(state.support_map.apply_fn, in_axes=(None, 0))(
            state.support_map.params, batch.o_t
        )
        locs_tp1 = jax.vmap(state.support_map.apply_fn, in_axes=(None, 0))(
            state.support_map.params, batch.o_tp1
        )
        probs_target = jax.vmap(state.apply_fn, in_axes=(None, 0))(
            state.params, batch.o_tp1
        )
        locs_target = batch.r_t[:, None, ...] + discount * locs_tp1

        kernel_tensor = _kernel_tensor(state, env, kernel)

        projected_probs = jax.vmap(
            functools.partial(mmd.mmd_projection_pre, signed=signed),
            in_axes=(0, None, 0, 0, 0),
        )(kernel_tensor, kernel, locs_t, locs_target, probs_target)

        # a, b = kernel_inv_blocks

        # def cheap_projection(p_target: chex.Array, l: chex.Array, l_target: chex.Array):
        #     a, b = kernel_inv_blocks
        #     q = mmd.kernel_matrix(kernel, locs_t, locs_target) @ p_target
        #     return a @ q[:-1] + q[-1] * b

        # projected_probs = jax.vmap(cheap_projection)(probs_target, locs_t, locs_target)

        @jax.value_and_grad
        def loss_fn(params: chex.ArrayTree) -> float:
            probs = jax.vmap(state.apply_fn, in_axes=(None, 0))(params, batch.o_t)
            return jnp.sum(optax.losses.l2_loss(probs, projected_probs))

        loss, grads = loss_fn(state.params)
        metrics = {LOSS_MMD: loss}
        return state.apply_gradients(
            grads, state.metrics.single_from_model_output(**metrics)
        )

    @jitpp.jit
    @staticmethod
    def train_step(
        key: chex.PRNGKey,
        state: Donate[WeightedParticleState],
        batch: MRPTransitionBatch,
        *,
        train_step_cat: Bind[Static[CatUpdateStep]],
        train_step_ewp: Bind[Static[CatUpdateStep]],
        ewp_steps: Bind[int],
    ) -> WeightedParticleState:
        return jax.lax.cond(
            state.step > ewp_steps, train_step_cat, train_step_ewp, key, state, batch
        )

    def return_distribution(
        self, state: WeightedParticleState, i: int
    ) -> DiscreteDistribution:
        i = jnp.int32(i)
        locs = state.support_map.apply_fn(state.support_map.params, i)
        probs = state.apply_fn(state.params, i)
        new_probs = mmd.mmd_projection(self.kernel, locs, locs, probs, signed=False)
        return DiscreteDistribution(locs=locs, probs=new_probs)
