import functools
import numbers
from typing import Callable, NamedTuple, Sequence, TypeAlias

import jax
import jax._src.scipy.sparse.linalg as jssl
import jax.flatten_util as jfu
import jax.lax as lax
import jax.numpy as jnp
import jax.tree_util as jtu
import numpy as np
import optax
from jax.core import Jaxpr
from nix import psum_if_pmap
from nix.utils.jax_utils import pall_to_all, pgather, pidx
from nix.utils.tree_utils import (
    tree_add,
    tree_mul,
    tree_squared_norm,
)

from globe.nn import ParamTree
from globe.nn.globe import Globe
from globe.utils import (
    group_configs,
    tree_generator_zip,
)
from globe.utils.config import SystemConfigs, group_by_config, inverse_group_idx
from globe.utils.jax_utils import pmean_if_pmap
from globe.utils.jnp_utils import tree_scale


class NaturalGradientState(NamedTuple):
    damping: jax.Array
    last_grad: ParamTree


def make_natural_gradient_preconditioner(
    globe: Globe,
    decay_factor: float = 0.99,
    center: bool = True,
    **kwargs,
) -> Callable[
    [ParamTree, jax.Array, jax.Array, SystemConfigs, jax.Array, NaturalGradientState],
    tuple[ParamTree, NaturalGradientState],
]:
    """
    Returns a function that computes the natural gradient preconditioner for a given set of parameters.

    Args:
    - globe: Globe object
    - linearize: whether to linearize the network
    - precision: precision to use for the CG solver
    - kwargs: keyword arguments for the CG solver
    Returns:
    - function that computes the natural gradient preconditioner
    """
    network = jax.vmap(globe.apply, in_axes=(None, 0, None, None, None))
    n_dev = jax.device_count()

    @functools.partial(jax.jit, static_argnames='config')
    def nat_cg(
        params,
        electrons,
        atoms,
        config: SystemConfigs,
        dloss_dlog_p: jax.Array,
        natgrad_state: NaturalGradientState,
        weights: jax.Array | None = None,
    ):
        # Remove the last two dimensions of electrons to get the batch size
        # multiply by the number of graphs per config
        n = dloss_dlog_p.size * n_dev
        norm = (1 / jnp.sqrt(n)).astype(dloss_dlog_p.dtype)
        struc_params = globe.apply(params, atoms, config, method=globe.structure_params)
        damping = natgrad_state.damping.astype(dloss_dlog_p.dtype)

        def log_p_closure(p):
            total_params = {**params, 'params': p}
            result = network(total_params, electrons, atoms, config, struc_params)
            return result * norm

        _, vjp_fn = jax.vjp(log_p_closure, params['params'])
        _, jvp_fn = jax.linearize(log_p_closure, params['params'])
        if center:

            def center_fn(x):
                # This centers on a per molecule basis rather than center everything
                x = x.reshape(dloss_dlog_p.shape)
                center = pmean_if_pmap(jnp.mean(x, axis=0, keepdims=True))
                return x - center

            def vjp(x):
                return psum_if_pmap(vjp_fn(center_fn(x).astype(dloss_dlog_p.dtype))[0])

            def jvp(x):
                return center_fn(jvp_fn(x))
        else:
            vjp = lambda x: vjp_fn(x)[0]
            jvp = jvp_fn

        grad = psum_if_pmap(vjp(dloss_dlog_p * norm))
        last_grad = natgrad_state.last_grad
        last_grad = jtu.tree_map(jax.lax.convert_element_type, last_grad, grad)
        decayed_last_grad = tree_mul(last_grad, decay_factor)
        b = tree_add(grad, tree_mul(decayed_last_grad, damping))

        def Fisher_matmul(v):
            result = vjp(jvp(v))
            # add damping
            result = tree_add(result, tree_scale(damping, v))
            # synchronize across GPUs
            result = psum_if_pmap(result)
            return result

        # Compute natural gradient
        natgrad = cg(
            A=Fisher_matmul,
            b=b,
            x0=last_grad,
            fixed_iter=n_dev > 1,  # multi gpu
            **kwargs,
        )[0]

        aux_data = dict(
            grad_norm=tree_norm(grad),
            natgrad_norm=tree_norm(natgrad),
            decayed_last_grad_norm=tree_norm(decayed_last_grad),
        )
        return natgrad, NaturalGradientState(natgrad_state.damping, natgrad), aux_data

    return nat_cg


