"""Test the limit function."""
import matplotlib.pyplot as plt
import numpy as np
from action_masking.util.util import Algorithm, ActionSpace, Approach, TransitionTuple  # noqa: F401
from action_masking.provably_safe_env.envs.long_quadrotor_env import LongQuadrotorEnv  # noqa: F401


def test_limit_function(create_env, continuous_safe_space_fn, SAFE_REGION):
    """Test the limit function."""
    space = ActionSpace.Continuous
    transition = TransitionTuple.Naive
    approach = Approach.Sample
    env = create_env(space, approach, transition, 0)
    env = env.envs[0].env.env.unwrapped

    x_vals = np.arange(-0.1, 0.1, 0.001)
    limits = np.zeros([x_vals.shape[0], 4])
    for i in range(x_vals.shape[0]):
        env.state = np.array([x_vals[i], 0, 1, 0, 0, 0])
        try:
            limit, _, _ = continuous_safe_space_fn(env, SAFE_REGION)
            limits[i, 0:2] = limit[0, :]
            limits[i, 2:4] = limit[1, :]
        except ValueError:
            limits[i, :] = np.full([4, ], -1000)
    plt.plot(x_vals, limits[:, 0], 'r--', label="u_g min")
    plt.plot(x_vals, limits[:, 1], 'r-', label="u_g max")
    plt.plot(x_vals, limits[:, 2], 'g--', label="u_d min")
    plt.plot(x_vals, limits[:, 3], 'g-', label="u_d max")
    plt.legend()
    plt.xlabel("x [m]")
    plt.ylabel("u")
    plt.ylim([-0.2, 0.5])
    plt.savefig("safe_region/x_limit.png")
    plt.close()

    dx_vals = np.arange(-1, 1, 0.01)
    limits = np.zeros([dx_vals.shape[0], 4])
    for i in range(dx_vals.shape[0]):
        env.state = np.array([0, dx_vals[i], 1, 0, 0, 0])
        try:
            limit, _, _ = continuous_safe_space_fn(env, SAFE_REGION)
            limits[i, 0:2] = limit[0, :]
            limits[i, 2:4] = limit[1, :]
        except ValueError:
            limits[i, :] = np.full([4, ], -1000)
    plt.plot(dx_vals, limits[:, 0], 'r--', label="u_g min")
    plt.plot(dx_vals, limits[:, 1], 'r-', label="u_g max")
    plt.plot(dx_vals, limits[:, 2], 'g--', label="u_d min")
    plt.plot(dx_vals, limits[:, 3], 'g-', label="u_d max")
    plt.legend()
    plt.xlabel("dx [m/s]")
    plt.ylabel("u")
    plt.ylim([-0.2, 0.5])
    plt.savefig("safe_region/dx_limit.png")
    plt.close()

    z_vals = np.arange(0.9, 1.1, 0.001)
    limits = np.zeros([z_vals.shape[0], 4])
    for i in range(z_vals.shape[0]):
        env.state = np.array([0, 0, z_vals[i], 0, 0, 0])
        try:
            limit, _, _ = continuous_safe_space_fn(env, SAFE_REGION)
            limits[i, 0:2] = limit[0, :]
            limits[i, 2:4] = limit[1, :]
        except ValueError:
            limits[i, :] = np.full([4, ], -1000)
    plt.plot(z_vals, limits[:, 0], 'r--', label="u_g min")
    plt.plot(z_vals, limits[:, 1], 'r-', label="u_g max")
    plt.plot(z_vals, limits[:, 2], 'g--', label="u_d min")
    plt.plot(z_vals, limits[:, 3], 'g-', label="u_d max")
    plt.legend()
    plt.xlabel("z [m]")
    plt.ylabel("u")
    plt.ylim([-0.2, 0.5])
    plt.savefig("safe_region/z_limit.png")
    plt.close()

    dz_vals = np.arange(-1, 1, 0.01)
    limits = np.zeros([dz_vals.shape[0], 4])
    for i in range(dz_vals.shape[0]):
        env.state = np.array([0, 0, 1, dz_vals[i], 0, 0])
        try:
            limit, _, _ = continuous_safe_space_fn(env, SAFE_REGION)
            limits[i, 0:2] = limit[0, :]
            limits[i, 2:4] = limit[1, :]
        except ValueError:
            limits[i, :] = np.full([4, ], -1000)
    plt.plot(dz_vals, limits[:, 0], 'r--', label="u_g min")
    plt.plot(dz_vals, limits[:, 1], 'r-', label="u_g max")
    plt.plot(dz_vals, limits[:, 2], 'g--', label="u_d min")
    plt.plot(dz_vals, limits[:, 3], 'g-', label="u_d max")
    plt.legend()
    plt.xlabel("dz [m/s]")
    plt.ylabel("u")
    plt.ylim([-0.2, 0.5])
    plt.savefig("safe_region/dz_limit.png")
    plt.close()

    thetas = np.arange(-0.2, 0.2, 0.001)
    limits = np.zeros([thetas.shape[0], 4])
    for i in range(thetas.shape[0]):
        env.state = np.array([0, 0, 1, 0, thetas[i], 0])
        try:
            limit, _, _ = continuous_safe_space_fn(env, SAFE_REGION)
            limits[i, 0:2] = limit[0, :]
            limits[i, 2:4] = limit[1, :]
        except ValueError:
            limits[i, :] = np.full([4, ], -1000)
    plt.plot(thetas, limits[:, 0], 'r--', label="u_g min")
    plt.plot(thetas, limits[:, 1], 'r-', label="u_g max")
    plt.plot(thetas, limits[:, 2], 'g--', label="u_d min")
    plt.plot(thetas, limits[:, 3], 'g-', label="u_d max")
    plt.legend()
    plt.xlabel("theta [rad]")
    plt.ylabel("u")
    plt.ylim([-0.2, 0.5])
    plt.savefig("safe_region/theta_limit.png")
    plt.close()

    theta_dots = np.arange(-2.0, 2.0, 0.01)
    limits = np.zeros([theta_dots.shape[0], 4])
    for i in range(theta_dots.shape[0]):
        env.state = np.array([0, 0, 1, 0, 0, theta_dots[i]])
        try:
            limit, _, _ = continuous_safe_space_fn(env, SAFE_REGION)
            limits[i, 0:2] = limit[0, :]
            limits[i, 2:4] = limit[1, :]
        except ValueError:
            limits[i, :] = np.full([4, ], -1000)
    plt.plot(theta_dots, limits[:, 0], 'r--', label="u_g min")
    plt.plot(theta_dots, limits[:, 1], 'r-', label="u_g max")
    plt.plot(theta_dots, limits[:, 2], 'g--', label="u_d min")
    plt.plot(theta_dots, limits[:, 3], 'g-', label="u_d max")
    plt.legend()
    plt.xlabel("d/dt theta [rad/s]")
    plt.ylabel("u")
    plt.ylim([-0.2, 0.5])
    plt.savefig("safe_region/dtheta_limit.png")
    plt.close()
