from __future__ import print_function

from evaluation import roc_auc_single, precision_auc_single, enrichment_factor_single, number_of_hit_single
from sklearn.metrics import mean_squared_error
from sklearn.metrics import mean_absolute_error
from sklearn.metrics import r2_score
from scipy.stats import pearsonr
import numpy as np


def output_classification_result(y_train, y_pred_on_train,
                                 y_val, y_pred_on_val,
                                 y_test, y_pred_on_test, EF_ratio_list, hit_ratio=0.01):

    print('train precision: {}'.format(precision_auc_single(y_pred_on_train, y_train)))
    print('train roc: {}'.format(roc_auc_single(y_pred_on_train, y_train)))
    N = int(len(y_train) * hit_ratio)
    print('train hit in top {}: {} out of {}'.format(N, number_of_hit_single(y_pred_on_train, y_train, N=N), np.sum(y_train)))
    print()

    if y_pred_on_val is not None:
        print('val precision: {}'.format(precision_auc_single(y_pred_on_val, y_val)))
        print('val roc: {}'.format(roc_auc_single(y_pred_on_val, y_val)))
        N = int(len(y_val) * hit_ratio)
        print('val hit in top {}: {} out of {}'.format(N, number_of_hit_single(y_pred_on_val, y_val, N=N), np.sum(y_val)))
        print()

    if y_pred_on_test is not None:
        print('test precision: {}'.format(precision_auc_single(y_pred_on_test, y_test)))
        print('test roc: {}'.format(roc_auc_single(y_pred_on_test, y_test)))
        N = int(len(y_test) * hit_ratio)
        print('test hit in top {}: {} out of {}'.format(N, number_of_hit_single(y_pred_on_test, y_test, N=N), np.sum(y_test)))
        print()
        for EF_ratio in EF_ratio_list:
            n_actives, ef, ef_max = enrichment_factor_single(y_pred_on_test, y_test, EF_ratio)
            print('ratio: {}, EF: {},\tactive: {}'.format(EF_ratio, ef, n_actives))
        print()

    return


def output_regression_result(y_train_binary, y_pred_on_train,
                             y_val_binary, y_pred_on_val,
                             y_test_binary, y_pred_on_test, EF_ratio_list, hit_ratio=0.01):

    print('train precision: {}'.format(precision_auc_single(y_pred_on_train, y_train_binary)))
    print('train roc: {}'.format(roc_auc_single(y_pred_on_train, y_train_binary)))
    N = int(len(y_train_binary) * hit_ratio)
    print('train hit: {} out of {}'.format(number_of_hit_single(y_pred_on_train, y_train_binary, N=N),
                                           np.sum(y_train_binary)))
    print()

    if y_pred_on_val is not None:
        print('val precision: {}'.format(precision_auc_single(y_pred_on_val, y_val_binary)))
        print('val roc: {}'.format(roc_auc_single(y_pred_on_val, y_val_binary)))
        N = int(len(y_val_binary) * hit_ratio)
        print('val hit: {} out of {}'.format(number_of_hit_single(y_pred_on_val, y_val_binary, N=N),
                                               np.sum(y_val_binary)))
        print()

    if y_pred_on_test is not None:
        print('test precision: {}'.format(precision_auc_single(y_pred_on_test, y_test_binary)))
        print('test roc: {}'.format(roc_auc_single(y_pred_on_test, y_test_binary)))
        N = int(len(y_test_binary) * hit_ratio)
        print('test hit: {} out of {}'.format(number_of_hit_single(y_pred_on_test, y_test_binary, N=N),
                                              np.sum(y_test_binary)))
        print()
        for EF_ratio in EF_ratio_list:
            n_actives, ef, ef_max = enrichment_factor_single(y_pred_on_test, y_test_binary, EF_ratio)
            print('ratio: {}, EF: {},\tactive: {}'.format(EF_ratio, ef, n_actives))
        print()

    return


def rms_score(y_true, y_pred):
  """Computes RMS error."""
  return np.sqrt(mean_squared_error(y_true, y_pred))


def mae_score(y_true, y_pred):
  """Computes MAE."""
  return mean_absolute_error(y_true, y_pred)


def pearson_r2_score(y, y_pred):
    """Computes Pearson R^2 (square of Pearson correlation)."""
    y, y_pred = np.squeeze(y), np.squeeze(y_pred)
    return pearsonr(y, y_pred)[0] ** 2


def output_regression_result_no_binary(y_train, y_pred_on_train,
                                       y_val, y_pred_on_val,
                                       y_test, y_pred_on_test):
    def output(y_true, y_pred, mode):
        pearson_r2 = pearson_r2_score(y_true, y_pred)
        r2 = r2_score(y_true, y_pred)
        rms = rms_score(y_true, y_pred)
        mae = mae_score(y_true, y_pred)
        print('Pearson R2 on {}: {}'.format(mode, pearson_r2))
        print('R2 on {}: {}'.format(mode, r2))
        print('RMSE on {}: {}'.format(mode, rms))
        print('MAE on {}: {}'.format(mode, mae))
        print()
        return

    if y_pred_on_train is not None:
        output(y_train, y_pred_on_train, 'train set')
    if y_pred_on_val is not None:
        output(y_val, y_pred_on_val, 'val set')
    if y_pred_on_test is not None:
        output(y_test, y_pred_on_test, 'test set')
    return
