import random
from functools import lru_cache
from typing import Mapping, Optional, Sequence, Tuple, Union

import numpy as np
import torch
import torch.nn.functional as F
from pytorch_lightning import seed_everything
from sklearn.cluster import KMeans
from torch import cosine_similarity, nn
from torch_cluster import fps
from torchmetrics.functional import pearson_corrcoef, spearman_corrcoef

from rel2abs.openfaiss import FaissIndex


class LayerEncoding(nn.Module):
    def __init__(self, layer: int, encoding: torch.Tensor):
        super().__init__()
        self.layer: int = layer
        self.encoding: torch.Tensor = encoding


@torch.no_grad()
def anchor_augmentation(
    encoding_anchors: torch.Tensor,
    decoding_anchors: torch.Tensor,
    centering: bool,
    inverse_dtype: torch.dtype = torch.float32,
    std_correction: bool = True,
):
    augmented_encoding_anchors = torch.eye(n=encoding_anchors.size(1), device=encoding_anchors.device)

    rel_augmented_encoding_anchors = relative_angle(
        x=augmented_encoding_anchors, anchors=encoding_anchors, rel_norm=False, abs_norm=True
    )
    augmented_decoding_anchors = rel2abs_angle(
        rel_x=rel_augmented_encoding_anchors,
        anchors=decoding_anchors,
        normalize_anchors=False,
        inverse_dtype=inverse_dtype,
    )
    augmented_decoding_anchors = (augmented_decoding_anchors - (decoding_anchors.mean(dim=0) if centering else 0)) / (
        decoding_anchors.std(dim=0) if std_correction else 1
    )

    augmented_decoding_anchors = F.normalize(augmented_decoding_anchors, p=2, dim=-1)
    return dict(encoding=augmented_encoding_anchors, decoding=augmented_decoding_anchors)


@torch.no_grad()
def anchor_augmentation_list(
    encoding_anchors: Sequence[torch.Tensor],
    decoding_anchors: Sequence[torch.Tensor],
    centering: bool,
    inverse_dtype: torch.dtype = torch.float32,
    std_correction: bool = True,
):
    return [
        anchor_augmentation(
            encoding_anchors=sub_encoding_anchors,
            decoding_anchors=sub_decoding_anchors,
            centering=centering,
            inverse_dtype=inverse_dtype,
            std_correction=std_correction,
        )
        for sub_encoding_anchors, sub_decoding_anchors in zip(encoding_anchors, decoding_anchors)
    ]


def rescale_latent(x: torch.Tensor, mean: Union[torch.Tensor, float], std: Union[torch.Tensor, float]):
    mean = torch.tensor([mean], dtype=torch.float, device=x.device) if mean is float else mean
    std = torch.tensor([std], dtype=torch.float, device=x.device) if std is float else std

    rescaling_factor = torch.normal(mean=mean, std=std)
    rescaled_x = x * (rescaling_factor)

    return rescaled_x


def self_sim_analysis(
    space1: torch.Tensor,
    space2: torch.Tensor,
    normalize: bool = False,
    spearman: bool = True,
    pearson: bool = True,
    cosine: bool = True,
):
    if normalize:
        space1 = F.normalize(space1, p=2, dim=-1)
        space2 = F.normalize(space2, p=2, dim=-1)

    self_sim1 = space1 @ space1.T
    self_sim2 = space2 @ space2.T

    result = {}
    if spearman:
        result["spearman"] = spearman_corrcoef(self_sim1.T, self_sim2.T).mean()
    if pearson:
        result["pearson"] = pearson_corrcoef(self_sim1.T, self_sim2.T).mean()
    if cosine:
        result["cosine"] = cosine_similarity(self_sim1, self_sim2).mean()

    return result


@torch.no_grad()
def relative_angle(
    x: torch.Tensor,
    anchors: torch.Tensor,
    abs_norm: bool,
    rel_norm: bool,
    p: int = 2,
) -> torch.Tensor:
    if abs_norm:
        x = F.normalize(x, p=p, dim=-1)
        anchors = F.normalize(anchors, p=p, dim=-1)

    rel_x = torch.einsum("bn,an -> ba", x, anchors)

    if rel_norm:
        rel_x = F.normalize(rel_x, p=p, dim=-1)

    return rel_x


@torch.no_grad()
def relative_projection(
    x: torch.Tensor,
    anchors: torch.Tensor,
    abs_norm: bool,
    rel_norm: bool,
    p: int = 2,
) -> torch.Tensor:
    if abs_norm:
        x = F.normalize(x, p=p, dim=-1)
        anchors = F.normalize(anchors, p=p, dim=-1)

    # rel_x1 = torch.stack([x @ anchors[i, :, :].T for i in range(anchors.size(0))], dim=1)

    rel_x = torch.einsum("nd,ad -> na", x, anchors)

    # assert torch.allclose(rel_x1, rel_x, atol=1e-6)

    if rel_norm:
        rel_x = F.normalize(rel_x, p=p, dim=-1)

    return rel_x


