import tempfile
import imageio
import matplotlib.pyplot as plt
import numpy as np
from baselines.her.goal_sampler import FUN_NAME_TO_FUN
import pylab
import matplotlib.patches as patches
import matplotlib.colorbar as cbar
import time


def save_image(fig=None, fname=None):
    if fname is None:
        fname = tempfile.TemporaryFile()
    if fig is not None:
        fig.savefig(fname)
    else:
        plt.savefig(fname, format='png')
    plt.close('all')
    fname.seek(0)
    img = imageio.imread(fname)
    fname.close()
    return img


def my_square_scatter(axes, x_array, y_array, z_array, min_z=None, max_z=None, size=(0.5, 0.5), **kwargs):

    if min_z is None:
        min_z = z_array.min()
    if max_z is None:
        max_z = z_array.max()

    normal = pylab.Normalize(min_z, max_z)
    colors = pylab.cm.jet(normal(z_array))

    for x, y, c in zip(x_array, y_array, colors):
        square = pylab.Rectangle((x - size[0] / 2, y - size[1] / 2), size[0], size[1], color=c, **kwargs)
        axes.add_patch(square)

    axes.autoscale()

    cax, _ = cbar.make_axes(axes)
    _ = cbar.ColorbarBase(cax, cmap=pylab.cm.jet, norm=normal)


def plot_heatmap(z, goals, limit, center, spacing, min_z=None, max_z=None, show_heatmap=True, maze_id=None):
    fig, ax = plt.subplots()

    x_goal, y_goal = np.array(goals)[:, :2].T

    my_square_scatter(axes=ax, x_array=x_goal, y_array=y_goal, z_array=z, min_z=min_z, max_z=max_z, size=spacing)

    if maze_id == 0:
        ax.add_patch(patches.Rectangle((-3, -3), 10, 2, fill=True, edgecolor="none", facecolor='0.4', alpha=0.2))
        ax.add_patch(patches.Rectangle((-3, -3), 2, 10, fill=True, edgecolor="none", facecolor='0.4', alpha=0.2))
        ax.add_patch(patches.Rectangle((-3, 5), 10, 2, fill=True, edgecolor="none", facecolor='0.4', alpha=0.2))
        ax.add_patch(patches.Rectangle((5, -3), 2, 10, fill=True, edgecolor="none", facecolor='0.4', alpha=0.2))
        ax.add_patch(patches.Rectangle((-1, 1), 4, 2, fill=True, edgecolor="none", facecolor='0.4', alpha=0.2))
    elif maze_id == 11:
        ax.add_patch(patches.Rectangle((-7, 5), 14, 2, fill=True, edgecolor="none", facecolor='0.4'))
        ax.add_patch(patches.Rectangle((5, -7), 2, 14, fill=True, edgecolor="none", facecolor='0.4'))
        ax.add_patch(patches.Rectangle((-7, -7), 14, 2, fill=True, edgecolor="none", facecolor='0.4'))
        ax.add_patch(patches.Rectangle((-7, -7), 2, 14, fill=True, edgecolor="none", facecolor='0.4'))
        ax.add_patch(patches.Rectangle((-3, 1), 10, 2, fill=True, edgecolor="none", facecolor='0.4'))
        ax.add_patch(patches.Rectangle((-3, -3), 2, 6, fill=True, edgecolor="none", facecolor='0.4'))
        ax.add_patch(patches.Rectangle((-3, -3), 6, 2, fill=True, edgecolor="none", facecolor='0.4'))
    elif maze_id == 12:
        ax.add_patch(patches.Rectangle((-7, 5), 14, 2, fill=True, edgecolor="none", facecolor='0.4'))
        ax.add_patch(patches.Rectangle((5, -7), 2, 14, fill=True, edgecolor="none", facecolor='0.4'))
        ax.add_patch(patches.Rectangle((-7, -7), 14, 2, fill=True, edgecolor="none", facecolor='0.4'))
        ax.add_patch(patches.Rectangle((-7, -7), 2, 14, fill=True, edgecolor="none", facecolor='0.4'))

    # ax.set_ylim(center[0] - limit, center[0] + limit)
    # ax.set_xlim(center[1] - limit, center[1] + limit)
    ax.set_xlim(center[0] - limit[0], center[0] + limit[0])
    ax.set_ylim(center[1] - limit[1], center[1] + limit[1])

    assert np.all(np.logical_and(np.less_equal(center[0] - limit[0], goals[:, 0]), np.less_equal(goals[:, 0], center[0] + limit[0])))
    assert np.all(np.logical_and(np.less_equal(center[1] - limit[1], goals[:, 1]), np.less_equal(goals[:, 1], center[1] + limit[1])))

    # colmap = cm.ScalarMappable(cmap=cm.rainbow)
    # colmap.set_array(rewards)
    # Create the contour plot
    # CS = ax.contourf(xs, ys, zs, cmap=plt.cm.rainbow,
    #                   vmax=zmax, vmin=zmin, interpolation='nearest')
    # CS = ax.imshow([rewards], interpolation='none', cmap=plt.cm.rainbow,
    #                vmax=np.max(rewards), vmin=np.min(rewards)) # extent=[np.min(ys), np.max(ys), np.min(xs), np.max(xs)]
    # fig.colorbar(colmap)

    # ax.set_title(prefix + 'Returns')
    # ax.set_xlabel('goal position (x)')
    # ax.set_ylabel('goal position (y)')

    # ax.set_xlim([np.max(ys), np.min(ys)])
    # ax.set_ylim([np.min(xs), np.max(xs)])
    # plt.scatter(x_goal, y_goal, c=rewards, s=1000, vmin=0, vmax=max_reward)
    # plt.colorbar()
    if show_heatmap:
        plt.show()
    return fig


