import numpy as np

import importlib

from utils import *
from utils import matrix_discover
import expt_utils as expt_utils
# importlib.reload(expt_utils)


algos_dict = {
  'IDS': [IDS_matrix,
    {"gamma": 2, "sigma": 0.01, 'update': 5}, 10, 'IDS', '-og'],
  'Greedy': [IDS_matrix,
    {"gamma": 3, "sigma": 0.01, 'update': 5, 'greedy': True}, 10, 'Greedy', '-sm'],
  'Random': [matrix_solver, {}, 10, 'Random', '^y']
}

# algos_dict = {
#   'IDS50': [IDS_matrix,
#     {"gamma": 3, 'update': 5, 'sigma':50.0}, 10, 'IDS50', '-og'],
#   'IDS20': [IDS_matrix,
#     {"gamma": 3, 'update': 5, 'sigma':20.0}, 10, 'IDS20', '-sm'],
#   'IDS10': [IDS_matrix,
#     {"gamma": 3, 'update': 5, 'sigma':10.0}, 10, 'IDS10', '^y'],
#   'IDS5': [IDS_matrix,
#     {"gamma": 3, 'update': 5, 'sigma':5.0}, 10, 'IDS5', '-oc'],
#   'IDS1': [IDS_matrix,
#     {"gamma": 3, 'update': 5, 'sigma':1.0}, 10, 'IDS1', '-ok'],
#   'Random': [matrix_solver, {}, 10, 'Random', '^r']
# }


# algos_dict = {
#   'IDS10': [IDS_matrix,
#     {"gamma": 3, 'update': 5, 'sigma':0.001}, 10, 'IDS001', '^y'],
#   'IDS5': [IDS_matrix,
#     {"gamma": 3, 'update': 5, 'sigma':0.1}, 10, 'IDS1', '-oc'],
#   'IDS1': [IDS_matrix,
#     {"gamma": 3, 'update': 5, 'sigma':0.05}, 10, 'IDS05', '-ok'],
#   'IDS01': [IDS_matrix,
#     {"gamma": 3, 'update': 5, 'sigma':0.01}, 10, 'IDS01', '-ob'],
#   'Random': [matrix_solver, {}, 10, 'Random', '^r']
# }

prob_dict = {
  "m1": 23,
  "m2": 15,
  "R": 3,
  "noise": 0.1
}
prob = matrix_discover(**prob_dict)
# x_dat = np.array([
#     [14.01,18.72,4.11,2.86,1.54,1.12,2.31,4.90,2.92,6.54,6.46,4.84],
#     [0,1.40,0.01,0.58,0.10,1.10,0.66,0.04,1.70,0.29,3.65,0],
#     [3.97,6.50,1.09,2.29,2.52,1.03,1.19,4.20,5.13,2.47,9.66,8.45],
#     [16.21,12.24,6.09,2.62,4.36,7.43,3.76,2.45,1.38,4.41,11.66,4.68],
#     [6.67,9.50,3.38,1.60,4.45,2.10,1.98,0.17,0.76,0.47,0.50,4.11],
#     [1.07,2.93,1.42,0.09,0.31,4.70,4.29,0.00,2.07,4.13,0.00,0.33],
#     [0.00,2.39,0.28,4.45,0.05,1.01,1.84,0.07,1.73,1.70,3.85,0.02],
#     [0.11,3.12,0.60,2.40,0.17,0.00,2.52,0.37,1.11,0.61,9.50,0.03]
# ])
# x_dat = np.array(
#   [
#     [ 85.,  43.,  58.,  86.,  60.,  20.,  25.,  44.,  59., 100.,  30., 40.],
#     [ 80.,  38.,  18.,  51.,  66.,  63.,  28.,  50.,  78., 100.,  40., 20.],
#     [ 46., 32., 23., 48., 57., 53., 38., 57., 25., 74., 54., 29.],
#     [ 95.,  19.,  95.,  47.,  51., 100.,  32.,  29.,  14., 100.,  53., 26.],
#     [78., 25.,  4., 13., 17., 29., 19.,  7., 32., 66., 18., 32.],
#     [66.,  4., 23.,  9., 18., 24.,  6.,  8.,  0., 48.,  0.,  8.],
#     [43.,  0.,  3., 38.,  8.,  7.,  3.,  3.,  0., 33.,  0.,  0.],
#     [44.,  0.,  5.,  4.,  3.,  0.,  2.,  3.,  0., 27.,  0.,  0.]])

