from tfm_callbacks import GradientHandler
import tensorflow as tf
import numpy as np

from tf_utils import eliminate_all_patterns_and_starting_vs
ELIMINATE_PATTERNS = ['tower_[0-9]/', ':0', 'cg/']

def get_metric_func(metric_type):
    if metric_type == 'abs_gv':
        return lambda g, v: tf.abs(g * v)
    elif metric_type == 'abs_v':
        return lambda g, v: tf.abs(v)
    elif metric_type == 'abs_g':
        return lambda g, v: tf.abs(g)
    elif metric_type == 'sign_g':
        return lambda g, v: - g * tf.sign(v)
    elif metric_type == 'rel_g':
        return lambda g, v: - g / v
    else:
        assert False


class RTSStairGradientHandler(GradientHandler):

    def __init__(self, zero_rate, l2_factor, fast_decay_factor, bc, target_layers, metric_type, l2_on_vecs=False):
        self.zero_rate = zero_rate
        self.l2_factor = l2_factor
        self.fast_decay_factor = fast_decay_factor
        print('RTS stair GH: zero_rate={}, l2_factor={}, fast_decay_factor={}'.format(zero_rate, l2_factor, fast_decay_factor))
        #TODO no L2 on vecs
        self.bc = bc
        self.target_layers = target_layers
        self.metric_type = metric_type
        self.l2_on_vecs = l2_on_vecs


    def handle_gradient(self, origin_grads_and_vars, device_idx, device):
        if self.target_layers is None:
            self.given_v_names = None
        else:
            v0_kvs = [v for v in self.bc.get_kernel_variables(target_layers=self.target_layers)]
            self.given_v_names = [eliminate_all_patterns_and_starting_vs(v.name, ELIMINATE_PATTERNS) for v in v0_kvs]

        to_concat_g = []
        to_concat_v = []
        for (g, v) in origin_grads_and_vars:
            if self.given_v_names is not None and eliminate_all_patterns_and_starting_vs(v.name, ELIMINATE_PATTERNS) not in self.given_v_names:
                continue
            if len(v.get_shape().as_list()) == 4 or len(v.get_shape().as_list()) == 2:
                to_concat_g.append(tf.reshape(g, [-1]))
                to_concat_v.append(tf.reshape(v, [-1]))

        all_g = tf.concat(to_concat_g, axis=0)
        all_v = tf.concat(to_concat_v, axis=0)
        metric_func = get_metric_func(self.metric_type)
        all_metrics = metric_func(all_g, all_v)
        nz = int((1-self.zero_rate)*(all_metrics.get_shape().as_list()[0]))
        print('metric={}, we keep {} non-zero values'.format(self.metric_type, nz))
        top_values, _ = tf.nn.top_k(all_metrics, nz)
        thres = top_values[-1]

        result = []
        for (g, v) in origin_grads_and_vars:
            if len(v.get_shape().as_list()) in [2, 4] and (self.given_v_names is None or eliminate_all_patterns_and_starting_vs(v.name, ELIMINATE_PATTERNS) in self.given_v_names):
                mask = tf.cast((metric_func(g, v) > thres), tf.float32)
                result.append((mask * g + v * (self.l2_factor * mask + self.fast_decay_factor * (1 - mask)), v))
                print('mask a kernel gradient ! ', v.name)
            elif len(v.get_shape().as_list()) in [2, 4]:
                print('normal weight decay on a kernel ', v.name)
                result.append((g + v * self.l2_factor, v))
            elif self.l2_on_vecs:
                print('do weight decay on a vec')
                result.append((g + v * self.l2_factor, v))
            else:
                result.append((g, v))

        return result


