import gc
import json
import os
import pickle

import numpy as np
import pandas as pd
# from pypai.io import TableReader
# from aistudio_common.utils import env_utils
from sklearn.preprocessing import StandardScaler
from sklearn.utils.validation import check_array


class NewKBinsDiscretizer(object):
    """A new method to bin continuous data into intervals.
    All bins in each feature have identical widths.
    I refer to the implementation of lightgbm when implementing, see
    <https://github.com/microsoft/LightGBM/blob/v2.2.2/src/io/bin.cpp#L74-L150>
    for more details.

    Args:
        n_bins (int, optional): The number of bins to produce. Defaults to 32.
    """
    def __init__(self, n_bins=32):
        self.n_bins = n_bins

    def fit(self, X):
        """Fit the discretizer.

        Args:
            X (numeric array-like): The data to be used to determine the interval.
                The shape should be (n_samples, n_features).
        Returns:
            self
        """
        Xt = check_array(X, copy=True)  # Xt is np.ndarray
        n_features = Xt.shape[1]
        n_bins = np.full(n_features, self.n_bins, dtype=np.int)
        self.n_features = n_features
        self.n_bins_ = n_bins
        sample_count = len(Xt)  # n_samples

        bin_edges = []
        for feature_id in range(n_features):
            vals, counts = np.unique(Xt[:, feature_id], return_counts=True)
            val_counts_dict = dict(zip(vals, counts))
            if len(val_counts_dict) <= n_bins[feature_id]:
                # if the number of unique value is less then n_bins
                # then each unique value has one unique value
                new_bins = list(val_counts_dict.keys())
                new_bins = [(new_bins[i] + new_bins[i + 1]) / 2
                            for i in range(len(new_bins) - 1)]
                new_bins = [-1e50] + new_bins + [1e50]
                count_in_bin_list = list(val_counts_dict.values())
            else:
                count_in_bin_list = []  # number of samples in each bin
                mean_bin_size = sample_count / n_bins[
                    feature_id]  # mean size for one bin

                # if the count of a value is greater than `mean_bin_size`
                is_big_count_value = [False] * len(vals)
                for i, count in enumerate(counts):
                    if count > mean_bin_size:
                        is_big_count_value[i] = True

                bin_count = 0
                count_in_bin = 0
                upper_bounds = []
                lower_bounds = []
                for i, count in enumerate(counts[:-1]):
                    count_in_bin += count
                    # if we meet a big count value of the sample count in current bin
                    # is greater than `mean_bin_size`, then start a new bin
                    if is_big_count_value[i] or count_in_bin > mean_bin_size or \
                            is_big_count_value[i + 1] and count_in_bin >= 1:
                        upper_bounds.append(vals[i])
                        lower_bounds.append(vals[i + 1])
                        bin_count += 1
                        count_in_bin_list.append(count_in_bin)
                        count_in_bin = 0
                bin_count += 1
                count_in_bin_list.append(count_in_bin)
                new_bins = [-1e50]
                for i in range(bin_count - 1):
                    new_bins.append((lower_bounds[i] + upper_bounds[i]) / 2)
                new_bins.append(1e50)

            # if the number of bins is greater than :attr:`n_bins`, then merge the bins
            # untils the number of bins is equal to :attr:`n_bins`
            while len(count_in_bin_list) > self.n_bins:
                # the bin with minimum samples
                min_index = np.argmin(count_in_bin_list)
                # merge the left bin or the right bin
                if min_index == 0:
                    merge_mode = "right"
                elif min_index == bin_count - 1:
                    merge_mode = "left"
                elif count_in_bin_list[min_index +
                                       1] < count_in_bin_list[min_index - 1]:
                    merge_mode = "right"
                else:
                    merge_mode = "left"
                if merge_mode == "left":
                    count_in_bin_list[min_index -
                                      1] += count_in_bin_list[min_index]
                    del new_bins[min_index]
                else:
                    count_in_bin_list[min_index +
                                      1] += count_in_bin_list[min_index]
                    del new_bins[min_index + 1]
                del count_in_bin_list[min_index]
            n_bins[feature_id] = len(new_bins) - 1
            bin_edges.append(new_bins)
        self.bin_edges_ = bin_edges
        return self

    def get_params(self):
        return self.n_bins_, self.bin_edges_

    def transform(self, X):
        """Transform the data

        Args:
            X (numeric array-like): The data to be transformed.
                The shape should be (n_samples, n_features).
        Returns:
            The transformed data. The shape is (n_samples, n_features).
        """
        Xt = check_array(X, copy=True)

        n_features = self.n_features
        if Xt.shape[1] != n_features:
            raise ValueError(
                f"Expecting {n_features} feature(s), received {Xt.shape[1]}.")

        bin_edges = self.bin_edges_
        for feature_id in range(n_features):
            # cut the value
            Xt[:, feature_id] = np.digitize(Xt[:, feature_id],
                                            bin_edges[feature_id][1:])
        #np.clip(Xt, 0, self.n_bins_ - 1, out=Xt)  # clip the value
        Xt = (Xt + 1).astype(np.int8)

        return Xt


