import fix_imports
import argparse

import numpy as np
import torch
import matplotlib
from sys import platform as sys_pf
if sys_pf == 'darwin':
    matplotlib.use("TkAgg")
import matplotlib.pyplot as plt
import matplotlib.collections as mc
import matplotlib.style as style

from model import model_from_state


def sample_vectors(state, scale=1.0):
    x = np.sort(np.squeeze(state['uv']))

    lines = np.zeros([len(x), 2, 2])
    lines[:, 0, 1] = -x
    lines[:, 0, 0] = np.ones_like(x)
    lines[:, 1, 1] = x
    lines[:, 1, 0] = -np.ones_like(x)

    return lines*scale


def plot_bg(lines, ax, alpha, s=1.0):
    lx = [-s, s]
    ly1 = [-s, -s]
    ly2 = [s, s]
    ax.fill_between(lx, ly2, ly1, alpha=alpha)

    for i in range(lines.shape[0]-1):
        idx_next = (i + 1)
        lx = [-s, s]
        ly1 = [lines[i, 1, 1], lines[i, 0, 1]]
        ly2 = [lines[idx_next, 1, 1], lines[idx_next, 0, 1]]
        ax.fill_between(lx, ly1, ly2, alpha=alpha)


def main():
    style.use('seaborn')

    argparser = argparse.ArgumentParser()
    argparser.add_argument("state", type=str, help="Fitted model (out.pt) generated with fit_model.py")
    argparser.add_argument("-e", "--epoch", default=-1, type=int, help="Which epoch to plot the model at")
    argparser.add_argument("-s", "--scale", default=1.0, type=float, help="Scale factor for the figure")
    argparser.add_argument("-n", "--num-samples", default=100, type=int, help="Number of neurons to plot")
    args = argparser.parse_args()


    state = torch.load(args.state)
    model = model_from_state(state)
    model.load_state_dict(state['saved_states'][args.epoch][1])

    x = state['uv']
    y = state['x']

    res = model(torch.from_numpy(x).to(state['device'])) - torch.from_numpy(y).to(state['device'])
    res = res.detach().cpu().squeeze().numpy()

    s = args.scale
    n = args.num_samples*1j
    uv = np.mgrid[-s:s:n, -s:s:n]
    x_tilde = np.concatenate([x.reshape(1, len(x)), np.ones([1, len(x)])], axis=0)
    m = np.sum(x_tilde[:, :, np.newaxis, np.newaxis] * uv[:, np.newaxis, :, :], axis=0)
    tau = m.copy()
    tau[tau <= 0] = 0.0
    tau[tau > 0] = 1.0

    grad_l = res[np.newaxis, :, np.newaxis, np.newaxis] * tau[np.newaxis, :] * x_tilde[:, :, np.newaxis, np.newaxis]
    grad_l = np.sum(grad_l, axis=1)
    grad_l_dir = grad_l / np.linalg.norm(grad_l, axis=0, keepdims=True)

    s = 1.0
    plt.figure()
    ax = plt.axes()
    ax.axis([-s, s, -s, s])
    lines = sample_vectors(state, scale=s)
    plot_bg(lines, ax, alpha=0.4, s=s)
    lc = mc.LineCollection(lines, linewidths=3, label="$x_i$")
    ax.add_collection(lc)

    ax.quiver(uv[0, :], uv[1, :], grad_l_dir[0, :], grad_l_dir[1, :])
    # ns = np.abs(n*1j)
    # uvr = np.linspace(-s, s, ns)
    # ax.streamplot(uvr, uvr, grad_l_dir[0, :].T, grad_l_dir[1, :].T, density=10)
    plt.show()


if __name__ == "__main__":
    main()
