from bc_params_factory import *
import os
from rts_bc import RTSBC
from rts_gradient_handler import RTSStairGradientHandler

OVERALL_LOG_FILE = 'gsm_overall_log.txt'

def decide_train_batch_size(network_type):
    if 'rc' in network_type or 'dc' in network_type or 'vc' in network_type:
        return 64
    if 'rh' in network_type or 'dh' in network_type or 'vh' in network_type:
        return 64
    if 'vx' in network_type:
        return 64
    if 'vy' in network_type:
        return 64
    if 'cfqk' in network_type:
        return 64
    if 'wrn' in network_type:
        return 128
    if 'lenet' in network_type:
        return 256
    if network_type in ['resnet18', 'resnet34', 'resnet50', 'seresnet50']:
        return 64
    if network_type in ['resnet101', 'resnet152']:
        return 32
    if 'alex' in network_type:
        return 128
    if network_type in ['densenet121', 'inception3']:
        return 64
    if network_type == 'mobilenet':
        return 96
    if 'vgg' in network_type:
        return 32
    assert False

def decide_base_l2_factor(network_type):
    if 'rc' in network_type or 'dc' in network_type or 'vc' in network_type:
        return 1e-4
    if 'rh' in network_type or 'dh' in network_type or 'vh' in network_type:
        return 1e-4
    if 'vx' in network_type:
        return 1e-4
    if 'vy' in network_type:
        return 1e-4
    if 'cfqk' in network_type:
        return 4e-3
    if 'wrn' in network_type:
        return 5e-4
    if 'res' in network_type:
        return 1e-4
    if 'alex' in network_type:
        return 5e-4
    if 'lenet' in network_type:
        return 5e-4
    if network_type in ['densenet121', 'inception3']:
        return 1e-4
    if network_type == 'mobilenet':
        return 4e-5
    if 'vgg' in network_type:
        return 5e-4
    assert False


def eval_with_global_sparse_ratios(network_type, deps, use_dense_layer, weights, sparse_ratios, last_to_save, target_layers, data_dir):
    eval_params = default_params_for_eval(network_type, deps=deps, data_dir=data_dir)
    eval_my_params = default_my_params(init_hdf5=None, eval_log_file='gsm_global_sparse_eval_records.txt',
        just_compile=True, use_dense_layer=use_dense_layer)
    rts_params = get_rts_params(step_size=None, step_thres=None, zeroout_decay=None, subsequent_strategy=None,
        power=None, thresh_delay=None, target_layers=target_layers)
    eval_bc = RTSBC(params=eval_params, my_params=eval_my_params, rts_params=rts_params)
    eval_bc.run()

    eval_bc.load_weights_from_hdf5(weights)
    kts = eval_bc.get_kernel_variables(target_layers=target_layers)
    kvs = eval_bc.get_value(kts)
    all_weights = []
    for v in kvs:
        all_weights.append(np.ravel(v))
    all_weights = np.concatenate(all_weights)
    all_abs_weights = np.abs(all_weights)
    sorted_abs_weights = np.sort(all_abs_weights)

    message = ''

    kernel_name_to_mask_value = {}


    for sr in sparse_ratios:
        if sr > 0:
            nz = int((1 - sr) * len(all_abs_weights))
            print('we keep {} non-zero values'.format(nz))
            thresh = sorted_abs_weights[-nz]
            assign_t = []
            assign_v = []
            for t, v in zip(kts, kvs):
                mask = np.array(np.abs(v) >= thresh, dtype=np.float32)
                kernel_name_to_mask_value[t.name] = mask
                masked_v = v * mask
                assign_t.append(t)
                assign_v.append(masked_v)
            eval_bc.batch_set_value(assign_t, assign_v)

        ######## double check
        kts = eval_bc.get_kernel_variables(target_layers=target_layers)
        kvs = eval_bc.get_value(kts)
        all_weights = []
        for v in kvs:
            all_weights.append(np.ravel(v))
        all_weights = np.concatenate(all_weights)
        zero_cnt = np.sum(all_weights == 0)
        actual_zero_ratio = zero_cnt / np.size(all_weights)

        result_dict = eval_bc.simple_eval(eval_record_comment='azr={:.4f}'.format(actual_zero_ratio))
        message += 'top1={:.5f},azr={:.4f};'.format(result_dict['top1'], actual_zero_ratio)

        if sr == sparse_ratios[-1]:
            eval_bc.save_weights_to_hdf5(last_to_save)

    return message + '\n', kernel_name_to_mask_value



