import numpy as np
import multiprocessing as mp
from .metrics import *
from typing import Sequence, Dict, Union, Set
from itertools import chain


class Evaluate(object):
    def __init__(self, cores: int = 1, topks: Sequence = (10, 20)):
        self.core_num = cores
        self.topks = sorted(topks)
        self.rating_mat = None
        self.except_U2I = None
        self.test_U2I = None
        self.metric_names = list(chain(*[(f"NDCG@{topk}", f"Precision@{topk}", f"Recall@{topk}") for topk in topks]))

    def evaluate(self, uid_list: Union[Sequence, Set], rating_mat: np.ndarray, except_U2I: Dict, test_U2I: Dict):
        """
            (1) uid_list need to be consistent with rating_mat
            (2) test_U2I values can not be empty
        :param uid_list:
        :param rating_mat:
        :param except_U2I:
        :param test_U2I:
        :return:
        """
        self.rating_mat = rating_mat
        self.except_U2I = except_U2I
        self.test_U2I = test_U2I
        num_users = len(uid_list)
        test_parameters = zip(
            uid_list, self.rating_mat, [self.except_U2I] * num_users,
                                       [self.test_U2I] * num_users, [self.topks] * num_users
        )
        with mp.Pool(processes=self.core_num) as pool:
            result_list = pool.map(Evaluate.compute_single_user, test_parameters)
        perf_mat = np.vstack([i['perf'] for i in result_list])
        perf_dict = {item['uid']: item['perf'] for item in result_list}
        return perf_mat, perf_dict

    @staticmethod
    def largest_indices(score, topks):
        max_topk = np.max(topks)
        indices = np.argpartition(score, -max_topk)[-max_topk:]
        indices = indices[np.argsort(-score[indices])]
        return indices

    @classmethod
    def compute_single_user(cls, args):
        uid, ratings, except_U2I, test_U2I, topks = args
        ratings[except_U2I[uid]] = -np.inf
        score_indices = cls.largest_indices(ratings, topks)
        result = {
            'uid': uid,
            'perf': cls.compute_single_user_metric(score_indices, test_U2I[uid], topks)
        }
        return result

    @staticmethod
    def compute_single_user_metric(rank, uid_test_pos_items, topks):
        topk_eval = np.zeros(3 * len(topks), dtype=np.float64)
        for i, topk in enumerate(topks):
            topk_eval[i * 3 + 0] = ndcg_k(rank[:topk], set(uid_test_pos_items))
            topk_eval[i * 3 + 1] = precision_k(rank[:topk], set(uid_test_pos_items))
            topk_eval[i * 3 + 2] = recall_k(rank[:topk], set(uid_test_pos_items))
        return topk_eval