x_dat = np.array([[14.56538038, 18.74350865, 19.60839775,  2.06058185, 16.39030586,
        19.65840226,  6.85348014, 26.75645727, 27.04861196, 31.29909936,
        35.2030282 , 40.08669181, 68.76495224, 74.00970333, 70.58038327],
       [15.4176361 , 18.97417112, 16.81589184,  2.66335504, 16.04854773,
        21.9766332 ,  7.64191788, 28.39231834, 26.6748995 , 29.97543001,
        36.30636065, 36.58798902, 10.54082024, 26.25688388, 28.58964812],
       [34.28672739, 51.07128287, 44.71241275,  1.11250509, 50.90618574,
        54.23706048,  4.69506106, 83.3036144 , 70.00785114, 52.85443039,
        72.15476775, 73.13874552, 18.412337  ,  1.57607478,  1.83416125],
       [ 5.98097959, 50.58289676, 45.08048043,  0.        , 42.11946098,
        49.11143206,  0.70522529, 70.86219062, 66.44951759, 65.93988128,
        66.07082051, 66.20793624,  6.11069159, 67.81850094, 72.63064909],
       [31.76101887, 47.94242943, 42.81292825,  0.39886034, 45.57576479,
        49.32093179,  3.62246647, 75.31115449, 67.19184926, 54.63079808,
        63.14714159, 63.83709974, 20.39347896,  0.        ,  0.        ],
       [45.04873884, 45.78100103, 44.65371267,  0.92454162, 44.55887608,
        43.35721014,  3.42016634, 72.84831426, 65.31185167, 56.83280355,
        76.33726193, 73.22050209, 36.47724246, 68.55764886, 69.0890435 ],
       [10.83765987, 12.50803772, 12.73955434,  2.39499309, 11.43188302,
        12.30347995,  6.79059046, 20.80036721, 21.68053196, 17.93331506,
        19.21909567, 20.19717377,  0.        ,  0.        ,  0.        ],
       [ 9.61611288, 13.66496862, 11.85429164,  2.57933242, 16.4358293 ,
        18.29616546,  2.88933036, 19.25650723, 19.42489921, 30.38899776,
        38.88072897, 40.27887601, 11.94269708, 21.93776353, 23.04134387],
       [ 0.        ,  0.        ,  0.        ,  0.        ,  0.80108652,
         2.97427716,  0.        ,  1.59945874,  2.60159062,  3.51009437,
         8.00843211,  9.01199267,  0.        ,  0.        ,  1.17105597],
       [10.05139356, 12.33593597, 10.49048013,  0.97098329, 13.33701887,
        17.14482537,  4.52223888, 14.62829911, 13.95628888, 19.47556274,
        21.66443641, 18.99441923, 12.14875689, 14.44407911, 11.72782605],
       [13.02655322, 16.09836574, 15.87410565,  1.87941995, 14.85699263,
        19.37897224,  4.57865157, 22.0084522 , 22.22340072, 21.68221926,
        35.06478929, 30.72459608, 10.88472418, 21.37540077, 23.67114189],
       [21.18588073, 23.00642082, 39.51659226,  2.04861583, 44.92520247,
        47.05254638, 13.82100293, 72.6088082 , 69.38989345, 62.426823  ,
        76.21686958, 70.35430221, 59.99235473, 75.38356191, 78.29631742],
       [ 1.00926941,  5.03497482,  4.76528177,  0.35922575,  4.62111491,
         3.87682972,  0.40001516,  5.42561659,  8.36792795,  3.45194952,
         4.84059533, 10.03634536,  1.30939332,  2.68976617,  4.41601986],
       [ 8.91218676, 10.72342611, 11.81277895,  1.84960131,  9.06564945,
        13.83920399,  4.2788345 , 16.48404533, 17.45236147, 18.18990312,
        23.65087488, 26.69194458,  8.19409063, 14.81646107, 16.67869351],
       [20.05913492, 34.2208852 , 40.30244069,  1.19400693, 27.35092133,
        26.91242644,  5.62826137, 51.85571731, 49.62220517, 53.94447324,
        65.48767084, 66.1445595 , 38.1181996 , 64.2716813 , 65.73471929],
       [13.16893511, 16.79477369, 16.24533971,  2.11866009, 15.85894672,
        15.94900479,  7.10820448, 24.76005147, 28.11243479, 27.24876847,
        31.76624566, 37.29673945, 16.87891597, 22.41656552, 23.46293461],
       [31.73318055, 33.92093167, 38.41251828,  1.27760451, 46.04956103,
        43.335395  ,  9.85031511, 63.8388463 , 64.60187181, 54.53336611,
        55.70352222, 57.7580975 , 51.9557832 , 71.32414041, 78.36249844],
       [10.16290719, 10.72653134, 11.56918225,  1.92409218,  9.62708941,
        11.09901507,  6.58126665, 16.8324099 , 21.40006849, 42.78380779,
        53.50463767, 60.63710952, 13.25100229, 18.48421566, 20.6070389 ],
       [33.44709826, 36.56526068, 39.55792883,  8.38964842, 65.92742304,
        63.53284655,  9.77180899, 64.80559966, 63.41514905, 62.46560879,
        64.57811697, 63.222267  , 65.3367146 , 72.70586315, 76.79526891],
       [30.59884697, 33.91258635, 36.80129147,  2.49563605, 47.27988184,
        44.44212826, 12.51112995, 66.1081674 , 63.53604944, 53.37807808,
        62.99169098, 63.44404456, 52.55158103, 68.74074216, 77.66689194],
       [17.50108516, 38.94339394, 40.11142329,  0.        , 42.74477554,
        44.8781348 ,  0.64911054, 65.57301901, 66.67600965, 70.24905586,
        69.97972982, 75.73582011,  8.77663053, 74.07785582, 75.28558302],
       [15.75654156, 40.52594956, 38.16592165,  0.39071503, 45.77689221,
        41.36745101,  0.86029566, 63.69608906, 64.24487271, 63.65319734,
        73.26731184, 70.29330199,  8.47025224, 70.27327326, 75.8997955 ],
       [24.17024028, 35.80552258, 40.2947746 ,  3.12093389, 42.30149711,
        40.31193463, 10.90397812, 61.05645885, 64.44051661, 70.14701034,
        70.25926715, 73.95050649, 27.32089436, 72.8784855 , 74.62339174]])

x_dat = x_dat / 100 - 0.5
prob.x = x_dat
prob.x_noise = x_dat
prob.noise = 0

script_file = 'expt_real_data.py'

expt_utils.algos_real_data(
    prob, algos_dict, matrix_discover, T = 345,
    results_dir='results/matrix', script_file=script_file
    )