def get_jacobian(
    globe: Globe,
    params,
    electrons: jax.Array,
    atoms: jax.Array,
    config: SystemConfigs,
    pad_to_devices: bool = True,
):
    N = electrons.shape[0] * config.n_mols

    @functools.partial(jax.vmap, in_axes=(None, 0, None, None))
    @functools.partial(jax.vmap, in_axes=(None, 0, 0, None))
    @jax.grad
    def jac_fn(p, electrons: jax.Array, atoms: jax.Array, config: SystemConfigs):
        return globe.apply({**params, 'params': p}, electrons, atoms, config).sum()  # type: ignore

    jacobians = []
    for elec, nuc, conf in group_configs(config, electrons, atoms, elec_axis=1):
        jacobian = jac_fn(params['params'], elec, nuc, conf)
        jacobian = jnp.concatenate(
            jtu.tree_leaves(
                jtu.tree_map(lambda x: x.reshape(*x.shape[:2], -1), jacobian)
            ),
            axis=-1,
        )
        jacobians.append(jacobian)
    inv_idx = inverse_group_idx(config)
    jacobian = (
        jnp.concatenate(jacobians, axis=1)[:, inv_idx]
        .reshape(N, -1)
        .astype(jnp.float64)
    )
    if pad_to_devices:
        remainder = jacobian.shape[1] % jax.device_count()
        if remainder > 0:
            jacobian = jnp.concatenate(
                [jacobian, jnp.zeros((N, jax.device_count() - remainder))], axis=1
            )
    return jacobian / jnp.sqrt(N * jax.device_count())


def make_exact_natural_gradient_preconditioner(
    globe: Globe,
):
    n_dev = jax.device_count()

    @functools.partial(jax.jit, static_argnames='config')
    def cg(
        params,
        electrons: jax.Array,
        atoms: jax.Array,
        config: SystemConfigs,
        dloss_dlog_p: jax.Array,
        natgrad_state: NaturalGradientState,
        weights: jax.Array | None = None,
    ):
        flat_params, unravel = jfu.ravel_pytree(params['params'])
        N = electrons.shape[0] * config.n_mols
        P = flat_params.size
        damping = natgrad_state.damping

        jacobian = get_jacobian(
            globe, params, electrons, atoms, config, pad_to_devices=True
        )
        jacobian = jacobian - pmean_if_pmap(jnp.mean(jacobian, axis=0))

        JT = pall_to_all(jacobian, split_axis=1, concat_axis=0, tiled=True)
        JT_J = psum_if_pmap(JT @ JT.T)

        dloss_dlog_p = pgather(dloss_dlog_p.reshape(-1)).reshape(-1)

        I = jnp.eye(N * n_dev)
        # cotangent = jax.scipy.linalg.solve(
        #     JT_J + damping * I, dloss_dlog_p / jnp.sqrt(N*n_dev), assume_a='pos'
        # )
        cotangent = jnp.linalg.solve(
            JT_J + damping * I, dloss_dlog_p / jnp.sqrt(N * n_dev)
        )
        cotangent = cotangent.reshape(n_dev, -1)[pidx()]

        natgrad = pmean_if_pmap(cotangent @ jacobian)
        natgrad = unravel(natgrad[:P])
        return natgrad, NaturalGradientState(damping, natgrad)

    return cg