def gsm_pipeline(network_type, try_arg, init_weights,
                 lr_warmup_epochs, lr_values, lr_epoch_boundaries, max_epochs,
                 zero_rate, deps, num_gpus,
                 apply_l2_on_vecs=True, use_dense_layer=False, batch_size=None,

                 adjust_dropout_rate_thresh=None, target_layers=None,

                 metric_type='abs_gv', momentum=0.9,

                 init_step=0, load_ckpt=None, summary_verbosity=1, data_format='NHWC', zero_relax=1e-4, num_steps_per_hdf5=100000, specify_l2=None, data_dir=None
                 ):

    if batch_size is not None:
        print('specify batchsize: ', batch_size)

    final_sparsified_weights_file = '{}_{}_final_azr_{:.4f}.hdf5'.format(network_type, try_arg, zero_rate)
    train_dir = '{}_{}_train'.format(network_type, try_arg)
    save_hdf5 = '{}_{}_savedweights.hdf5'.format(network_type, try_arg)

    if not os.path.exists(save_hdf5):
        train_params = default_params_for_train(network_type, train_dir=train_dir,
            batch_size=batch_size or decide_train_batch_size(network_type),
            optimizer_name='momentum', num_gpus=num_gpus, weight_decay=0,
            max_epochs=max_epochs, use_default_lr=False, data_format=data_format, deps=deps, use_distortions=True,
            lr_warmup_epochs=lr_warmup_epochs, lr_values=lr_values, lr_epoch_boundaries=lr_epoch_boundaries,
            summary_verbosity=summary_verbosity, adjust_dropout_rate_thresh=adjust_dropout_rate_thresh,
            momentum=momentum, save_model_steps=50000, data_dir=data_dir)
        train_my_params = default_my_params(init_hdf5=init_weights,
            save_hdf5=save_hdf5,
            num_steps_per_hdf5=num_steps_per_hdf5, show_variables=True,
            save_mvav=False, auto_continue=True, frequently_save_interval=None,
            frequently_save_last_epochs=None,
            should_write_graph=False, init_global_step=init_step, load_ckpt=load_ckpt,
            apply_l2_on_vector_params=apply_l2_on_vecs, use_dense_layer=use_dense_layer)
        rts_params = get_rts_params(step_size=None, step_thres=None, zeroout_decay=None, subsequent_strategy=None,
            power=None, thresh_delay=None, target_layers=target_layers)
        train_bc = RTSBC(params=train_params, my_params=train_my_params, rts_params=rts_params)
        gh = RTSStairGradientHandler(zero_rate=zero_rate + zero_relax,
            l2_factor=specify_l2 if specify_l2 is not None else decide_base_l2_factor(network_type),
            fast_decay_factor=specify_l2 if specify_l2 is not None else decide_base_l2_factor(network_type), bc=train_bc, target_layers=target_layers,
            metric_type=metric_type, l2_on_vecs=apply_l2_on_vecs)

        train_bc.set_gradient_handler(gh)
        train_bc.run()
        del train_bc

        srs = [0, zero_rate]

        azr_msg, final_kernel_name_to_mask_value = eval_with_global_sparse_ratios(network_type, deps=deps,
                use_dense_layer=use_dense_layer, weights=save_hdf5, sparse_ratios=[s for s in srs if s < 1.0],
                last_to_save=final_sparsified_weights_file,
                target_layers=target_layers, data_dir=data_dir)
        with open(OVERALL_LOG_FILE, 'a') as f:
            f.write('{},{}:'.format(network_type, try_arg) + azr_msg)






