import numpy as np
import logging
import math

from ...constants import BINARY, MULTICLASS, REGRESSION, SOFTCLASS
from .....try_import import try_import_catboostdev
from ....metrics import soft_log_loss


logger = logging.getLogger(__name__)

# TODO: Add weight support?
# TODO: Can these be optimized? What computational cost do they have compared to the default catboost versions?
class CustomMetric:
    def __init__(self, metric, is_higher_better, needs_pred_proba):
        self.metric = metric
        self.is_higher_better = is_higher_better
        self.needs_pred_proba = needs_pred_proba

    def get_final_error(self, error, weight):
        return error

    def is_max_optimal(self):
        return self.is_higher_better

    def evaluate(self, approxes, target, weight):
        raise NotImplementedError


class BinaryCustomMetric(CustomMetric):
    def _get_y_pred_proba(self, approxes):
        return np.array(approxes[0])

    def _get_y_pred(self, y_pred_proba):
        return np.round(y_pred_proba)

    def evaluate(self, approxes, target, weight):
        y_pred_proba = self._get_y_pred_proba(approxes=approxes)

        # TODO: Binary log_loss doesn't work for some reason
        if self.needs_pred_proba:
            score = self.metric(np.array(target), y_pred_proba)
        else:
            raise NotImplementedError('Custom Catboost Binary prob metrics are not supported by AutoGluon.')
            # y_pred = self._get_y_pred(y_pred_proba=y_pred_proba)  # This doesn't work at the moment because catboost returns some strange valeus in approxes which are not the probabilities
            # score = self.metric(np.array(target), y_pred)

        return score, 1


class MulticlassCustomMetric(CustomMetric):
    def _get_y_pred_proba(self, approxes):
        return np.array(approxes)

    def _get_y_pred(self, y_pred_proba):
        return y_pred_proba.argmax(axis=0)

    def evaluate(self, approxes, target, weight):
        y_pred_proba = self._get_y_pred_proba(approxes=approxes)
        if self.needs_pred_proba:
            raise NotImplementedError('Custom Catboost Multiclass proba metrics are not supported by AutoGluon.')
            # y_pred_proba = y_pred_proba.reshape(len(np.unique(np.array(target))), -1).T
            # score = self.metric(np.array(target), y_pred_proba)  # This doesn't work at the moment because catboost returns some strange valeus in approxes which are not the probabilities
        else:
            y_pred = self._get_y_pred(y_pred_proba=y_pred_proba)
            score = self.metric(np.array(target), y_pred)

        return score, 1


class RegressionCustomMetric(CustomMetric):
    def _get_y_pred(self, approxes):
        return np.array(approxes[0])

    def evaluate(self, approxes, target, weight):
        y_pred = self._get_y_pred(approxes=approxes)
        score = self.metric(np.array(target), y_pred)

        return score, 1


# Ojectives for SOFTCLASS problem_type
# TODO: these require catboost_dev or catboost>=0.24
class SoftclassCustomMetric(CustomMetric):
    try_import_catboostdev()
    from catboost_dev import MultiRegressionCustomMetric
    def __init__(self, metric, is_higher_better, needs_pred_proba):  # metric is ignored
        super().__init__(metric, is_higher_better, needs_pred_proba)
        try_import_catboostdev()
        self.softlogloss = self.SoftLogLossMetric()  # the metric object to pass to CatBoostRegressor

    def evaluate(self, approxes, target, weight):
        return self.softlogloss.evaluate(approxes, target, weight)

    class SoftLogLossMetric(MultiRegressionCustomMetric):
        def get_final_error(self, error, weight):
            return error

        def is_max_optimal(self):
            return True

        def evaluate(self, approxes, target, weight):
            assert len(target) == len(approxes)
            assert len(target[0]) == len(approxes[0])
            weight_sum = len(target)
            approxes = np.array(approxes)
            approxes = np.exp(approxes)
            approxes = np.multiply(approxes, 1/np.sum(approxes, axis=1)[:, np.newaxis])
            error_sum = soft_log_loss(np.array(target), np.array(approxes))
            return error_sum, weight_sum

class SoftclassObjective(object):
    try_import_catboostdev()
    from catboost_dev import MultiRegressionCustomObjective
    def __init__(self):
        try_import_catboostdev()
        self.softlogloss = self.SoftLogLossObjective()  # the objective object to pass to CatBoostRegressor

    class SoftLogLossObjective(MultiRegressionCustomObjective):
        # TODO: Consider replacing with C++ implementation (but requires building catboost from source).
        # This pure Python is 3x faster than optimized Numpy implementation. Tested C++ implementation was 3x faster than this one.
        def calc_ders_multi(self, approxes, targets, weight):
            exp_approx = [math.exp(val) for val in approxes]
            exp_sum = sum(exp_approx)
            exp_approx = [val / exp_sum for val in exp_approx]
            grad = [(targets[j] - exp_approx[j])*weight for j in range(len(targets))]
            hess = [[(exp_approx[j] * exp_approx[j2] - (j==j2)*exp_approx[j]) * weight
                    for j in range(len(targets))] for j2 in range(len(targets))]
            return (grad, hess)


metric_classes_dict = {
    BINARY: BinaryCustomMetric,
    MULTICLASS: MulticlassCustomMetric,
    REGRESSION: RegressionCustomMetric,
}


def construct_custom_catboost_metric(metric, is_higher_better, needs_pred_proba, problem_type):
    if problem_type == SOFTCLASS:
        if metric.name != 'soft_log_loss':
            logger.warning("Setting metric=soft_log_loss, the only metric supported for softclass problem_type")
        return SoftclassCustomMetric(metric=None, is_higher_better=True, needs_pred_proba=True)
    if (metric.name == 'log_loss') and (problem_type == MULTICLASS) and needs_pred_proba:
        return 'MultiClass'
    if metric.name == 'accuracy':
        return 'Accuracy'
    if (metric.name == 'log_loss') and (problem_type == BINARY) and needs_pred_proba:
        return 'Logloss'
    if (metric.name == 'f1') and (problem_type == BINARY) and not needs_pred_proba:
        return 'F1'
    if (metric.name == 'balanced_accuracy') and (problem_type == BINARY) and not needs_pred_proba:
        return 'BalancedAccuracy'
    if (metric.name == 'recall') and (problem_type == BINARY) and not needs_pred_proba:
        return 'Recall'
    if (metric.name == 'precision') and (problem_type == BINARY) and not needs_pred_proba:
        return 'Precision'
    metric_class = metric_classes_dict[problem_type]
    return metric_class(metric=metric, is_higher_better=is_higher_better, needs_pred_proba=needs_pred_proba)


## Scratch ##
class OLDSoftclassObjective(object):
    """ Custom training objective for SOFTCLASS problem_type.
        Ignores weights for now.
    """
    def calc_ders_multi(self, approx, target, weight):
        approx = np.array(approx) - max(approx)
        exp_approx = np.exp(approx)
        exp_sum = exp_approx.sum()
        grad = []
        hess = []
        for j in range(len(approx)):
            der1 = -exp_approx[j] / exp_sum
            if j == target:
                der1 += 1
            hess_row = []
            for j2 in range(len(approx)):
                der2 = exp_approx[j] * exp_approx[j2] / (exp_sum**2)
                if j2 == j:
                    der2 -= exp_approx[j] / exp_sum
                hess_row.append(der2)  # To respect weights: hess_row.append(der2 * weight)

            grad.append(der1)  # To respect weights: grad.append(der1 * weight)
            hess.append(hess_row)

        return (grad, hess)
