import numpy as np
# from scripts.plot_brains import plot_blobs
from surfer import Brain
from mayavi import mlab

from mne.datasets import sample
from mne.inverse_sparse.mxne_inverse import _make_sparse_stc
from main_expe7 import (
    list_pb_name, times, info, fwd, list_event_id, list_decim)


fig_dir = "../../../latex/NeurIPS2019/prebuiltimages/"


def plot_blob(
        stc, subject="sample", surface="white", s=18, save_fname="",
        data_path=sample.data_path(), subject_name='/subjects',
        fig_dir="", figsize=(800, 800), event_id=1):

    subjects_dir = data_path + subject_name
    list_hemi = ["lh", "rh"]

    for i, hemi in enumerate(list_hemi):
        figure = mlab.figure(size=figsize)
        brain = Brain(
            subject, hemi, surface, subjects_dir=subjects_dir,
            offscreen=False, figure=figure)
        surf = brain.geo[hemi]
        sources_h = stc.vertices[i]  # 0 for lh, 1 for rh
        for sources in sources_h:
            mlab.points3d(
                surf.x[sources], surf.y[sources],
                surf.z[sources], color=(1, 0, 0),
                scale_factor=s, opacity=1., transparent=True)
        if save_fname:
            fname = fig_dir + hemi + save_fname
            if event_id == 1 or event_id == 2:
                brain.save_montage(fname, order=['lat'])
            else:
                brain.save_montage(fname, order=['lat'])
                # brain.save_montage(fname, order=['ven'])

            # mlab.savefig(fname)
            figure = mlab.gcf()
            mlab.close(figure)

for decim in list_decim:
    for event_id in list_event_id:
        dict_dns = np.load(
            "event_id_%i_decim_%i_dns.npy" % (event_id, decim)).take(0)
        dict_supp = np.load(
            "event_id_%i_decim_%i_supp.npy" % (event_id, decim)).take(0)

        for pb_name in list_pb_name:
            # save_fname = "_" + pb_name + ".pdf"
            B_dns = dict_dns[pb_name]
            supp = dict_supp[pb_name]
            stc = _make_sparse_stc(
                B_dns, supp, fwd, tmin=times[0], tstep=1. / info['sfreq'])
            save_fname = "_%s_event_id_%i_decim_%i.png" % \
                         (pb_name, event_id, decim)
            plot_blob(stc, save_fname=save_fname, fig_dir=fig_dir,
                      event_id=event_id)


# plot_sparse_source_estimates(fwd['src'], stc, bgcolor=(1, 1, 1),
#                         opacity=0.1)
    # # fig_name = "%s, lambda / lambdamax = %0.2f" % (pb_name, p_alpha)
    # # plot_sparse_source_estimates(fwd['src'], stc, bgcolor=(1, 1, 1),
    # #                              opacity=0.1, fig_name=fig_name)
