"""Evaluation utils."""

import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from scipy.cluster.hierarchy import linkage, dendrogram
from scipy.spatial.distance import squareform
import sknetwork as skn
from tqdm import tqdm
import time

import utils.tree as tree_utils
from utils.linkage import sl
from mst.mst import reorder


def den_purity_recursive(tree, gt_clusters):
    """ The dendrogram purity formulation from the gHHC paper """

    all_classes = np.unique(gt_clusters)

    def _den_purity_(node):
        children = list(tree.neighbors(node))
        if len(children) == 0:
            class_count = {c: 0 for c in all_classes}
            class_count[gt_clusters[node]] = 1
            return 0, class_count, 1, 0.0
        elif len(children) == 1:
            return _den_purity_(children[0])
        # else:
        #     assert len(children) == 2, "Can only compute dendrogram purity on binary trees for now."

        # Recurse
        pair_counts, class_counts, leaf_counts, puritys = zip(*[_den_purity_(c) for c in children])

        leaf_count = sum(leaf_counts)
        class_count = {c: sum([child[c] for child in class_counts]) for c in all_classes}
        # new_pairs = {c: class_counts[0][c]*class_counts[1][c] for c in all_classes} # binary tree case
        new_pairs = {c: (class_count[c]**2 - sum([child[c]**2 for child in class_counts])) // 2 for c in all_classes}
        pair_count = sum(pair_counts) + sum([new_pairs[c] for c in all_classes])
        purity = sum(puritys) + sum([(class_count[c]/leaf_count)*new_pairs[c] for c in all_classes])
        return pair_count, class_count, leaf_count, purity

    leaves, root = tree_utils.get_leaves_root(tree)
    pair_count, class_count, leaf_count, purity = _den_purity_(root)
    return purity / pair_count

# @profile
def den_purity(tree, gt_clusters):
    """ The dendrogram purity formulation from the gHHC paper

    Stack-based analog of den_purity_recursive to avoid Python recursion limits
    """

    n = len(gt_clusters) * 2 - 1

    all_classes = np.unique(gt_clusters)
    _, root = tree_utils.get_leaves_root(tree)

    # print(n, root, leaves)

    children = [list(tree.neighbors(node)) for node in range(n)] # children remaining to process
    stack = [root]
    # Create the computation buffers leaf_count, purity for all nodes
    pair_count = [None] * n  # number of same-class pairs in subtree
    class_count = [None] * n # number of leaves in subtree
    leaf_count = [None] * n # number of leaves (sum of class counts)
    purity = [None] * n # purity of subtree (not normalized by # pairs)
    while len(stack) > 0:
        node = stack[-1]
        if len(children[node]) > 0:
            stack.append(children[node].pop())
        else:
            # Get children computations
            children_ = list(tree.neighbors(node))
            # Base case: node is a leaf
            if len(children_) == 0:
                pair_count[node] = 0
                class_count[node] = {c: 0 for c in all_classes}
                class_count[node][gt_clusters[node]] = 1
                leaf_count[node] = 1
                purity[node] = 0.0

            else:
                pair_counts = [pair_count[c] for c in children_]
                class_counts = [class_count[c] for c in children_]
                leaf_counts = [leaf_count[c] for c in children_]
                puritys = [purity[c] for c in children_]
                # Free children computations
                for c in children[node]:
                    pair_count[c] = class_count[c] = leaf_count[c] = purity[c] = None

                # Compute new info for this node
                leaf_count[node] = sum(leaf_counts)
                class_count[node] = {c: sum([child[c] for child in class_counts]) for c in all_classes}
                new_pairs = {c: (class_count[node][c]**2 - sum([child[c]**2 for child in class_counts])) // 2 for c in all_classes}
                pair_count[node] = sum(pair_counts) + sum([new_pairs[c] for c in all_classes])
                purity[node] = sum(puritys) + sum([(class_count[node][c]/leaf_count[node])*new_pairs[c] for c in all_classes])
                # class_count[node] = class_count

            assert node == stack.pop()

    return purity[root] / pair_count[root]


# @profile
def dasgupta_cost_recursive(tree, similarities):
    # @profile
    def _dc_top_down(current, cost):
        children = list(tree.neighbors(current))
        if len(children) == 2:
            left_idx = parent2desc[children[0]]
            right_idx = parent2desc[children[1]]
            assert len(left_idx) + len(right_idx) == len(parent2desc[current])
            cost += similarities[left_idx].T[right_idx].sum() * (len(left_idx) + len(right_idx))
            cost = _dc_top_down(children[0], cost)
            cost = _dc_top_down(children[1], cost)
        return cost

    root = max(list(tree.nodes()))
    parent2desc = get_desc(tree, root)
    return 2 * _dc_top_down(root, cost=0.0)


# @profile
def dasgupta_cost_iterative(tree, similarities):
    """ Non-recursive version of DC. Also works on non-binary trees """
    n = len(list(tree.nodes()))
    root = n-1

    cost = [0] * n

    desc = [None] * n # intermediate computation: children of node

    children = [list(tree.neighbors(node)) for node in range(n)] # children remaining to process
    stack = [root]
    while len(stack) > 0:
        node = stack[-1]
        if len(children[node]) > 0:
            stack.append(children[node].pop())
        else:
            children_ = list(tree.neighbors(node))

            if len(children_) == 0:
                desc[node] = [node]

            else:
                # Intermediate computations
                desc[node] = [d for c in children_ for d in desc[c]]

                # Cost at this node
                # cost_ = similarities[desc[node]].T[desc[node]].sum()
                # cost_ -= sum([similarities[desc[c]].T[desc[c]].sum() for c in children_])
                # cost_ = cost_ / 2.0
                # This is much faster for imbalanced trees
                cost_ = sum([similarities[desc[c0]].T[desc[c1]].sum() for i,c0 in enumerate(children_) for c1 in children_[i+1:]])
                cost_ *= len(desc[node])


                cost[node] = cost_ + sum([cost[c] for c in children_]) # recursive cost

                # Free intermediate computations (otherwise, up to n^2 space for recursive descendants)
                for c in children_:
                    desc[c] = None

            assert node == stack.pop()
    return 2 * cost[root]


# @profile
# TODO this should probably be moved to tree.py since they're generic tree utils
def descendants_traversal(tree):
    """ Get all descendants non-recursively, in traversal order """
    n = len(list(tree.nodes()))
    root = n-1

    traversal = []

    children = [list(tree.neighbors(node)) for node in range(n)] # children remaining to process
    is_leaf = [len(children[node]) == 0 for node in range(n)]
    stack = [root]
    while len(stack) > 0:
        node = stack[-1]
        if len(children[node]) > 0:
            stack.append(children[node].pop())
        else:
            assert node == stack.pop()
            if is_leaf[node]:
                traversal.append(node)

    return traversal[::-1]

# @profile
def descendants_count(tree):
    """ For every node, count its number of descendant leaves, and the number of leaves before it """
    n = len(list(tree.nodes()))
    root = n-1

    left = [0] * n
    desc = [0] * n
    leaf_idx = 0

    children = [list(tree.neighbors(node))[::-1] for node in range(n)] # children remaining to process
    stack = [root]
    while len(stack) > 0:
        node = stack[-1]
        if len(children[node]) > 0:
            stack.append(children[node].pop())
        else:
            children_ = list(tree.neighbors(node))

            if len(children_) == 0:
                desc[node] = 1
                left[node] = leaf_idx
                leaf_idx += 1
            else:
                desc[node] = sum([desc[c] for c in children_])
                left[node] = left[children_[0]]
            assert node == stack.pop()
            
    return desc, left


# @profile
def dasgupta_cost(tree, similarities):
    """ Non-recursive version of DC for binary trees.

    Optimized for speed by reordering similarity matrix for locality
    """
    n = len(list(tree.nodes()))
    root = n-1
    n_leaves = len(similarities)

    leaves = descendants_traversal(tree)
    n_desc, left_desc = descendants_count(tree)

    cost = [0] * n # local cost for every node


    # reorder similarity matrix for locality
    # similarities = similarities[leaves].T[leaves] # this is the bottleneck; is there a faster way?
    similarities = reorder(similarities, np.array(leaves), n_leaves) # this is the bottleneck; is there a faster way?

    # Recursive computation
    children = [list(tree.neighbors(node)) for node in range(n)] # children remaining to process
    stack = [root]
    while len(stack) > 0:
        node = stack[-1]
        if len(children[node]) > 0:
            stack.append(children[node].pop())
        else:
            children_ = list(tree.neighbors(node))


            if len(children_) < 2:
                pass
            elif len(children_) == 2:
                left_c = children_[0]
                right_c = children_[1]

                left_range = [left_desc[left_c], left_desc[left_c] + n_desc[left_c]]
                right_range = [left_desc[right_c], left_desc[right_c] + n_desc[right_c]]
                cost_ = np.add.reduceat(
                    np.add.reduceat(
                        similarities[
                            left_range[0]:left_range[1],
                            right_range[0]:right_range[1]
                        ], [0], axis=1
                    ), [0], axis=0
                )
                cost[node] = cost_[0,0]

            else:
                assert False, "tree must be binary"
            assert node == stack.pop()

    return 2 * sum(np.array(cost) * np.array(n_desc))


def dc_bounds(similarities, lb=True, n_triples=10000):
    n = similarities.shape[0]
    nodes = np.arange(n)
    total = n * (n - 1) * (n - 2) / 6
    triples = []
    if n_triples < 0:
        for n1 in range(n):
            for n2 in range(n1 + 1, n):
                for n3 in range(n2 + 1, n):
                    triples.append((n1, n2, n3))
        n_triples = total
    else:
        for _ in range(n_triples):
            n1, n2, n3 = np.random.choice(nodes, 3, replace=False)
            triples.append((n1, n2, n3))
    expected_tcost = 0
    for i, j, k in tqdm(triples):
        sij = similarities[i, j]
        sik = similarities[i, k]
        sjk = similarities[j, k]
        if lb:
            expected_tcost += min(sij + sik, sij + sjk, sik + sjk)
        else:
            expected_tcost += max(sij + sik, sij + sjk, sik + sjk)
    expected_tcost *= 2
    expected_tcost *= total / n_triples
    cost = 2 * (np.sum(similarities) - np.sum(np.diag(similarities))) + expected_tcost
    return cost