@torch.no_grad()
def relative_projection_list(
    x: torch.Tensor,
    anchors: Sequence[torch.Tensor],
    abs_norm: bool,
    rel_norm: bool,
    p: int = 2,
) -> torch.Tensor:
    return [
        relative_projection(x=x, anchors=anchor_subspace, abs_norm=abs_norm, rel_norm=rel_norm, p=p)
        for anchor_subspace in anchors
    ]


@torch.no_grad()
def invert_anchors(
    anchors: torch.Tensor,
    p: int = 2,
    normalize_anchors: bool = True,
    inverse_dtype: torch.dtype = torch.float32,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    if normalize_anchors:
        anchors = F.normalize(anchors, p=p, dim=-1)

    return torch.linalg.pinv(anchors.transpose(1, 0).type(inverse_dtype)).type(anchors.dtype)


@torch.no_grad()
def invert_anchors_list(
    anchors: Sequence[torch.Tensor],
    p: int = 2,
    normalize_anchors: bool = True,
    inverse_dtype: torch.dtype = torch.float32,
):
    return [
        invert_anchors(anchors=anchor_subspace, p=p, normalize_anchors=normalize_anchors, inverse_dtype=inverse_dtype)
        for anchor_subspace in anchors
    ]


@torch.no_grad()
def absolute_projection(
    rel_x: torch.Tensor,
    anchor_inverse: torch.Tensor,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    rec_x = torch.einsum("na,ad -> nd", rel_x, anchor_inverse)

    return rec_x


@torch.no_grad()
def absolute_projection_list(
    rel_x: Sequence[torch.Tensor],
    anchor_inverse: Sequence[torch.Tensor],
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    return torch.stack(
        [
            absolute_projection(
                rel_x=rel_x_subspace,
                anchor_inverse=anchor_subspace,
            )
            for rel_x_subspace, anchor_subspace in zip(rel_x, anchor_inverse)
        ],
        dim=0,
    )


@torch.no_grad()
def relative_radius(
    x: torch.Tensor,
    anchors: torch.Tensor,
    rel_norm: bool,
    p: int = 2,
) -> torch.Tensor:

    x_norms = x.norm(p=p, dim=1, keepdim=True)
    anchor_norms = anchors.norm(p=p, dim=-1, keepdim=True)

    anchor_norms = 1 / anchor_norms

    rel_x = x_norms @ anchor_norms.T
    if rel_norm:
        rel_x = F.normalize(rel_x, p=p, dim=-1)

    return rel_x


@torch.no_grad()
def to_radius(x: torch.Tensor, p: int = 2):
    radius: torch.Tensor = x.norm(p=p, dim=-1)

    return torch.stack([torch.arange(radius.size(0)), radius], dim=-1)


@torch.no_grad()
def rel2abs_angle(
    rel_x: torch.Tensor,
    anchors: torch.Tensor,
    p: int = 2,
    return_inverse: bool = False,
    normalize_anchors: bool = True,
    inverse_dtype: torch.dtype = torch.float,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
    if normalize_anchors:
        anchors = F.normalize(anchors, p=p, dim=-1)
    inverse_anchors: torch.Tensor = torch.linalg.pinv(anchors.T.type(inverse_dtype))

    rec_x: torch.Tensor = (rel_x.type(inverse_dtype) @ inverse_anchors).float()

    if return_inverse:
        return rec_x, inverse_anchors

    # return rec_x

    # if anchors.size(0) == anchors.size(1):
    #     rec_x = torch.linalg.solve(anchors, rel_x.T)
    #     # rec_x_t = torch.linalg.inv(anchors) @ rel_x.T
    # else:
    #     rec_x_t_t = torch.linalg.lstsq(anchors, rel_x.T).solution
    #     # rec_x_t = torch.linalg.pinv(anchors) @ rel_x.T

    # rec_x = rec_x_t.T

    return rec_x


@torch.no_grad()
def rel2abs_radius(rel_x: torch.Tensor, anchors: torch.Tensor, p: int = 2) -> torch.Tensor:
    anchor_norm = torch.norm(anchors, p=p, dim=-1, keepdim=True)
    anchor_norm = 1 / anchor_norm
    abs_norm = rel_x @ torch.linalg.pinv(anchor_norm.T)

    return to_radius(x=abs_norm, p=p)


@torch.no_grad()
def rel2abs_polar(
    rel_angle_x: torch.Tensor, rel_radius_x: torch.Tensor, anchors: torch.Tensor, p: int = 2
) -> torch.Tensor:
    abs_angle = rel2abs_angle(rel_x=rel_angle_x, anchors=anchors, p=p)
    abs_radius = rel2abs_radius(rel_x=rel_radius_x, anchors=anchors, p=p)[:, 1:2]

    return abs_angle * abs_radius


class LatentSpace:
    def __init__(
        self,
        encoding_type: str,
        encoder: str,
        keys: Sequence[str] = None,
        vectors: torch.Tensor = None,
    ):
        self.encoding_type: str = encoding_type
        assert ((keys is None) + (vectors is None)) % 2 == 0
        assert keys is None or len(keys) == vectors.size(0)

        self.key2index: Mapping[str, int] = keys if keys is None else {key: index for index, key in enumerate(keys)}
        self.index2key: Mapping[int, str] = None if keys is None else {index: key for index, key in enumerate(keys)}
        self.vectors: torch.Tensor = vectors
        self.encoder: str = encoder

    def __repr__(self) -> str:
        return f"LatentSpace(encoding_type={self.encoding_type}, encoder={self.encoder})"

    # @lru_cache
    def to_faiss(self, normalize: bool, keys: Sequence[str]) -> FaissIndex:
        index: FaissIndex = FaissIndex(d=self.vectors.size(1))

        index.add_vectors(
            embeddings=list(zip(keys, self.vectors.cpu().numpy())),
            normalize=normalize,
        )

        return index

    def to_relative(
        self,
        mode: str,
        abs_norm: bool,
        rel_norm: bool,
        rel_radius_norm: bool,
        rel_angle_norm: bool,
        anchor_choice: str = None,
        seed: int = None,
        anchors: Optional[Mapping[str, torch.Tensor]] = None,
        num_anchors: int = None,
    ) -> "RelativeSpace":
        assert self.encoding_type != "relative"  # TODO: for now
        assert (anchors is None) or (num_anchors is None)

        anchors = (
            self.get_anchors(anchor_choice=anchor_choice, seed=seed, num_anchors=num_anchors)
            if anchors is None
            else anchors
        )

        anchor_keys, anchor_latents = list(zip(*anchors.items()))
        anchor_latents = torch.stack(anchor_latents, dim=0).cpu()

        if mode == "angle":
            relative_vectors = relative_angle(
                x=self.vectors, anchors=anchor_latents, abs_norm=abs_norm, rel_norm=rel_angle_norm
            ).cpu()
        elif mode == "radius":
            relative_vectors = relative_radius(x=self.vectors, anchors=anchor_latents, rel_norm=rel_radius_norm).cpu()
        elif mode == "polar":
            rel_angles = relative_angle(
                x=self.vectors, anchors=anchor_latents, abs_norm=abs_norm, rel_norm=rel_angle_norm
            ).cpu()
            rel_radius = relative_radius(x=self.vectors, anchors=anchor_latents, rel_norm=rel_radius_norm).cpu()
            relative_vectors = torch.cat([rel_angles, rel_radius], dim=-1)

            if rel_norm:
                relative_vectors = F.normalize(relative_vectors, p=2, dim=-1)
        else:
            raise NotImplementedError()

        return RelativeSpace(
            keys=self.key2index.keys(),
            vectors=relative_vectors,
            encoder=self.encoder,
            anchors=anchor_keys,
        )

    @lru_cache()
    def get_anchors(self, anchor_choice: str, seed: int, num_anchors: int) -> Mapping[str, torch.Tensor]:
        # Select anchors
        seed_everything(seed)

        if anchor_choice == "uniform" or anchor_choice.startswith("top_"):
            limit: int = len(self.key2index.keys()) if anchor_choice == "uniform" else int(anchor_choice[4:])
            anchor_set: Sequence[str] = random.sample(list(self.key2index.keys())[:limit], num_anchors)
        elif anchor_choice == "fps":
            anchor_fps = F.normalize(self.vectors, p=2, dim=-1)
            anchor_fps = fps(anchor_fps, random_start=True, ratio=num_anchors / len(self.key2index.keys()))
            anchor_set: Sequence[str] = [self.index2key[word_index] for word_index in anchor_fps.cpu().tolist()]
        elif anchor_choice == "kmeans":
            vectors = F.normalize(self.vectors)
            clustered = KMeans(n_clusters=num_anchors, random_state=seed).fit_predict(vectors.cpu().numpy())

            all_targets = sorted(set(clustered))
            cluster2embeddings = {target: vectors[clustered == target] for target in all_targets}
            cluster2centroid = {
                cluster: centroid.mean(dim=0).cpu().numpy() for cluster, centroid in cluster2embeddings.items()
            }
            centroids = np.array(list(cluster2centroid.values()), dtype="float32")

            index: FaissIndex = FaissIndex(d=vectors.shape[1])
            index.add_vectors(list(zip(self.key2index.keys(), vectors.cpu().numpy())), normalize=False)
            centroids = index.search_by_vectors(query_vectors=centroids, k_most_similar=1, normalize=True)

            anchor_set = [list(word2score.keys())[0] for word2score in centroids]
        else:
            assert NotImplementedError

        return {anchor_key: self.vectors[self.key2index[anchor_key]] for anchor_key in sorted(anchor_set)}


class RelativeSpace(LatentSpace):
    def __init__(
        self,
        keys: Sequence[str],
        vectors: torch.Tensor,
        anchors: Sequence[str],
        encoder: str = None,
    ):
        super().__init__(encoding_type="relative", keys=keys, vectors=vectors, encoder=encoder)
        self.anchors: Sequence[str] = anchors
