"""Experiment that shows arbitrary off-policy behavior of TD."""

import argparse
from distutils.util import strtobool
import json
import os
from pathlib import Path
import pickle

import numpy as np
import ray
from ray import tune
import tensorflow as tf
import tensorflow_probability as tfp


from policy_evaluation.utils import PROJECT_ROOT, get_git_rev
from policy_evaluation import environments
from policy_evaluation import policies

from .mountain_car import MountainCarExperimentRunner, datetime_stamp
from .experiment_runner import set_gpu_memory_growth


set_gpu_memory_growth(True)

tfd = tfp.distributions


DISCOUNT = 0.98
CURRENT_FILE_PATH = Path(__file__)
CACHE_DIR = CURRENT_FILE_PATH.parent / 'data' / CURRENT_FILE_PATH.stem


experiment_params = {
    # 'total_samples': 5000,  # l
    # 'epoch_length': 50,  # error_every
    'n_episodes': 1,  # n_eps
    'episodic': False,
    'name': "puddle_world",
    'title': "Puddle World",
    'criterion': "RMSE",
}

run_params = {
    # 'run_eagerly': True,
    'num_samples': 5,  # n_indep
    # 'seed': 1,
    'verbose': 100,
}


environment_params = {
    'class_name': 'PuddleWorld',
    'config': {
        'reset_fn': (0.0, 0.0),
        'puddle_cost_weight': 1.0,
    },
}


def normalize_puddle_world_states(states, environment):
    old_low = environment.observation_space.low
    old_high = environment.observation_space.high
    normalized_states = environments.utils.rescale_values(
        states, old_low, old_high, new_low=-1.0, new_high=1.0)

    assert np.all(np.abs(normalized_states) <= 1.0)
    np.testing.assert_allclose(np.mean(normalized_states), 0, atol=1e-1)
    return normalized_states


def normalize_puddle_world_rewards(rewards):
    normalized_rewards = rewards
    return normalized_rewards


class PuddleWorldPolicy(policies.BasePolicy):
    def actions(self, inputs):
        action_indices = tf.random.uniform(
            (tf.shape(inputs)[0], 1), minval=0, maxval=2, dtype=tf.int64)
        actions = tf.gather([1, 3], action_indices)
        return actions

    def log_probs(self, *args, **kwargs):
        return tf.math.log(self.probs(*args, **kwargs))

    def probs(self, inputs, actions):
        return tf.where(tf.logical_or(actions == 1, actions == 3), 0.5, 0.0)


class PuddleWorldExperimentRunner(MountainCarExperimentRunner):
    pass


algorithm_params = {
    'bbo-rp': {
        'class_name': 'BBORandomizedPrior',
        'config': {
            'gamma': DISCOUNT,
            'num_phi_steps': 10,

            'phi_lr': 1e-2,
            'omega_lr': 3e-2,

            'prior_loc': 0.0,
            'prior_scale': 3.0,
            'prior_loss_weight': 3e-4
        },
    },
    'td0': {
        'class_name': 'TD0',
        'config': {
            'gamma': DISCOUNT,
            'alpha': 3e-3,
        },
    },
    'tdc': {
        'class_name': 'TDC',
        'config': {
            'gamma': DISCOUNT,
            'alpha': 1e-3,
            'beta': 1e-3,
        },
    },
    'gtd2': {
        'class_name': 'GTD2',
        'config': {
            'gamma': DISCOUNT,
            'alpha': tune.grid_search([1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0]),
            'beta': tune.grid_search([1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1.0]),
        },
    },
}

value_function_params = {
    'hidden_layer_sizes': (256, ),
    'activation': tune.sample_from(lambda spec: (
        {
            'TDC': 'tanh',
            'GTD2': 'tanh',
        }.get(
            spec.get('config', spec)
            ['algorithm_params']
            ['class_name'],
            'relu')
    )),
}


