import os
import random
from typing import Union, List

from pymoo.util.ref_dirs import get_reference_directions
import matplotlib.pyplot as plt
import numpy as np


def extrema_weights(dim) -> List[np.ndarray]:
    extrema_weights = []
    for i in range(dim):
        w = np.zeros(dim)
        w[i] = 1.0
        extrema_weights.append(w)
    return extrema_weights


def unique_tol(a: List[np.ndarray], tol=1e-4) -> List[np.ndarray]:
    """ Returns unique elements of a list of arrays, with a tolerance."""
    if len(a) == 0:
        return a
    delete = np.array([False] * len(a))
    a = np.array(a)
    for i in range(len(a)):
        if delete[i]:
            continue
        for j in range(i + 1, len(a)):
            if np.allclose(a[i], a[j], tol):
                delete[j] = True
    return list(a[~delete])


def generate_weights(count=1, n=3, m=1):
    """Source: https://github.com/axelabels/DynMORL/blob/db15c29bc2cf149c9bda6b8890fee05b1ac1e19e/utils.py#L281"""
    all_weights = []

    target = np.random.dirichlet(np.ones(n), 1)[0]
    prev_t = target
    for _ in range(count // m):
        target = np.random.dirichlet(np.ones(n), 1)[0]
        if m == 1:
            all_weights.append(target)
        else:
            for i in range(m):
                i_w = target * (i + 1) / float(m) + prev_t * (m - i - 1) / float(m)
                all_weights.append(i_w)
        prev_t = target + 0.0

    return all_weights


def random_weights(dim, seed=None, n=1):
    """Generate random normalized weights from a Dirichlet distribution alpha=1
    Args:
        dim: size of the weight vector
    """
    if seed is not None:
        rng = np.random.default_rng(seed)
    else:
        rng = np.random
    weights = []
    for _ in range(n):
        w = rng.dirichlet(np.ones(dim))
        weights.append(w)
    if n == 1:
        return weights[0]
    return weights


def equally_spaced_weights(dim: int, num_weights: int, seed: int = 42) -> List[np.ndarray]:
    return get_reference_directions("energy", dim, num_weights, seed=seed)


def moving_average(interval: Union[np.array, List], window_size: int) -> np.array:
    if window_size == 1:
        return interval
    window = np.ones(int(window_size)) / float(window_size)
    return np.convolve(interval, window, "same")


def linearly_decaying_epsilon(initial_epsilon, decay_period, step, warmup_steps, final_epsilon):
    """Returns the current epsilon for the agent's epsilon-greedy policy.
    This follows the Nature DQN schedule of a linearly decaying epsilon (Mnih et
    al., 2015). The schedule is as follows:
    Begin at 1. until warmup_steps steps have been taken; then
    Linearly decay epsilon from 1. to epsilon in decay_period steps; and then
    Use epsilon from there on.
    Args:
    decay_period: float, the period over which epsilon is decayed.
    step: int, the number of training steps completed so far.
    warmup_steps: int, the number of steps taken before epsilon is decayed.
    epsilon: float, the final value to which to decay the epsilon parameter.
    Returns:
    A float, the current epsilon value computed according to the schedule.
    """
    steps_left = decay_period + warmup_steps - step
    bonus = (initial_epsilon - final_epsilon) * steps_left / decay_period
    bonus = np.clip(bonus, 0.0, 1.0 - final_epsilon)
    return final_epsilon + bonus


def seed_everything(seed: int = 42):
    random.seed(seed)
    os.environ["PYTHONHASHSEED"] = str(seed)
    np.random.seed(seed)