class NewQuantileTransformer(object):
    """Transform the data into uniform distribution data based on quantile information.

    Args:
        n_quantiles (int, optional): The number of n_quantiles to be calculated. Defaults to 10000.
    """
    def __init__(self, n_quantiles=10000):
        self.n_quantiles = n_quantiles

    def fit(self, X):
        """Fit the discretizer.

        Args:
            X (numeric array-like): The data to be used to calculate quantiles.
                The shape should be (n_samples, n_features).

        Returns:
            self
        """
        Xt = check_array(X, copy=True)  # Xt is np.ndarray
        n_features = X.shape[1]
        n_quantiles = self.n_quantiles

        # the quantiles
        references = np.array(
            [x / n_quantiles for x in range(n_quantiles + 1)])
        self.n_features = n_features
        self.bins_ = [[], []]  # the bins that is used to cut the value
        self.labels_ = [[], []]  # the labels of bins
        for feature_id in range(n_features):
            quantiles0 = pd.Series(
                Xt[:, feature_id]).quantile(references).sort_values()

            # the bins and labels to calculate P(X\le x)
            quantiles1 = quantiles0.drop_duplicates(keep="last").to_dict()
            bins = [-1e50] + list(quantiles1.values()) + [1e50]
            labels = [0] + list(quantiles1.keys())
            # self.bins_[0].append(np.array(bins))
            # self.labels_[0].append(np.array(labels))
            self.bins_[0].append(bins)
            self.labels_[0].append(labels)

            # the bins and labels to calculate P(X\lt x)
            quantiles2 = quantiles0.drop_duplicates(keep="first").to_dict()
            bins = [-1e50] + list(quantiles2.values()) + [1e50]
            labels = list(quantiles2.keys()) + [1]
            # self.bins_[1].append(np.array(bins))
            # self.labels_[1].append(np.array(labels))
            self.bins_[1].append(bins)
            self.labels_[1].append(labels)

        return self

    def get_params(self):
        return self.n_quantiles, self.bins_, self.labels_

    def transform(self, X):
        """Transform the data.

        Args:
            X (numeric array-like): The data to be transformed.
                The shape should be (n_samples, n_features).

        Returns:
            The transformed data. The shape is (n_samples, n_features * 2).
        """
        Xt = check_array(X, copy=True)  # Xt is np.ndarray

        n_features = self.n_features
        if Xt.shape[1] != n_features:
            raise ValueError(
                f"Expecting {n_features} feature(s), received {Xt.shape[1]}.")

        Xt = np.hstack([Xt, Xt])  # shape: (n_samples, n_features * 2)

        for feature_id in range(n_features):
            # cut the value and set the label
            bin_index = np.digitize(Xt[:, feature_id],
                                    self.bins_[0][feature_id][1:])
            tmp1 = np.array(self.labels_[0][feature_id])
            Xt[:, feature_id] = tmp1[bin_index]

            bin_index = np.digitize(Xt[:, feature_id + n_features],
                                    self.bins_[1][feature_id][1:])
            tmp2 = np.array(self.labels_[1][feature_id])
            Xt[:, feature_id + n_features] = tmp2[bin_index]

        Xt = Xt * self.n_quantiles
        Xt = Xt.astype(np.int16)
        return Xt


# TODO 正式发布时，这部分需要删掉，写到test中
# def fit(project_table_name, o_tunnel_endpoint, feature_names, params_n_bins, params_n_quantiles, output_file):
#     """
#     使用点表或者边表预训练，得到分桶分界点 与 分位数变换的分界点和分解点值，保存到prefit/preconfig_*文件中