def train(num_samples,
          num_steps,
          epoch_length,
          debug,
          use_wandb,
          algorithm='bbo-rp',
          experiment_name=None):
    ray.init(resources={}, local_mode=debug, include_webui=False)

    if num_samples < 1 or 25 < num_samples:
        raise ValueError("num_samples must be between 1 and 25.")
    seeds = np.sort(np.random.choice(25, num_samples, replace=False)).tolist()

    run_params.update({
        'run_eagerly': debug,
        'seed': tune.grid_search(seeds),
    })

    policy = PuddleWorldPolicy(input_shapes=[2], output_shape=[1])
    behavior_policy = target_policy = policy

    def generate_dataset(seed,
                         behavior_policy,
                         target_policy,
                         environment_params,
                         *args,
                         **kwargs):
        environment_params = environment_params.copy()
        environment = environments.PuddleWorld(
            **environment_params['config'])

        dataset_key = json.dumps({
            'seed': seed,
            'environment': environment_params,
            'behavior_policy': behavior_policy.get_config(),
            'target_policy': target_policy.get_config(),
            **kwargs,
        }, sort_keys=True, separators=',:')

        cache_path = CACHE_DIR / 'datasets' / f'dataset-{dataset_key}'
        if os.path.exists(cache_path):
            with open(cache_path, 'rb') as f:
                dataset = pickle.load(f)
        else:
            dataset = environments.utils.generate_dataset(
                environment,
                behavior_policy,
                target_policy,
                *args,
                **kwargs)
            os.makedirs(CACHE_DIR / 'datasets', exist_ok=True)
            with open(cache_path, 'wb') as f:
                pickle.dump(dataset, f)

        dataset['samples']['state_0'] = normalize_puddle_world_states(
            dataset['samples']['state_0'], environment)
        dataset['samples']['state_1'] = normalize_puddle_world_states(
            dataset['samples']['state_1'], environment)
        dataset['samples']['reward'] = normalize_puddle_world_rewards(
            dataset['samples']['reward'])

        return dataset

    datasets = {
        seed: generate_dataset(
            seed,
            behavior_policy,
            target_policy,
            environment_params,
            num_samples=20000,  # num_steps,
            independent_samples=True)
        for seed in seeds
    }

    dataset_object_ids = {
        str(seed): ray.put(dataset, weakref=False)
        for seed, dataset in datasets.items()
    }

    def compute_value_function(policy, environment_params):
        environment_params = environment_params.copy()
        environment = environments.PuddleWorld(
            **environment_params['config'])

        discount = DISCOUNT
        num_rollouts = 1000

        value_function_key = json.dumps({
            'environment': environment_params,
            'behavior_policy': behavior_policy.get_config(),
            'target_policy': target_policy.get_config(),
            'discount': discount,
            'num_rollouts': num_rollouts,
        }, sort_keys=True, separators=',:')

        cache_path = (
            CACHE_DIR
            / 'value_functions'
            / f'value_function-{value_function_key}')
        if os.path.exists(cache_path):
            with open(cache_path, 'rb') as f:
                value_function = pickle.load(f)
        else:
            value_function = environments.utils.estimate_value_function(
                environment,
                policy,
                discount=discount,
                num_rollouts=num_rollouts,
                max_rollout_length=1000)
            os.makedirs(CACHE_DIR / 'value_functions', exist_ok=True)
            with open(cache_path, 'wb') as f:
                pickle.dump(value_function, f)

        states, values = value_function
        states = normalize_puddle_world_states(states, environment)
        values = normalize_puddle_world_rewards(values)

        return states, values

    value_function = compute_value_function(target_policy, environment_params)
    value_function_object_id = ray.put(value_function, weakref=False)

    experiment_config = {
        'dataset_object_ids': dataset_object_ids,
        'value_function_object_id': value_function_object_id,
        'algorithm_params': algorithm_params[algorithm],
        'value_function_params': value_function_params,
        'experiment_params': {
            'total_samples': num_steps,
            'epoch_length': epoch_length,
            **experiment_params,
        },
        'run_params': run_params,
        'environment_params': environment_params,
        'task_params': {
            'class_name': 'ValuePredictionTask',
            'config': {
                'criteria': ['RMSE', 'MSE', ],
                'batch_size': 512,
            },
        },
        'git_rev': get_git_rev(PROJECT_ROOT),
    }

    if experiment_name is not None:
        experiment_name = '-'.join((datetime_stamp(), experiment_name))
    else:
        experiment_name = datetime_stamp()

    local_dir = os.path.join(
        PROJECT_ROOT, 'data', ('debug' if debug else ''), 'puddle_world')

    tune.run(
        PuddleWorldExperimentRunner,
        name=experiment_name,
        # name='puddle_world',
        config=experiment_config,
        resources_per_trial={
            'cpu': 1,
            # 'gpu': 1.0,
            'gpu': 0.0,
        },
        local_dir=local_dir,
        # num_samples=run_params['num_samples'],
        num_samples=1,  # Use seeds to generate multiple samples
        # upload_dir=upload_dir,
        checkpoint_freq=0,
        checkpoint_at_end=False,
        max_failures=0,
        restore=None,
        with_server=False,
        scheduler=None,
        loggers=(ray.tune.logger.CSVLogger, ray.tune.logger.JsonLogger),
        reuse_actors=True)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--mode',
        type=str,
        choices=('train', 'visualize'),
        default='train')

    parser.add_argument('--num-samples', type=int, default=25)
    parser.add_argument('--num-steps', type=int, default=5000)
    parser.add_argument('--epoch-length', type=int, default=25)
    parser.add_argument('--experiment-path', type=str, default=None)
    parser.add_argument('--experiment-name', type=str, default=None)
    parser.add_argument('--algorithm', type=str, default='bbo-rp')
    parser.add_argument(
        '--use-wandb',
        type=lambda x: bool(strtobool(x)),
        nargs='?',
        const=True,
        default=False)
    parser.add_argument(
        '--debug',
        type=lambda x: bool(strtobool(x)),
        nargs='?',
        const=True,
        default=False,
        help="Whether or not to execute sequentially to allow breakpoints.")

    args = parser.parse_args()
    if args.mode == 'train':
        train(num_samples=args.num_samples,
              num_steps=args.num_steps,
              epoch_length=args.epoch_length,
              debug=args.debug,
              use_wandb=args.use_wandb,
              algorithm=args.algorithm,
              experiment_name=args.experiment_name)
    elif args.mode == 'visualize':
        raise NotImplementedError(args.mode)
        if args.experiment_path is None:
            raise ValueError("Set '--experiment-path [path-to-experiment]'.")
        visualize_experiment(args.experiment_path)