def make_spring_preconditioner(
    globe: Globe, decay_factor: float = 0.99, norm_constraint: float = 1e-3
):
    n_dev = jax.device_count()

    @functools.partial(jax.jit, static_argnames='config')
    def spring(
        params,
        electrons: jax.Array,
        atoms: jax.Array,
        config: SystemConfigs,
        dloss_dlog_p: jax.Array,
        natgrad_state: NaturalGradientState,
        weights: jax.Array | None = None,
    ):
        flat_params, unravel = jfu.ravel_pytree(params['params'])
        N = electrons.shape[0] * config.n_mols
        P = flat_params.size
        damping = natgrad_state.damping

        jacobian = get_jacobian(
            globe, params, electrons, atoms, config, pad_to_devices=True
        )
        jacobian = jacobian - pmean_if_pmap(jnp.mean(jacobian, axis=0))

        JT = pall_to_all(jacobian, split_axis=1, concat_axis=0, tiled=True)
        JT_J = psum_if_pmap(JT @ JT.T)

        last_grad = (
            natgrad_state.last_grad
            if isinstance(natgrad_state.last_grad, jax.Array)
            else jnp.zeros((jacobian.shape[-1],), dtype=jacobian.dtype)
        )
        decayed_last_grad: jax.Array = decay_factor * last_grad  # type: ignore
        cotangent = (
            dloss_dlog_p.reshape(-1) / jnp.sqrt(N * n_dev)
            - jacobian @ decayed_last_grad
        )
        cotangent = pgather(cotangent).reshape(-1)

        I = jnp.eye(N * n_dev)
        cotangent = jnp.linalg.solve(JT_J + damping * I + 1 / (N * n_dev), cotangent)
        cotangent = cotangent.reshape(n_dev, -1)[pidx()]

        natgrad_flat = psum_if_pmap(cotangent @ jacobian) + decayed_last_grad
        natgrad_flat *= jnp.minimum(
            jnp.ones(()), jnp.sqrt(norm_constraint) / jnp.linalg.norm(natgrad_flat)
        )
        natgrad = unravel(natgrad_flat[:P])
        return natgrad, NaturalGradientState(damping, natgrad_flat)

    return spring


def greedy_group_tensors(
    tensors: Sequence[jax.Array], max_size: int | None, axis: int
) -> list[jax.Array]:
    if max_size is None:
        max_size = max(tensor.shape[axis] for tensor in tensors)
    current_stack: list[jax.Array] = []
    current_size = 0
    result = []
    for tensor in tensors:
        if current_size + tensor.shape[axis] > max_size:
            if len(current_stack) > 0:
                result.append(jnp.concatenate(current_stack, axis=axis))
            current_stack = [tensor]
            current_size = tensor.shape[axis]
        else:
            current_stack.append(tensor)
            current_size += tensor.shape[axis]
    result.append(jnp.concatenate(current_stack, axis=axis))
    return result


DictTree: TypeAlias = dict[str, 'DictTree'] | jax.Array


def sum_in_order(items: list[jax.Array], order: Sequence[int] | np.ndarray):
    result = jnp.zeros(())
    for i in order:
        result = result + items[i]
    return result


def outvar_depth_from_jaxpr(jaxpr: Jaxpr):
    # Returns for each output variable the index in the order they are computed.
    from jax.core import Atom, Literal

    parents: dict[Atom, list[Atom]] = {v: [] for v in jaxpr.invars + jaxpr.constvars}
    depths: dict[Atom, int] = {}
    for eqn in jaxpr.eqns:
        for ov in eqn.outvars:
            parents[ov] = eqn.invars

    def depth(v: Atom):
        if isinstance(v, Literal):
            return 0
        if len(parents[v]) == 0:
            return 0
        if v not in depths:
            depths[v] = 1 + max(depth(p) for p in parents[v])
        return depths[v]

    out_indices = np.array(list(map(depth, jaxpr.outvars)))
    return out_indices


def summation_order(fn, *args):
    # Returns a list of integers where the ith integer is the index of the ith computed output variable.
    # This way we can find the optimal optimal summation order and avoid keeping stuff unnecessarily in memory.
    closed_jaxpr = jax.make_jaxpr(fn, static_argnums=3)(*args)
    depths = outvar_depth_from_jaxpr(closed_jaxpr.jaxpr)
    return np.argsort(depths)