#     Args:
#         project_table_name: 预训练的表名（点表或者边表）
#         o_tunnel_endpoint: [description]
#         feature_names: 需要预处理的特征列名
#         params_n_bins: 分桶的数量
#         params_n_quantiles: 分位数变换的数量
#         output_file: 保存文件
#     """
#     o = env_utils.get_odps_instance()
#     o._tunnel_endpoint = o_tunnel_endpoint
#     o = env_utils.get_odps_instance()
#     reader = TableReader.from_ODPS_type(o, project_table_name)
#     data = reader.to_pandas(columns=feature_names)
#     data = data.astype('float')
#     print('data.head: ', data.head())
#     print('feature_names len: ', len(feature_names))
#     print('feature_names: ', feature_names)

#     # 分箱
#     feature_kbins = NewKBinsDiscretizer(n_bins=params_n_bins)
#     feature_kbins.fit(data.values)
#     n_bins_, bin_edges_ = feature_kbins.get_params()
#     print('n_bins_: ', n_bins_)
#     print('bin_edges_: ', bin_edges_)

#     # 分位数变换
#     feature_qts = NewQuantileTransformer(n_quantiles=params_n_quantiles)
#     feature_qts.fit(data.values)
#     n_quantiles, bins_, labels_ = feature_qts.get_params()
#     print('n_quantiles: ', n_quantiles)
#     print('bins_: ', bins_)
#     print('labels_: ', labels_)

#     # 保存文件，保存为7个常量
#     f = open(output_file, 'w')
#     f.write('n_features = ' + str(len(feature_names)) + '\n')
#     f.write('feature_names = ' + str(feature_names) + '\n')     # 特征数量
#     f.write('n_bins_ = ' + str(params_n_bins) + '\n')           # 分桶数量，转换后的值范围为[1, n_bins_+1]，从1开始计数
#     f.write('bin_edges_ = ' + str(bin_edges_) + '\n')           # 分桶分界点，是一个二维list，第一维表示特征的数量，第二维表示该特征的边界值。
#     f.write('n_quantiles = ' + str(params_n_quantiles) + '\n')  # 分位数个数
#     f.write('bins_ = ' + str(bins_) + '\n')                     # 分位数[左分界点，右分界点]，bins_[0]表示第一种分位点边界，bins_[1]表示第二种分位点边界。
#     f.write('labels_ = ' + str(labels_) + '\n')                 # 分位数[左分界点值，右分界点值]，数据满足len(bins_[0][i])等于len(labels_[0][i])+1。
#     f.close()
#     '''
#     output_file文件格式demo：
#     n_features = 1
#     feature_names = ["user_cnt"]
#     n_bins_ = 64
#     bin_edges_ = [[-1e50,0.2,0.4,1.3,1e50],[-1e50,0.1,5.6,7.8,12.3,1e50]]
#     n_quantiles = 1000
#     bins_ = [[[-1e50,0.2,3.4,7.6,1e50],[-1e50,2.3,5.4,1e50]],[[-1e50,0.2,3.4,7.6,1e50],[-1e50,2.3,5.4,1e50]]]
#     labels_ = [[[0.0,0.1,0.3,0.9],[0.0,0.5,0.8]],[[0.1,0.2,0.9,1.0],[0.5,0.8,1.0]]]
#     '''

