from multiprocessing.pool import ThreadPool
from typing import Callable, List, Union

import numpy as np
import torch
import torch.nn.functional as F
import tqdm

Array = Union[List[float], np.ndarray]


class QuantileFunction:
    """Computes the quantile function corresponding to an empirical distribution of samples.

    Args
        samples (numpy array): Empirical distribution of data e.g. scores belonging to an algorithm.
    """

    def __init__(self, samples: Array):
        if isinstance(samples, list):
            samples = np.array(samples)
        self.samples = np.sort(samples)

    def __len__(self):
        return len(self.samples)

    def __call__(self, p: Union[float, Array]) -> np.ndarray:
        if isinstance(p, list):
            p = np.array(p)
        inds = np.ceil(len(self) * p).astype(int)
        clip_inds = np.clip(inds - 1, 0, len(self) - 1)
        return self.samples[clip_inds]


def num_integrate(func: Callable, x0: float, x1: float, dx: float) -> float:
    r""" Numerical integration

    Integrate a function func from x0 to x1 with step dx

    Args:
        func (Callable): function to integrate
        x0 (float): start point
        x1 (float): end point
        dx (float): step size

    Returns:
        numpy array containing the results of the integration
    """
    X = np.linspace(x0, x1, int((x1 - x0) / dx))
    y = func(X[1:])  # remove x0
    return np.sum(y * dx).item()


def num_integrate_func(func: Callable, x0: float, x1: float, dx: float) -> Callable:
    r""" Numerical integration

    Integrate a function func from x0 to x1 with step dx and returns
    the result as the indefinite integral function $F(x) = \Int_x0^x f(s) ds$

    Args:
        func (Callable): function to integrate
        x0 (float): start point
        x1 (float): end point
        dx (float): step size

    Returns:
        function, a function that returns the integral until a given point x
    """
    X = np.linspace(x0, x1, int((x1 - x0) / dx))
    y = func(X[1:])  # remove x0

    res = np.cumsum(y * dx)
    res = np.insert(res, 0, 0)  # add x0 so that $\Int_{x0}^{x0}f(x)dx = 0$

    def integral_func(x: Union[float, Array]) -> np.ndarray:
        if isinstance(x, list):
            x = np.array(x)
        if not all(np.r_[x >= x0, x <= x1]):
            raise ValueError(f"x must be between {x0} and {x1}, but got {x}.")
        inds = np.ceil((x - x0) / dx).astype(int)
        clip_inds = np.clip(inds - 1, 0, len(res) - 1)
        return res[clip_inds]

    return integral_func


class SecondQuantileFunction:
    """Computes the second quantile function $F_X^{(-2)}(p)$, i.e. the integral of the quantile function $F_X^{(-1)}(p)$.
        For efficiency we compute its values between 0 and 1.0 with interval given by dp, and store them.
    """

    def __init__(self, samples: Array, dp: float = 0.01):
        self.samples = samples
        self.dp = dp

        quant_func = QuantileFunction(samples)
        self.second_quantile_func = num_integrate_func(quant_func, 0.0, 1.0, dp)

    def __call__(self, p: Union[float, Array]) -> np.ndarray:
        return self.second_quantile_func(p)


class ECDF:
    """Empirical CDF corresponding to an empirical distribution of scores.

    Args
        samples (numpy array): Empirical distribution of data e.g. scores belonging to an algorithm.

    Returns
        Callable: the empirical CDF belonging to an empirical score distribution.
    """

    def __init__(self, samples: Array):
        if isinstance(samples, list):
            samples = np.array(samples)
        self.samples = np.sort(samples)

    def __len__(self):
        return len(self.samples)

    def __call__(self, x: Union[float, Array]) -> Union[float, np.ndarray]:
        if isinstance(x, list):
            x = np.array(x)
        ecdf_values = np.searchsorted(self.samples, x, side="right") / len(self)
        return ecdf_values


class IntegratedECDF:
    """Computes the integrated ECDF $F_X^{(2)}(x)$, i.e. the integral of the ECDF $F_X^{(1)}(x)$.
        For efficiency we compute its values between 0 and 1.0 with interval given by dp, and store them.
    """

    def __init__(self, samples: Array):
        if isinstance(samples, list):
            samples = np.array(samples)
        self.samples = np.sort(samples)
        self.cdf_values = np.arange(1, len(samples) + 1) / len(samples)

    def __call__(self, x: Union[float, Array]) -> Union[float, np.ndarray]:
        if isinstance(x, list):
            x = np.array(x)
        inds = np.searchsorted(self.samples, x, side="right")
        inds = np.clip(inds, 1, len(self.samples))
        h = self.cdf_values[:np.max(inds)]
        dx = np.diff(self.samples[:np.max(inds)], prepend=self.samples[0])
        return np.cumsum(h * dx)[inds - 1]


def bootstrap_multiprocessing(func: Callable,
                              num_workers: int,
                              n_bootstrap: int,
                              desc: str = "",
                              verbose: bool = True) -> List:
    """Bootstrap a function in parallel

    Args:
        func (Callable): function to bootstrap
        num_workers (int): number of workers
        n_bootstrap (int): number of bootstrap samples
        desc (str, optional): description for tqdm. Defaults to empty string.

    Returns:
        List: list of results

    Example:
        >>> from soe.utils import bootstrap_multiprocessing
        >>> def func(i):
        ...     return i
        >>> bootstrap_multiprocessing(func, 4, 10)
        [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]

    """
    pool = ThreadPool(processes=num_workers)
    results = list(
        tqdm.tqdm(pool.imap(func, range(n_bootstrap)),
                  desc=desc,
                  total=n_bootstrap,
                  disable=not verbose))
    pool.close()
    return results


def pdist(X: np.ndarray | torch.Tensor,
          Y: np.ndarray | torch.Tensor,
          beta: float = 1.0,
          metric: str = 'euclidean') -> np.ndarray | torch.Tensor:
    """Compute the pairwise distance between two sets of points

    Args:
        X (array): first set of points
        Y (array): second set of points
        metric (str, optional): distance metric. Defaults to 'euclidean', corresoinding to the squared euclidean distance.

    Returns:
        numpy or torch tensor: pairwise distance matrix
    """
    if metric == 'euclidean':
        fn = lambda x, y: (x - y)**2

    elif metric == 'hinge':
        fn = lambda x, y: (y - x)**2 * (y - x > 0)

    elif metric == 'logistic':
        if isinstance(X, torch.Tensor):
            fn = lambda x, y: -F.logsigmoid(-beta * (y - x))
        else:
            fn = lambda x, y: np.log(1 + np.exp(beta * (y - x)))

    elif metric == 'logistic_sym':
        if isinstance(X, torch.Tensor):
            fn = lambda x, y: -F.logsigmoid(-beta * (y - x)) - F.logsigmoid(-beta * (x - y))
        else:
            fn = lambda x, y: np.log(1 + np.exp(beta * (y - x))) + np.log(1 + np.exp(beta *
                                                                                     (x - y)))
    else:
        raise ValueError(f"Unknown metric {metric}")

    return fn(X[:, None, ...], Y[None, ...]).sum(axis=-1)  # type: ignore
