"""
Parameters for single testing run of interaction prediction model.

To update params, do the following...

1) Up the current version number.
2) In _create_params, add the new params behind a version flag.
3) In _attempt_patch, add the appropriate patch behind a version flag.
4) In docstring below, document what the updated params are.

== 1.0 update ==

release version.

"""
import src.util.versioned_params as vp


class TestParams(vp.VersionedParams):
    """Interaction testing params."""

    @classmethod
    def _create_params(cls, inputs, version):
        """Inititalize and check parameters."""

        params = {
            'version': version,
            'dataset_tfrecords': inputs['dataset_tfrecords'],
            'num_testing': inputs['num_testing'],
        }
        cls._set_or_default(params, inputs, 'num_rolls', 20)
        cls._set_or_default(params, inputs, 'num_directions', 20)
        cls._set_or_default(params, inputs, 'batch_size', 40)
        cls._set_or_default(params, inputs, 'towers', 1)
        cls._set_or_default(params, inputs, 'check_nans', False)
        cls._set_or_default(params, inputs, 'neg_roll_sync', True)
        cls._set_or_default(params, inputs, 'directionless', True)
        cls._set_or_default(params, inputs, 'num_interleaved', 20)
        cls._set_or_default(params, inputs, 'prune_file_testing', '')
        cls._set_or_default(params, inputs, 'keep_file_testing', '')
        cls._set_or_default(params, inputs, 'shuffle_buffer', 100)
        cls._set_or_default(params, inputs, 'seq_src', "")
        cls._set_or_default(params, inputs, 'keep_file_pairs_testing', "")
        return params

    def _get_creation_inputs(self):
        """Get arguments used to create the param file from existing params."""
        inputs = dict()
        inputs['num_testing'] = self.params['num_testing']
        inputs['dataset_tfrecords'] = self.params['dataset_tfrecords']

        return inputs

    @classmethod
    def _curr_version(cls):
        """Current version of params."""
        return 1.0


def init_params(args, version=None):
    return TestParams.create(vars(args), version=version)


def load_params(param_json, new_version=None):
    return TestParams.load_updated(param_json, new_version)
