"""
This script launches a script in order to make AUC ROC curves on semi-real data
"""

import numpy as np
import time
from joblib import Parallel, delayed
from itertools import product


from clar.solvers import solver, get_path
from clar.utils import get_alpha_max, get_sigma_min
from expes.utils import check_and_create_dirs, get_path_expe
from data.semi_real import get_semi_real_data_v2

# import socket
import os


if __name__ == '__main__':
    import argparse
    import importlib
    parser = argparse.ArgumentParser('Main script for maximal sparsity')
    parser.add_argument('--expe', type=str, default='expe4B',
                        help='Choose the parameters for the experiement.')
    args = parser.parse_args()
    expe = importlib.import_module("expes.expe4.params_{}".format(args.expe))

    list_whiten = expe.list_whiten
    # parameters of the problem
    n_times = expe.n_times
    list_n_epochs = expe.list_n_epochs
    n_jobs = expe.n_jobs

    # n_samples for ROC curves
    n_repet_roc_curves = expe.n_repet_roc_curves

    # parameters of the solver
    tol = expe.tol
    n_iter = expe.n_iter
    S_freq = expe.S_freq
    active_set_freq = expe.active_set_freq
    list_pb_name = expe.list_pb_name
    # parameters of the problem
    # parameters to store results
    name_dir_raw_res = expe.name_dir_raw_res
    name_dir_raw_res = name_dir_raw_res + "_" + args.expe

    name_expe = expe.name_expe
    path_expe = "sgcl/expes/%s/" % name_expe
    check_and_create_dirs(
        name_expe=name_expe, name_dir_raw_res=name_dir_raw_res)
    dict_gap_freq = expe.dict_gap_freq
    list_n_dipoles = expe.list_n_dipoles
    list_seed = expe.list_seed
    list_amplitudes = expe.list_amplitudes
    dict_list_p_alpha = expe.dict_list_p_alpha
    resolution = expe.resolution

# for n_dipoles, seed, whiten, n_epochs, amplitude \
#     in product(list_n_dipoles, list_seed, list_whiten, list_n_epochs, list_amplitudes):
#         path_X = name_dir_raw_res  + ("/X_n_dipole_%i_seed_%i.npy" % (n_dipoles, seed))
#         path_B_star = name_dir_raw_res  + ("/B_star_n_dipoles_%i_seed_%i_ampli_%.2f.npy" \
#             % (n_dipoles, seed, amplitude))
#         # path_X = path_expe + name_dir_raw_res  + ("/X_n_dipole_%i_seed_%i.npy" % (n_dipoles, seed))
#         path_all_epochs = name_dir_raw_res  + \
#             ("/all_epochs_star_n_epochs_%i_n_dipoles_%i_seed_%i_ampli_%.2f.npy" \
#                  % (n_epochs, n_dipoles, seed, amplitude))
#         if not (os.path.isfile(path_X) and \
#                 os.path.isfile(path_B_star) and os.path.isfile(path_all_epochs)):
#             X, all_epochs, B_star, cov_data, (colorer, source_colorer), stc = \
#                 get_semi_real_data(
#                     n_times=n_times, n_epochs=n_epochs, n_dipoles=n_dipoles,
#                     whiten=whiten, seed=seed, amplitude=amplitude * 1e-9)
#             np.save(path_X, X)
#             np.save(path_B_star, B_star)
#             np.save(path_all_epochs, all_epochs)


def parallel_function(pb_name, n_dipoles, whiten, n_epochs, seed, amplitude):
    print("-----------------------------------------------------")
    print("solver name: %s " % pb_name)
    params = (pb_name, n_dipoles, n_epochs, whiten, seed, amplitude)

    path_dense_Bs = get_path_expe(name_expe, path_expe, name_dir_raw_res, params, extension='npy', obj="dense_Bs")
    path_masks_Bs = get_path_expe(name_expe, path_expe, name_dir_raw_res, params, extension='npy', obj="masks_Bs")

    # path_X = path_expe + name_dir_raw_res  + ("/X_n_dipole_%i_seed_%i.npy" % (n_dipoles, seed))
    # path_X = name_dir_raw_res  + ("/X_n_dipole_%i_seed_%i.npy" % (n_dipoles, seed))
    # X = np.load(path_X)
    # path_all_epochs = name_dir_raw_res  + \
    #     ("/all_epochs_star_n_epochs_%i_n_dipoles_%i_seed_%i_ampli_%.2f.npy" \
    #             % (n_epochs, n_dipoles, seed, amplitude))
    # all_epochs = np.load(path_all_epochs)
    X, all_epochs, B_star, cov_data, (colorer, source_colorer), stc = \
        get_semi_real_data_v2(
            n_times=n_times, n_epochs=n_epochs, n_dipoles=n_dipoles,
            whiten=whiten, seed=seed, amplitude=amplitude * 1e-9,
            meg="grad", eeg=False, SNR=None, resolution=resolution)

    path_B_star = name_dir_raw_res + \
        ("/B_star_n_dipoles_%i_seed_%i_ampli_%.2f.npy" %
            (n_dipoles, seed, amplitude))
    np.save(path_B_star, B_star)

    list_p_alpha = dict_list_p_alpha[pb_name]
    if not os.path.isfile(path_masks_Bs) or not os.path.isfile(path_dense_Bs):
        Y = all_epochs.mean(axis=0)
        if pb_name == "CLaR" or pb_name == "MTLME" or \
           pb_name == "MRCER" or pb_name == "MLER":
            measurement = all_epochs
        else:
            measurement = Y

        sigma_min = get_sigma_min(Y)
        alpha_max = get_alpha_max(X, measurement, sigma_min, pb_name)
        print("alpha_max = %.2f" % alpha_max)
        print("sigma_min = %.2f" % sigma_min)

        gap_freq = dict_gap_freq[pb_name]
        B_hat = None
        assert name_expe == "expe4"
        # import ipdb; ipdb.set_trace()
        dict_masks, dict_dense_Bs = get_path(
            X, measurement, list_p_alpha, alpha_max,
            sigma_min, B0=B_hat,
            n_iter=n_iter, tol=tol, gap_freq=gap_freq,
            active_set_freq=active_set_freq,
            S_freq=S_freq, pb_name=pb_name, use_accel=False,
            heur_stop=True)

        np.save(path_masks_Bs, dict_masks)
        np.save(path_dense_Bs, dict_dense_Bs)


if __name__ == '__main__':

    print("enter parallel")
    n_jobs = np.minimum(n_jobs, 48)
    Parallel(n_jobs=n_jobs, verbose=100, backend='multiprocessing')(
        delayed(parallel_function)(pb_name, n_dipoles, whiten, n_epochs, seed, amplitude)
        for pb_name, n_dipoles, whiten, n_epochs, seed, amplitude in \
            product(list_pb_name, list_n_dipoles, list_whiten, list_n_epochs, list_seed, list_amplitudes))
    print('OK finished')
    # for pb_name, n_dipoles, whiten, n_epochs, seed, amplitude in \
    #     product(list_pb_name, list_n_dipoles, list_whiten, \
    #         list_n_epochs, list_seed, list_amplitudes):
    #         parallel_function(pb_name, n_dipoles, whiten, n_epochs, seed, amplitude)