def make_memefficient_spring_preconditioner(
    globe: Globe,
    damping: float | jax.Array,
    decay_factor: float,
    momentum: float,
    norm_constraint: float | None,
    dtype: str = 'float64',
    only_use_wf: bool = False,
    **_,
):
    damping = jnp.ones((), dtype=dtype) * damping

    @functools.partial(jax.jit, static_argnames='config')
    def spring(
        params,
        electrons: jax.Array,
        atoms: jax.Array,
        config: SystemConfigs,
        dloss_dlog_p: jax.Array,
        natgrad_state: NaturalGradientState,
        weights: jax.Array | None = None,
    ):
        N_per_device = electrons.shape[0] * config.n_mols
        N_dev = jax.device_count()
        N = N_per_device * N_dev
        norm = 1 / jnp.sqrt(N)

        # Run everything in target dtype
        def to_target_dtype(x):
            return jtu.tree_map(lambda x: x.astype(dtype), x)

        out_dtypes = jtu.tree_map(lambda x: x.dtype, params['params'])
        params, electrons, atoms, dloss_dlog_p = to_target_dtype(
            (params, electrons, atoms, dloss_dlog_p)
        )

        def center_fn(x: jax.Array):
            # This centers on a per molecule basis rather than center everything
            x = x.reshape(dloss_dlog_p.shape)
            # center over molecules
            center = pmean_if_pmap(jnp.mean(x, axis=0))
            x -= center
            # center over the whole dataset
            return x.astype(dtype)

        # indices to restore original sorting
        inv_idx = inverse_group_idx(config)

        def to_covariance(jacs: list[jax.Array]) -> jax.Array:
            # merge all systems into a single jacobian
            jac = jnp.concatenate(jacs, axis=1)[:, inv_idx].astype(jnp.float64)
            # pad and transpose across devices
            if jac.shape[-1] % N_dev != 0:
                pad = jnp.zeros((*jac.shape[:-1], N_dev - jac.shape[-1] % N_dev))
                jac = jnp.concatenate([jac, pad], axis=-1)
            # N, mols, params
            jac = pall_to_all(jac, split_axis=2, concat_axis=0, tiled=True)
            jac = jac.at[:].add(-jac.mean(axis=0))  # this should be inplace?
            jac = jac.reshape(N, -1)
            return jac @ jac.T

        jacs: list[list[jax.Array]] = []  # list of lists of tensors
        for (elec, (spins, charges)), nuc in tree_generator_zip(
            group_by_config(
                config, electrons, lambda s, c: np.sum(s), return_config=True, axis=1
            ),
            group_by_config(config, atoms, lambda s, c: len(c)),
        ):

            @functools.partial(jax.vmap, in_axes=(None, 0, None, None))
            @functools.partial(jax.vmap, in_axes=(None, 0, 0, None))
            @jax.grad
            def jac_fn(p, electrons, atoms, config):
                total_params = {**params, 'params': p}
                mol_params = None
                if only_use_wf:
                    mol_params = globe.apply(
                        total_params, atoms, config, method=globe.get_mol_params
                    )
                    mol_params = jax.lax.stop_gradient(mol_params)
                result = globe.apply(
                    total_params,
                    electrons,
                    atoms,
                    config,
                    mol_params=mol_params,
                    method=globe.wf,
                ).sum()  # type: ignore
                return result * norm

            conf = SystemConfigs((spins,), (charges,))
            jac_tree = jac_fn(params['params'], elec, nuc, conf)
            jac_tree = jtu.tree_map(lambda x: x.reshape(*elec.shape[:2], -1), jac_tree)
            jac_tensors = jtu.tree_leaves(jac_tree)
            jacs.append(jac_tensors)

        # Compute the covariance matrix
        JT_J = jnp.zeros((N, N), dtype=jnp.float64)
        for i in range(len(jacs[0])):
            JT_J += to_covariance([j[i] for j in jacs])
        JT_J = psum_if_pmap(JT_J)

        def log_p_closed(p):
            fn = jax.vmap(globe.apply, in_axes=(None, 0, None, None))
            return fn({**params, 'params': p}, electrons, atoms, config) * norm

        _, vjp_fn = jax.vjp(log_p_closed, params['params'])

        def vjp(x):
            return psum_if_pmap(vjp_fn(center_fn(x))[0])

        def jvp(x):
            if only_use_wf:
                x['gnn'] = jtu.tree_map(jnp.zeros_like, x['gnn'])
            return center_fn(jax.jvp(log_p_closed, (params['params'],), (x,))[1])

        # construct episolon tilde
        last_grad = to_target_dtype(natgrad_state.last_grad)
        decayed_last_grad = tree_mul(last_grad, decay_factor)
        epsilon_tilde = dloss_dlog_p * norm - jvp(decayed_last_grad)
        epsilon_tilde = pgather(epsilon_tilde, axis=0, tiled=True)
        epsilon_tilde = epsilon_tilde.astype(jnp.float64).reshape(-1)

        T = JT_J + damping * np.eye(N) + 1 / N

        x = jax.scipy.linalg.solve(T, epsilon_tilde, assume_a='pos', check_finite=False)
        x = x.reshape(N_dev, -1)[pidx()].astype(dtype)
        preconditioned = vjp(x)
        natgrad = tree_add(preconditioned, decayed_last_grad)

        # Add momentum
        if momentum:
            natgrad = tree_add(
                tree_mul(natgrad, 1 - momentum),
                tree_mul(last_grad, momentum),
            )

        cache, update = natgrad, natgrad
        # Return proper dtypes
        update = jtu.tree_map(jax.lax.convert_element_type, update, out_dtypes)
        return update, NaturalGradientState(damping, cache), {}

    return spring


