"""
Implementations for different methods
"""

from collections import Counter

import numpy as np
from flyingsquid.label_model import LabelModel
from sklearn.cluster import KMeans
from sklearn.metrics import f1_score




def baseline_flying_squid(votes, labels, return_preds=True):
    """ 
    Runs FlyingSquid as a baseline. Returns F1 score.

    **Assumes votes and labels are in -1/1 space**
    """
    m = votes.shape[1]
    assert sorted(np.unique(votes)) == [-1, 1]
    assert sorted(np.unique(labels)) == [-1, 1]
    
    
    label_model = LabelModel(m)
    label_model.fit(votes)
    preds = label_model.predict(votes).ravel()
    if return_preds:
        return preds
    else:
        return f1_score(labels, preds, average="macro")


def baseline_majority_vote(votes, labels):
    """
    Runs majority vote baseline. Returns F1 score.
    """
    N, M = votes.shape
    preds = []
    for i in range(N):
        pred = Counter(votes[i]).most_common(1)[0][0]
        preds.append(pred)
    return f1_score(labels, preds, average="macro")


def baseline_liger(votes, embeddings, labels, n_clusters=5):
    """
    Runs Liger baseline. Returns F1 score.

    **Assumes votes and labels are in -1/1 space**
    """
    n, m = votes.shape
    assert sorted(np.unique(votes)) == [-1, 1]
    assert sorted(np.unique(labels)) == [-1, 1]

    clf = KMeans(n_clusters=n_clusters)
    clf.fit(embeddings)
    clusters = clf.predict(embeddings)
    preds = np.zeros(n)
    for c in range(n_clusters):
        idxs = np.where(clusters == c)
        c_votes = votes[idxs]
        label_model = LabelModel(m)
        label_model.fit(c_votes)
        preds[idxs] = label_model.predict(c_votes).ravel()
    
    return f1_score(labels, preds, average="macro")


def embroid(votes, nn_info, knn = 10, thresholds=[[0.5, 0.5]]):
    """
    Implements embroid method.

    **Assumes votes and labels are in -1/1 space**

    votes (n, m): Predictions from m prompts over n samples. Votes are assumed to be in 1/-1 space.
    nn_info: list of numpy arrays, where each array contains nearest-neighbor information from a different embedding space. Each array is of shape (n, k), where the array at index i contains the indices for the k-closest samples for sample i.
    thresholds (m, 2): The tau threshold we use for computing majority vote. Expresses a proportion f

    nn is a list of nearest-neighbor information
    """
    assert sorted(np.unique(votes)) in [[1], [-1], [-1, 1]], np.unique(votes)
    
    n, m = votes.shape

    # compute neighborhood votes
    inputs = [votes]
    for i in range(len(nn_info)):
        S = np.zeros((n, m))
        for j in range(m):
            # Get prediction of source in index space (0, 1)
            j_prediction = (votes[:, j]+1)/2 
            # Get fraction of nearest neighbor votes for positive class
            neighbor_pos_frac = j_prediction[nn_info[i][:, 1:1+knn]].mean(axis=1)
            # Get votes
            neighbor_votes = np.round(neighbor_pos_frac)
            # Construct shrunk votes
            shrunk_neighbor_votes = np.zeros(len(neighbor_votes))
            idxs = np.where(neighbor_pos_frac >= thresholds[j][1])
            shrunk_neighbor_votes[idxs] = 1
            idxs = np.where((1-neighbor_pos_frac) >= thresholds[j][0])
            shrunk_neighbor_votes[idxs] = -1
            S[:,  j] = shrunk_neighbor_votes
        inputs.append(S)

    # Stack votes and S 
    mod_votes = np.concatenate(inputs, axis=1)
    assert mod_votes.shape[1] == m*(len(inputs))

    label_model = LabelModel(m*(len(inputs)))
    label_model.fit(mod_votes)
    preds = label_model.predict(mod_votes).ravel()
    return preds