import time
import functools
from typing import Dict, NamedTuple

import jax
import jax.numpy as jnp
import optax
import flax
import haiku as hk

import tensorflow_datasets as tfds

import input_pipeline


class cfg:
    # From flax
    # dataset = 'imagenet2012:5.*.*'
    dataset = 'imagenette'
    
    seed = 42
    learning_rate = 0.1
    warmup_epochs = 5.0
    momentum = 0.9
    batch_size = 512
    num_epochs = 100.0
    
    image_size = 224

    cache = True
    half_precision = False
    
    l2_norm_clip = 1.0
    noise_multiplier = 1.1
    


def model_fn(features, num_classes, **_):
    mlp = hk.Sequential([
          hk.Flatten(),
          hk.Linear(300), jax.nn.relu,
          hk.Linear(100), jax.nn.relu,
          hk.Linear(10),
        ])
    return mlp(features)
    

def prep_data(ds, distributed=False):
    ldc = jax.local_device_count()
    def _prepare(x):
        x = x.numpy()
        return x.reshape((ldc, -1) + x.shape[1:]) if distributed else x
    
    it = map(functools.partial(jax.tree_map, _prepare), ds)
    return flax.jax_utils.prefetch_to_device(it, 2) if distributed else it


class TrainState(NamedTuple):
    params: Dict
    opt_state: Dict


def get_train_step(model, criterion, optim):   
    def train_step(state, batch):
        def loss_fn(params, batch):
            logits = model.apply(params, batch['image'][None])
            loss = criterion(logits, batch['label'][None])
            return loss, logits

        grad_fn = jax.vmap(jax.value_and_grad(loss_fn, has_aux=True), in_axes=(None, 0))
        (loss, logits), grad = grad_fn(state.params, batch)

        # jax.lax.pmean(grad, axis_name='batch') happens in get_updates
        updates, new_opt_state = optim.update(grad, state.opt_state, state.params)
        new_params = optax.apply_updates(state.params, updates)

        return TrainState(new_params, new_opt_state)

    return train_step

def optim_pmean(axis_name):
    def init_fn(_):
        return tuple()
    
    def update_fn(updates, state, params=None):
        return jax.lax.pmean(updates, axis_name=axis_name), state
    
    return optax.GradientTransformation(init_fn, update_fn)
    
    
def cosine_decay(lr, step, total_steps):
    # from Flax
    ratio = jnp.maximum(0., step / total_steps)
    mult = 0.5 * (1. + jnp.cos(jnp.pi * ratio))
    return mult * lr


def create_lr_fn(config,
                 base_learning_rate: float,
                 steps_per_epoch: int):
    # from Flax
    def step_fn(step):
        epoch = step / steps_per_epoch
        lr = cosine_decay(base_learning_rate,
                          epoch - config.warmup_epochs,
                          config.num_epochs - config.warmup_epochs)
        warmup = jnp.minimum(1., epoch / config.warmup_epochs)
        return lr * warmup
    return step_fn


def main():
    rng = jax.random.PRNGKey(cfg.seed)

    # Dataset
    n_classes = 10 if cfg.dataset == 'imagenette' else 1000
    local_batch_size = cfg.batch_size // jax.host_count()
    ds_builder = tfds.builder(cfg.dataset)
    if cfg.dataset == 'imagenette':
        ds_builder.download_and_prepare()
    train_iter = prep_data(input_pipeline.create_split(
        ds_builder, cfg.batch_size, image_size=cfg.image_size, 
        train=True, cache=cfg.cache), distributed=True)
    eval_iter = prep_data(input_pipeline.create_split(
        ds_builder, cfg.batch_size, image_size=cfg.image_size, 
        train=True, cache=cfg.cache), distributed=True)

    # Training vars
    steps_per_epoch = ds_builder.info.splits[
        'train'].num_examples // cfg.batch_size
    base_lr = cfg.learning_rate * cfg.batch_size / 256.
    num_steps = int(steps_per_epoch * cfg.num_epochs)
    val_examples = ds_builder.info.splits[
        'validation'].num_examples // cfg.batch_size
    base_learning_rate = cfg.learning_rate * cfg.batch_size / 256.
    
    # Create Optimizer
    optim = optax.chain(
        optax.differentially_private_aggregate(
            l2_norm_clip=cfg.l2_norm_clip,
            noise_multiplier=cfg.noise_multiplier,
            seed=cfg.seed),
        optim_pmean('batch'),
        optax.trace(decay=cfg.momentum, nesterov=False),
        optax.scale_by_schedule(create_lr_fn(
            cfg, base_lr, steps_per_epoch))
    )
    
     # Create Model
    model = hk.without_apply_rng(hk.transform(functools.partial(
        model_fn, num_classes=n_classes)))
    params = model.init(rng, jnp.ones((4, 224, 224, 3)))
    state = jax.device_put_replicated(TrainState(
        params=params,
        opt_state=optim.init(params)
    ), jax.devices())

    def criterion(logits, labels):
        return optax.softmax_cross_entropy(logits, jax.nn.one_hot(labels, n_classes)).mean()

    p_train_step = jax.pmap(get_train_step(model, criterion, optim), 
                            axis_name='batch')
    p_eval_step = None  # TODO

    timings = []
    start_time = time.perf_counter()
    for step, batch in zip(range(num_steps), train_iter):
        state = p_train_step(state, batch)

        if step % steps_per_epoch == 0 and step > 0:
            duration = time.perf_counter() - start_time
            timings.append(duration)


if __name__ == '__main__':
    main()