def make_block_spring_preconditioner(
    globe: Globe,
    damping: float | jax.Array,
    decay_factor: float,
    momentum: float,
    norm_constraint: float | None,
    **_,
):
    damping = jnp.ones((), dtype=jnp.float64) * damping

    @functools.partial(jax.jit, static_argnames='config')
    def block_spring(
        params,
        electrons: jax.Array,
        atoms: jax.Array,
        config: SystemConfigs,
        dloss_dlog_p: jax.Array,
        natgrad_state: NaturalGradientState,
        weights: jax.Array | None = None,
    ):
        N_per_device = electrons.shape[0]
        N_dev = jax.device_count()
        N = N_per_device * N_dev
        norm = 1 / jnp.sqrt(N)

        # Run everything in float64
        def to_float64(x):
            return jtu.tree_map(lambda x: x.astype(jnp.float64), x)

        out_dtypes = jtu.tree_map(lambda x: x.dtype, params['params'])
        params, electrons, atoms, dloss_dlog_p = to_float64(
            (params, electrons, atoms, dloss_dlog_p)
        )

        def center(x):
            return x - pmean_if_pmap(x.mean())

        @functools.partial(jax.vmap, in_axes=(None, 0, None, None))
        @jax.grad
        def jac_fn(p, electrons, atoms, config):
            total_params = {**params, 'params': p}
            return globe.apply(total_params, electrons, atoms, config).sum() * norm  # type: ignore

        last_grad = to_float64(natgrad_state.last_grad)
        decayed_last_grad = tree_mul(last_grad, decay_factor)

        updates = jtu.tree_map(jnp.zeros_like, params['params'])
        for (elec, (spins, charges)), nuc, epsilon in tree_generator_zip(
            group_by_config(
                config, electrons, lambda s, c: np.sum(s), return_config=True, axis=1
            ),
            group_by_config(config, atoms, lambda s, c: len(c)),
            group_by_config(config, dloss_dlog_p, lambda s, c: 1, axis=1),
        ):
            conf = SystemConfigs((spins,), (charges,))

            @functools.partial(jax.vmap, in_axes=(1, 0, 1))
            def comp_update(elec, nuc, epsilon):
                assert epsilon.ndim == 2
                assert epsilon.shape[1] == 1
                jac_tree = jac_fn(params['params'], elec, nuc, conf)
                JT_J = 0
                for jac in jtu.tree_leaves(jac_tree):
                    jac = jac.reshape(*epsilon.shape, -1)
                    jac -= pmean_if_pmap(jnp.mean(jac, axis=0))
                    jac = jac.reshape(-1, jac.shape[-1])
                    if jac.shape[-1] % N_dev != 0:
                        pad = jnp.zeros((N_per_device, N_dev - jac.shape[-1] % N_dev))
                        jac = jnp.concatenate([jac, pad], axis=-1)
                    jac = pall_to_all(jac, split_axis=1, concat_axis=0, tiled=True)
                    JT_J += jac @ jac.T
                JT_J = psum_if_pmap(JT_J)

                def log_p_closed(p):
                    fn = jax.vmap(globe.apply, in_axes=(None, 0, None, None))
                    return fn({**params, 'params': p}, elec, nuc, conf) * norm

                _, vjp_fn = jax.vjp(log_p_closed, params['params'])

                def vjp(x):
                    return psum_if_pmap(vjp_fn(center(x))[0])

                def jvp(x):
                    uncentered = jax.jvp(log_p_closed, (params['params'],), (x,))[1]
                    return center(uncentered)

                epsilon_tilde = epsilon * norm - jvp(decayed_last_grad)
                epsilon_tilde = pgather(epsilon_tilde, axis=0, tiled=True)
                epsilon_tilde = epsilon_tilde.reshape(-1)

                T = JT_J + damping * np.eye(N) + 1 / N

                x = jax.scipy.linalg.solve(
                    T, epsilon_tilde, assume_a='pos', check_finite=False
                ).reshape(N_dev, N_per_device)[pidx()]
                preconditioned = vjp(x.reshape(epsilon.shape))
                return tree_add(preconditioned, decayed_last_grad)

            updates = jtu.tree_map(
                lambda x, y: x + y.sum(0),
                updates,
                comp_update(elec, nuc, epsilon),
            )
        updates = tree_mul(updates, 1 / config.n_mols)

        # Add momentum
        if momentum:
            updates = tree_add(
                tree_mul(updates, 1 - momentum),
                tree_mul(last_grad, momentum),
            )

        cache, update = updates, updates
        # Return proper dtypes
        update = jtu.tree_map(jax.lax.convert_element_type, update, out_dtypes)
        return update, NaturalGradientState(damping, cache), {}

    return block_spring