# if __name__ == "__main__":
#     ####### 需要配置的参数，其他代码都不需要修改start ########
#     # 预处理表名，需要修改，每个特征单独占一列
#     project_table_name = "ant_p13n_dev.demo_iriskgraph_fraud_ubd_feature_table"
#     o_tunnel_endpoint = "http://service-us.odps.aliyun-inc.com/api" # 一般不需要修改
#     # 预处理的特征列，需要修改，保证与之后的预处理的点边特征顺序一致
#     feature_names = ["model_user_dnbd_score_avg_e_6h","model_inacc_dnbd_v5_score_max_3d","model_inacc_dnbd_v5_score_max_7d","model_inacc_dnbd_v5_score_max_1d","model_oacc_redrecharge_amt_3h","mer_fraud_inacc_model_graphy_gang_score","mer_std_history_both_side_relation_1_90d","model_acc_opp_wp_amt_6h","usr_opp_cer_diff_dd","model_inacc_dnbd_v5_score_avg_7d","ind_opp_city","model_acc_change_payment_5m","opp_acc_rain_mod","model_opps_phn_city_rate","user_age_period","b_amt","oppcert_ts_ucnt_15d_r_oppo_max_30d","model_oppo_wp_rec_amt_7d_oppo_min_30d","trade_total_daycnt_180d","model_offfp_score","model_oppo_wp_rec_amt_7d","usr_wppayamt_1d_rt","model_paychannel_credit","acc_opp_trd_lbs_distance","gmt_sign_days","model_oacc_all_punish_fraud_event_cnt_24h","model_acc_change_payment_5m_oppo_avg_30d","model_pair_amt_avg_7d","model_inacc_certno_complaint_45d_cnt","oacc_inacc_contact_phone_relation","oppcert_ts_cnt_15d_r","fraud_oacc_inac_3d_lbs_same_flag","usr_amtlevel_rt","model_opp_weijin_task_30d_rt","model_user_dnbd_score_avg_e_1h","mer_merch_rcving_creditcard_pay_ratio","c_oppo_balance","model_oppo_repurchase_1d","model_opp_lbsprovpay_cnt_30d","mer_gamble_user_trade_scene_switch","opp_user_cert_days","mer_gamble_direct_uid_select_count_15min","model_oacc_asset_level","model_fraud_occa_lbs_lon","model_opp_avg_payamt_90d","model_opp_lbsprov_inc_dt_30d","zl_seller_fg_user_pct_e_1d","head_portrait_score","model_oppos_crt_days","model_opp_dayseq_lstm_score","model_oppo_sign_reject_cnt","model_buyer_scan_code_select_ratio_30m_7d","model_oppo_buyers_7d","model_usr_scancode_search_cnt_15m_oppo_avg_30d","usr_bustype_rt","model_oppo_succtradedays_90d","opp_ts_cnt_15d_r","model_oacc_avg_payamt_90d_oppo_max_30d","mer_gamble_direct_qrcode_lbs_city","model_oppphn_reped_cnt_30d","model_accopp_first_amt_rt_7d","model_opp_wp_lbsprov_cnt_7d_oppo_avg_30d","opp_rain_score_mobile","model_accopp_first_amt_rt_6h","model_gamble_gambler_level","model_fraud_occa_lbs_lat","model_opp_offline_mod_score","mer_trade_qr_f2f_prov_cnt7d_uid","c_oppo_balance_oppo_avg_30d","opp_fr_cnt","model_opp_scan_cnt_7d","model_oacc_payamt_7d","model_oppo_credit_amt_ratio_1d","mer_merch_night_trade_ratio_7d","opp_avg_pay_cnt_90d","model_opp_lbsprov_inc_dt_30d_oppo_avg_30d","model_oacc_wp_payamt_1d_oppo_max_30d","opp_trd_dd_90d","model_gamble_total_trade_amt_total_7d","model_amt_vs_opp_succ_rec_30d","mer_gamble_direct_qrcodescantype_risk_1d","mer_fraud_inacc_bhvior_interger_cnt_ratio_7d","model_oacc_amt_vs_wp_per_trd_1_30d","mer_buyer_if_taobao_seller_copy","modelp_complaint_gambling_uid_proportion_30d3","model_fraud_oacc_rcv_acc_cnt_7d","trade_total_certcnt_90d","model_opp_wp_rec_amt_std_2h","mer_gamble_direct_qrcodescantype_highrisk_rate_7d","model_inacc_dnbd_v5_score_avg_1d","model_oppo_rec_amt_1d","model_oppo_tradeamtsuccrate_7d","mer_merch_trade_amt_10s_7d","mer_trade_amt_ten_eight_five","model_usr_scancode_search_cnt_15m","usr_payamt_1h_rt","model_fraud_oacc_behavior_load_loan_app_lvl6","model_oacc_wp_payamt_1d","mer_gamble_direct_gambler_trade_cnt_7d_150","model_amt_vs_usr_per_c2c_trd_14d"]
#     params_n_bins = 64          # 分桶的数量，不需要修改
#     params_n_quantiles = 1000   # 分位数的数量，不需要修改
#     # 保存文件，需要修改
#     output_file = './preconfig_fraud_ubd.py'
#     #######  需要配置的参数，其他代码都不需要修改end  ########

#     father_path = os.path.dirname(output_file)
#     if not os.path.exists(father_path):
#         os.makedirs(father_path)

#     # fit数据，并将分桶和分位数变换的结果保存在output_file文件中。生成的文件太大了，不要直接打开否则会卡一会，使用<head -n 5 output_file>查看一部分数据
#     fit(project_table_name, o_tunnel_endpoint, feature_names, params_n_bins, params_n_quantiles, output_file)
