import matplotlib.patches as patches
import matplotlib.pyplot as plt
import numba
import numpy as np


@numba.njit
def sampler_d(n, d, bounds=None):
    if bounds is None:
        bounds = np.array([[0., 1.]] * d)
    return np.random.uniform(0, 1, size=n * d).reshape(-1, d) * (
            bounds[:, 1] - bounds[:, 0]) + bounds[:, 0]


@numba.njit
def sample_around_center(centers, j_max, d, n):
    bounds = np.zeros((len(centers), 2, d), dtype=float)
    a = 1 / (2 ** (j_max + 1))
    for j in range(d):
        bounds[:, 0, j] = centers[:, j] - a
        bounds[:, 1, j] = centers[:, j] + a
    x = np.zeros((len(centers) * n, d))

    for i in range(len(centers)):
        x[i * n:(i + 1) * n, :] = sampler_d(n, d, bounds=bounds[i, :, :].T)
    return x


def test_sampler_d():
    assert sampler_d(10, 2).shape == (10, 2)
    assert sampler_d(100, 2, np.array([[0., 0.2]]))[:, 1].max() < 0.2
    sampler_d(100, 3, np.array([[0., 0.2]]))

    j_max = 3
    d = 2
    c1 = (0, 1)
    c2 = (0, 0.80)
    centers = np.array([c1, c2])
    x = sample_around_center(centers, d=d, j_max=j_max, n=300)
    a = 1 / 2 ** (j_max + 1)
    plt.figure(figsize=(6, 6))

    for c in centers:
        rect = patches.Rectangle(
            (c[0] - a, c[1] - a),
            2 * a, 2 * a, linewidth=1, edgecolor='r', facecolor='none')

        plt.gca().add_patch(rect)
    plt.scatter(x[:, 0], x[:, 1])