def tree_norm(tree):
    return tree_squared_norm(tree) ** 0.5


def clip_tree_norm(tree, max_norm):
    if max_norm is not None:
        norm = tree_norm(tree)
        scale = jnp.minimum(1.0, max_norm / norm)
        # if we have a zero norm vector, the previous line is a division by zero.
        scale = jnp.nan_to_num(scale)
        return tree_mul(tree, scale)
    return tree


def get_spring_update_fn(
    globe: Globe,
    damping=0.001,
    mu=0.99,
    momentum=0.0,
):
    def spring_update_fn(
        params,
        electrons: jax.Array,
        atoms: jax.Array,
        config: SystemConfigs,
        dloss_dlog_p: jax.Array,
        natgrad_state,
        weights: jax.Array | None = None,
    ):
        from jax.flatten_util import ravel_pytree

        nchains = electrons.shape[0] * config.n_mols

        if natgrad_state is None:
            natgrad_state = jtu.tree_map(jnp.zeros_like, params['params'])

        prev_grad, unravel_fn = ravel_pytree(natgrad_state)
        prev_grad_decayed = mu * prev_grad

        log_psi_grads = get_jacobian(globe, params, electrons, atoms, config, True)
        Ohat = log_psi_grads - jnp.mean(log_psi_grads, axis=0, keepdims=True)

        T = Ohat @ Ohat.T
        ones = jnp.ones((nchains, 1))
        T_reg = T + ones @ ones.T / nchains + damping * jnp.eye(nchains)

        epsilon_bar = dloss_dlog_p.reshape(-1) / jnp.sqrt(nchains)
        epsion_tilde = epsilon_bar - Ohat @ prev_grad_decayed

        dtheta_residual = Ohat.T @ jax.scipy.linalg.solve(
            T_reg, epsion_tilde, assume_a='pos'
        )

        SR_G = dtheta_residual + prev_grad_decayed
        SR_G = (1 - momentum) * SR_G + momentum * prev_grad

        result = unravel_fn(SR_G)
        return result, result

    return spring_update_fn


