# -*- coding: utf-8 -*-
"""k-means in JAX

Automatically generated by Colaboratory.

Original file is located at
    https://colab.research.google.com/drive/1AwS4haUx6swF82w3nXr6QKhajdF8aSvA
"""

from functools import partial

import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp

@jax.jit
def vector_quantize(points, codebook):
    assignment = jax.vmap(
        lambda point: jnp.argmin(jax.vmap(jnp.linalg.norm)(codebook - point))
    )(points)
    distns = jax.vmap(jnp.linalg.norm)(codebook[assignment,:] - points)
    return assignment, distns

@partial(jax.jit, static_argnums=(2,))
def kmeans_run(key, points, k, thresh=1e-5):

    def improve_centroids(val):
        prev_centroids, prev_distn, _ = val
        assignment, distortions = vector_quantize(points, prev_centroids)

        # Count number of points assigned per centroid
        # (Thanks to Jean-Baptiste Cordonnier for pointing this way out that is
        # much faster and let's be honest more readable!)
        counts = (
            (assignment[jnp.newaxis, :] == jnp.arange(k)[:, jnp.newaxis])
            .sum(axis=1, keepdims=True)
            .clip(a_min=1.)  # clip to change 0/0 later to 0/1
        )

        # Sum over points in a centroid by zeroing others out
        new_centroids = jnp.sum(
            jnp.where(
                # axes: (data points, clusters, data dimension)
                assignment[:, jnp.newaxis, jnp.newaxis] \
                    == jnp.arange(k)[jnp.newaxis, :, jnp.newaxis],
                points[:, jnp.newaxis, :],
                0.,
            ),
            axis=0,
        ) / counts

        # [An alternative version, due to Jean-Baptiste Cordonnier, uses the
        #  undocumented LAX scatter ops. It achieves similar speed here, but
        #  may be worth considering in other workloads?]
        # 
        # new_centroids = jnp.zeros_like(prev_centroids)
        # dnums = jax.lax.ScatterDimensionNumbers(update_window_dims=(1,), 
        #                                         inserted_window_dims=(0,), 
        #                                         scatter_dims_to_operand_dims=(0,))
        # new_centroids = jax.lax.scatter_add(new_centroids, 
        #                                     jnp.expand_dims(assignment, -1),
        #                                     points, 
        #                                     dnums)
        # new_centroids /= counts

        # For completeness' sake, this was the original I implemented---the cond
        # and individual count divisions makes it very slow though.
        # 
        # new_centroids = jnp.stack([
        #     jax.lax.cond(
        #         jnp.any(mask > 0),
        #         None,
        #         (lambda _: jnp.sum(
        #             jnp.where(
        #                 jnp.expand_dims(mask, -1),
        #                 points,
        #                 jnp.zeros_like(points)
        #             ),
        #             axis=0,
        #         ) / jnp.count_nonzero(mask)),
        #         None,
        #         (lambda _: jnp.zeros(points.shape[1:])),
        #     )
        #     for mask in [(assignment == c) for c in range(k)]
        # ])


        return new_centroids, jnp.mean(distortions), prev_distn

    # Run one iteration to initialize distortions and cause it'll never hurt...
    initial_indices = jax.random.shuffle(key, jnp.arange(points.shape[0]))[:k]
    initial_val = improve_centroids((points[initial_indices, :], jnp.inf, None))
    # ...then iterate until convergence!
    centroids, distortion, _ = jax.lax.while_loop(
        lambda val: (val[2] - val[1]) > thresh,
        improve_centroids,
        initial_val,
    )
    return centroids, distortion

@partial(jax.jit, static_argnums=(2,3))
def kmeans(key, points, k, restarts, **kwargs):
    all_centroids, all_distortions = jax.vmap(
        lambda key: kmeans_run(key, points, k, **kwargs)
    )(jax.random.split(key, restarts))
    i = jnp.argmin(all_distortions)
    return all_centroids[i], all_distortions[i]


def find_n_clusters(key, points, n_clusters):
    codebook, _ = kmeans(key, points, n_clusters, 1)
    assignment, _ = vector_quantize(points, codebook)

    return assignment