import time
import numpy as np
import mne
import mayavi.mlab
from joblib import Memory

from mne.viz import plot_sparse_source_estimates
from mne.inverse_sparse.mxne_inverse import _make_sparse_stc

from data.real import get_data
from clar.solvers import wrap_solver
from clar.utils import get_alpha_max, get_sigma_min
from expes.utils import configure_plt
from mne.datasets import sample
from surfer import Brain
from mayavi import mlab
from surfer import Brain

mem = Memory(location=".")

##############################################################################
##############################################################################
################  Chose the optimization problem you want to solve
pb_name = "CLaR"
# pb_name = "MLER"
# pb_name = "MTL"
# pb_name = "SGCL"
# comment and uncomment lines here to launch MLER, MTL etc !
##############################################################################
##############################################################################


# load real, preprocessed data
snty_chk = False
eeg = False
meg = "mag"
event_id = 1
if event_id == 1 or event_id == 2:
    tmax = 0.3
else:
    tmax = 0.100

resolution = 6
X, all_epochs, src_wgth, fwd, info, times = mem.cache(get_data)(
    resolution=resolution, snty_chk=snty_chk, eeg=eeg, meg=meg,
    event_id=event_id,
    tmax=tmax)
Y = all_epochs.mean(axis=0)

# params of the algo
dict_data = {}
dict_data["MTL"] = Y
dict_data["MLE"] = Y
dict_data["MLER"] = all_epochs
dict_data["SGCL"] = Y
dict_data["CLaR"] = all_epochs
dict_data["MRCE"] = Y
dict_data["MRCER"] = all_epochs
dict_data["NNCVX"] = all_epochs

dict_p_alpha = {}
if event_id == 1:
    dict_p_alpha["CLaR"] = 0.9
    dict_p_alpha["MLE"] = 0.85
    dict_p_alpha["MLER"] = 0.9
    dict_p_alpha["MRCE"] = 0.9
    dict_p_alpha["MRCER"] = 0.96
    dict_p_alpha["SGCL"] = 0.99
    dict_p_alpha["MTL"] = 0.9
elif event_id == 2:
    dict_p_alpha["MTL"] = 0.925
    dict_p_alpha["CLaR"] = 0.815
    dict_p_alpha["MLE"] = 0.9
    dict_p_alpha["MLER"] = 0.99
    dict_p_alpha["MRCE"] = 0.9
    dict_p_alpha["MRCER"] = 0.96
    dict_p_alpha["SGCL"] = 0.99
elif event_id == 3:
    dict_p_alpha["MTL"] = 0.999
    dict_p_alpha["CLaR"] = 0.95
    dict_p_alpha["MLE"] = 0.85
    dict_p_alpha["MLER"] = 0.99
    dict_p_alpha["MRCE"] = 0.9
    dict_p_alpha["MRCER"] = 0.96
    dict_p_alpha["SGCL"] = 0.99
else:
    dict_p_alpha["MTL"] = 0.999
    dict_p_alpha["CLaR"] = 0.95
    dict_p_alpha["MLE"] = 0.85
    dict_p_alpha["MLER"] = 0.99
    dict_p_alpha["MRCE"] = 0.9
    dict_p_alpha["MRCER"] = 0.96
    dict_p_alpha["SGCL"] = 0.9999

dict_heur_stop = {}
dict_heur_stop["CLaR"] = False
dict_heur_stop["SGCL"] = False
dict_heur_stop["MTL"] = False
dict_heur_stop["MLE"] = False
dict_heur_stop["MLER"] = False
dict_heur_stop["MRCER"] = False

p_alpha = dict_p_alpha[pb_name]

# parameter of the algorithm
tol = 10**-2
active_set_freq = 20
S_freq = 1
heur_stop = False
sigma_min = get_sigma_min(Y)
n_iter = 10 ** 3
obs = dict_data[pb_name]
alpha_max = get_alpha_max(
    X, dict_data[pb_name], sigma_min=sigma_min, pb_name=pb_name)

if pb_name == "MLER" or pb_name == "MRCER" or pb_name == "MLE" or if pb_name == "MLER" or pb_name == "MRCER" or pb_name == "MLE" or pb_name=="SGCL":
    heur_stop = Truepb_name=="SGCL":
    heur_stop = True

# run algo
t_start = time.time()
B_dns, supp = wrap_solver(
    X, obs, p_alpha, pb_name=pb_name, tol=tol, heur_stop=heur_stop, n_iter=n_iter,
    active_set_freq=active_set_freq, S_freq=S_freq)
B_dns *= src_wgth[supp][:, np.newaxis]
t_end = time.time() - t_start


########################################
# plot brains
stc = _make_sparse_stc(
        B_dns, supp, fwd, tmin=times[0], tstep=1. / info['sfreq'])
print("X.shape = ", X.shape)

background = "black"
foreground = "black"
save_fname = pb_name + '.pdf'
subject = "sample"
surface = "white"
s = 8
data_path = sample.data_path()
subjects_dir = data_path + '/subjects'
list_hemi = ["lh", "rh"]

for i, hemi in enumerate(list_hemi):
    brain = Brain(
        subject, hemi, surface, subjects_dir=subjects_dir,
        offscreen=False, background=background, foreground=foreground)
        # views="ventral")
    surf = brain.geo[hemi]
    sources_h = stc.vertices[i]  # 0 pour lh
    for sources in sources_h:
        mlab.points3d(
            surf.x[sources], surf.y[sources],
            surf.z[sources], color=(1, 0, 0),
            scale_factor=s, opacity=0.7, transparent=True)
        # brain.save_montage(filename=save_fname)
    mlab.show()