def cg_solve(
    A,
    b,
    x0=None,
    *,
    maxiter,
    M=jssl._identity,
    min_lookback=10,
    lookback_frac=0.1,
    eps=5e-6,
):
    steps = jnp.arange(maxiter + 1) - 1
    gaps = (steps * lookback_frac).astype(jnp.int32)
    gaps = jnp.where(gaps < min_lookback, min_lookback, gaps)
    gaps = jnp.where(gaps > steps, steps, gaps)

    def cond_fun(value):
        x, r, gamma, p, k, cache = value
        gap = gaps[k]
        k = k - 1
        relative_still = jnp.logical_and(
            jnp.abs((cache[k] - cache[k - gap]) / cache[k]) < eps * gap, gap >= 1
        )
        over_max = k >= maxiter
        # We check that we are after the third iteration because the first ones may have close to 0 error.
        converged = jnp.logical_and(k > 2, jnp.abs(cache[k]) < 1e-7)
        return ~(relative_still | over_max | converged)

    def body_fun(value):
        x, r, gamma, p, k, cache = value
        Ap = A(p)
        alpha = gamma / jssl._vdot_real_tree(p, Ap)
        x_ = jssl._add(x, jssl._mul(alpha, p))
        r_ = jssl._sub(r, jssl._mul(alpha, Ap))
        z_ = M(r_)
        gamma_ = jssl._vdot_real_tree(r_, z_)
        beta_ = gamma_ / gamma
        p_ = jssl._add(z_, jssl._mul(beta_, p))

        Ax = jssl._add(r_, b)

        val = jtu.tree_reduce(
            jnp.add, jtu.tree_map(lambda a, b, c: jnp.vdot(a - b, c), Ax, b, x_)
        )
        cache_ = cache.at[k].set(val)
        return x_, r_, gamma_, p_, k + 1, cache_

    r0 = jssl._sub(b, A(x0))
    p0 = z0 = M(r0)
    gamma0 = jssl._vdot_real_tree(r0, z0)
    initial_value = (x0, r0, gamma0, p0, 0, jnp.zeros((maxiter,)))

    x_final, _, _, _, _, _ = lax.while_loop(cond_fun, body_fun, initial_value)

    return x_final


def cg_solve_fixediter(A, b, x0=None, *, maxiter, M=jssl._identity):
    # Implementation of CG-method with a fixed number of iterations
    def body_fun(value, i):
        del i
        x, r, gamma, p = value
        Ap = A(p)
        alpha = gamma / jssl._vdot_real_tree(p, Ap)
        x_ = jssl._add(x, jssl._mul(alpha, p))
        r_ = jssl._sub(r, jssl._mul(alpha, Ap))
        z_ = M(r_)
        gamma_ = jssl._vdot_real_tree(r_, z_)
        beta_ = gamma_ / gamma
        p_ = jssl._add(z_, jssl._mul(beta_, p))

        return (x_, r_, gamma_, p_), None

    r0 = jssl._sub(b, A(x0))
    p0 = z0 = M(r0)
    gamma0 = jssl._vdot_real_tree(r0, z0)
    initial_value = (x0, r0, gamma0, p0)

    x_final, _, _, _ = lax.scan(body_fun, initial_value, jnp.arange(maxiter), maxiter)[
        0
    ]
    return x_final


