import gensim.downloader as api
import numpy as np

def load_embedding_model(model):
    """
    Load GloVe Vectors
    Return:
        wv_from_bin: All embeddings with the specified dimension.
    """
    wv_from_bin = api.load(model)
    print(f"Loaded {model}, with vocabulary size {len(list(wv_from_bin.index_to_key))}")
    return wv_from_bin

def normalize_vector(vec):
    """
    Normalize a vector to unit length. Return the vector unchanged if its norm is zero.
    """
    norm = np.linalg.norm(vec)
    if norm == 0:
        return vec
    return vec / norm

def phrase_vector(wv_from_bin, phrase):
    """
    Calculate a normalized vector for a phrase by averaging the vectors of the words it contains.
    """
    words = phrase.split()
    vectors = [wv_from_bin[word] for word in words if word in wv_from_bin.key_to_index]
    if vectors:
        sum_vector = np.sum(vectors, axis=0)
        return normalize_vector(sum_vector)
    else:
        return None

def cosine_similarity(wv_from_bin, phrase1, phrase2):
    """
    Calculate cosine similarity between two phrases.
    """
    vec1 = phrase_vector(wv_from_bin, phrase1)
    vec2 = phrase_vector(wv_from_bin, phrase2)
    if vec1 is not None and vec2 is not None:
        return np.dot(vec1, vec2)
    else:
        return None
    
def max_cosine_similarity(wv_from_bin, predicted_phrase, object_set):
    """Calculate the maximum cosine similarity between a predicted phrase and all phrases in a set."""
    predicted_vector = phrase_vector(wv_from_bin, predicted_phrase)
    if predicted_vector is None:
        return 0 
    max_similarity = 0
    for obj_name in object_set:
        obj_vector = phrase_vector(wv_from_bin, obj_name)
        if obj_vector is not None:
            similarity = cosine_similarity(wv_from_bin, predicted_phrase, obj_name)
            if similarity > max_similarity:
                max_similarity = similarity
    return max_similarity

def find_highest_similarity(wv_from_bin, words):
    """
    Find the highest cosine similarity among all unique pairs of a list of words.
    Args:
        wv_from_bin: Gensim KeyedVectors object
        words: List of words as strings

    Returns:
        A tuple of (max_similarity, word_pair) where max_similarity is the highest similarity found,
        and word_pair is the tuple of words for that similarity
    """
    max_similarity = -1
    best_pair = None

    for i in range(len(words)):
        for j in range(i + 1, len(words)):
            similarity = cosine_similarity(wv_from_bin, words[i], words[j])
            if similarity is not None and similarity > max_similarity:
                max_similarity = similarity
                best_pair = (words[i], words[j])

    return max_similarity, best_pair