def make_plotter(init_ob, policy, value_ensemble, goals, disagreement_str, plotter_worker, gamma, report,
                 plot_heatmap_fun, eval_policy):

    if plotter_worker.rollout_batch_size > 1:
        assert not plotter_worker.exploit

    def plotter(epoch, goal_history):
        o = init_ob['observation'][np.newaxis, ...]
        ag = init_ob['observation'][np.newaxis, ...]
        input_o = np.repeat(o[np.newaxis, ...], repeats=len(goals), axis=0)
        input_ag = np.repeat(ag[np.newaxis, ...], repeats=len(goals), axis=0)

        plot_heatmap_fun(z=np.arange(len(goal_history)) / len(goal_history), goals=goal_history, spacing=(0.01, 0.01))
        report.add_image(save_image(), f'goal history: {np.mean(goal_history), np.std(goal_history)}')

        # plot Q_pi
        u, Q_pi = policy.get_actions(o=input_o, ag=input_ag, g=goals, compute_Q=True)

        plot_heatmap_fun(Q_pi.squeeze(axis=1), goals=goals)
        report.add_image(save_image(), f'epoch {epoch}, policy Q: {np.mean(Q_pi)}')

        vals = value_ensemble.get_values(o=input_o, ag=input_ag, g=goals, u=u)
        vals = np.squeeze(vals, axis=2)  # (size_ensemble, n_candidates)

        # plot sample val
        # plot_heatmap_fun(vals[0])
        # report.add_image(save_image(), f'sample Q: {np.mean(vals[0])}')

        # plot mean vals
        q = np.mean(vals, axis=0)
        plot_heatmap_fun(q, goals=goals)
        report.add_image(save_image(), f'mean Q: {np.mean(q)}')

        # plot disagreement
        disagreement = FUN_NAME_TO_FUN[disagreement_str](vals)

        plot_heatmap_fun(disagreement, goals=goals)
        report.add_image(save_image(), f'disagreement: {np.mean(disagreement)}, {np.std(disagreement)}')

        # # plot disagreement after subtracting ensemble-mean
        # zero_mean_vals = vals - np.mean(vals, axis=0)
        # assert zero_mean_vals.shape == (len(goals),)
        # disagreement = FUN_NAME_TO_FUN[disagreement_str](zero_mean_vals)
        # plot_heatmap_fun(disagreement_str, goals=goals)
        # report.add_image(save_image(), f'disagreement: {np.mean(disagreement)}, {np.std(disagreement)}')

        # # plot p
        # if np.allclose(np.sum(disagreement), 0):
        #     p = np.ones_like(disagreement) / len(goals)
        # else:
        #     p = disagreement / np.sum(disagreement)


        if eval_policy:
            plotter_worker.clear_history()
            q_ground_truth = []
            t = time.time()

            for goal in goals:
                plotter_worker.envs_op('update_goal_sampler', goal_sampler=lambda obs_dict: goal)
                episode = plotter_worker.generate_rollouts(expose_reward=True)

                # r = self.env.compute_reward(achieved_goal=episode['ag'][:, :-1, :], desired_goal=episode['g'], info=None)

                discount_factors = [gamma**t for t in range(plotter_worker.T-1)]
                disc_rewards = np.sum(episode['r'] * discount_factors, axis=1)  # (rollout_batch_size, T-1,) -> (rollout_batch_size,)
                q_ground_truth.append(np.mean(disc_rewards))

            plotter_worker.logs('plot')
            plot_heatmap_fun(np.asarray(q_ground_truth), goals=goals)
            report.add_image(save_image(), f'q gt: {np.mean(q_ground_truth)}, use time {time.time() - t} for {len(goals)} goals')

            binary_success = plotter_worker.success_history
            assert len(binary_success) == len(goals)
            plot_heatmap_fun(np.asarray(binary_success), goals=goals)
            report.add_image(save_image(), f'binary coverage: {np.mean(binary_success)}')

            # any_success = plotter_worker.misc_history['any_is_success']
            # assert len(any_success) == len(goals)
            # plot_heatmap_fun(np.asarray(any_success), goals=goals)
            # report.add_image(save_image(), f'any success: {np.mean(any_success)}')

        report.new_row()

    return plotter