def cg(
    A,
    b,
    x0=None,
    *,
    maxiter=None,
    min_lookback=10,
    lookback_frac=0.1,
    eps=5e-6,
    M=None,
    fixed_iter=False,
):
    """CG-method with the stopping criterium from Martens 2010.

    Args:
        A (Callable): Matrix A in Ax=b
        b (jax.Array): b
        x0 (jax.Array, optional): Initial value for x. Defaults to None.
        maxiter (int, optional): Maximum number of iterations. Defaults to None.
        min_lookback (int, optional): Minimum lookback distance. Defaults to 10.
        lookback_frac (float, optional): Fraction of iterations to look back. Defaults to 0.1.
        eps (float, optional): An epsilon value. Defaults to 5e-6.
        M (Callable, optional): Preconditioner. Defaults to None.

    Returns:
        jax.Array: b
    """
    if x0 is None:
        x0 = jtu.tree_map(jnp.zeros_like, b)

    b, x0 = jax.device_put((b, x0))

    if maxiter is None:
        size = sum(bi.size for bi in jtu.tree_leaves(b))
        maxiter = 10 * size

    if M is None:
        M = jssl._identity
    A = jssl._normalize_matvec(A)
    M = jssl._normalize_matvec(M)

    if jtu.tree_structure(x0) != jtu.tree_structure(b):
        raise ValueError(
            'x0 and b must have matching tree structure: '
            f'{jtu.tree_structure(x0)} vs {jtu.tree_structure(b)}'
        )

    if jssl._shapes(x0) != jssl._shapes(b):
        raise ValueError(
            'arrays in x0 and b must have matching shapes: '
            f'{jssl._shapes(x0)} vs {jssl._shapes(b)}'
        )

    if fixed_iter:
        solve = functools.partial(cg_solve_fixediter, x0=x0, maxiter=maxiter, M=M)
    else:
        solve = functools.partial(
            cg_solve,
            x0=x0,
            maxiter=maxiter,
            min_lookback=min_lookback,
            lookback_frac=lookback_frac,
            eps=eps,
            M=M,
        )

    # real-valued positive-definite linear operators are symmetric
    def real_valued(x):
        return not issubclass(x.dtype.type, np.complexfloating)

    symmetric = all(map(real_valued, jtu.tree_leaves(b)))
    x = lax.custom_linear_solve(
        A, b, solve=solve, transpose_solve=solve, symmetric=symmetric
    )
    info = None
    return x, info


def make_schedule(params: dict) -> Callable[[int], float]:
    """Simple function to create different kind of schedules.

    Args:
        params (dict): Parameters for the schedules.

    Returns:
        Callable[[int], float]: schedule function
    """
    if isinstance(params, numbers.Number):

        def result(t):
            return params
    elif callable(params):
        result = params
    elif isinstance(params, dict):
        if 'schedule' not in params or params['schedule'] == 'hyperbola':
            assert 'init' in params
            assert 'delay' in params
            init = params['init']
            delay = params['delay']
            decay = params['decay'] if 'decay' in params else 1

            def result(t):
                return init * jnp.power(1 / (1 + t / delay), decay)
        elif params['schedule'] == 'exponential':
            assert 'init' in params
            assert 'delay' in params
            init = params['init']
            delay = params['delay']

            def result(t):
                return init * jnp.exp(-t / delay)
        else:
            raise ValueError()
        if 'min' in params:
            val_fn = result

            def result(t):
                return jnp.maximum(val_fn(t), params['min'])
    else:
        raise ValueError()
    return result


def scale_by_trust_ratio_embeddings(
    min_norm: float = 0.0,
    trust_coefficient: float = 1.0,
    eps: float = 0.0,
) -> optax.GradientTransformation:
    """Scale by trust ratio but for embeddings were we don't want the norm
    over all parameters but just the last dimension.
    """

    def init_fn(params):
        del params
        return optax.ScaleByTrustRatioState()

    def update_fn(updates, state, params):
        if params is None:
            raise ValueError(optax.NO_PARAMS_MSG)

        def _scale_update(update, param):
            # Clip norms to minimum value, by default no clipping.
            param_norm = optax.safe_norm(param, min_norm, axis=-1, keepdims=True)
            update_norm = optax.safe_norm(update, min_norm, axis=-1, keepdims=True)
            trust_ratio = trust_coefficient * param_norm / (update_norm + eps)

            # If no minimum norm clipping is used
            # Set trust_ratio to 1 in case where parameters would never be updated.
            zero_norm = jnp.logical_or(param_norm == 0.0, update_norm == 0.0)
            safe_trust_ratio = jnp.where(
                zero_norm, jnp.array(1.0, dtype=param.dtype), trust_ratio
            )

            return update * safe_trust_ratio

        updates = jax.tree_util.tree_map(_scale_update, updates, params)
        return updates, state

    return optax.GradientTransformation(init_fn, update_fn)
