'''
Script to reproduce the SNR figure of the article.
'''
import numpy as np

import matplotlib.pyplot as plt
from sklearn.linear_model import RidgeCV
from sklearn.dummy import DummyRegressor
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import cross_val_score

from pyriemann.tangentspace import TangentSpace as GeomVect
from wasserstein_tangent import TangentSpace as WassVect
from generation import generate_covariances
from utils import logDiag


print('Running SNR experiment...')

rng = 4

# Noise levels
sigmas = np.logspace(-2, 1, 10)

# Parameters
n_matrices = 100  # Number of matrices
n_channels = 5   # Number of channels
n_sources = 2  # Number of sources
distance_A_id = .3  # Parameter 'mu': distance from A to Id
f_powers = 'log'  # link function between the y and the source powers


# Chose embeddings
embeddings = [None, logDiag(), WassVect(n_channels), GeomVect()]
names = ['Chance level', 'Log-powers', 'Wasserstein', 'Geometric']

# Run simulation
results = np.zeros((len(names), len(sigmas)))
for j, sigma in enumerate(sigmas):
    X, y = generate_covariances(n_matrices, n_channels, n_sources, sigma=sigma,
                                distance_A_id=distance_A_id, f_p=f_powers,
                                rng=rng)
    for i, (name, embedding) in enumerate(zip(names, embeddings)):
        print('sigma = {}, {} method'.format(sigma, name))
        lr = RidgeCV(alphas=np.logspace(-7, 3, 100),
                     scoring='neg_mean_absolute_error')
        if name == 'Chance level':
            pipeline = Pipeline([('emb', logDiag()),
                                 ('sc', StandardScaler()),
                                 ('lr', DummyRegressor())])
        else:
            pipeline = Pipeline([('emb', embedding),
                                 ('sc', StandardScaler()),
                                 ('lr', lr)])

        sc = cross_val_score(pipeline, X, y,
                             scoring='neg_mean_absolute_error',
                             cv=10, n_jobs=3)
        results[i, j] = -np.mean(sc)

# Plot
f, ax = plt.subplots(figsize=(4, 3))
results /= results[0]
for i, name in enumerate(names):
    if name != 'Chance level':
        ls = None
    else:
        ls = '--'
    ax.plot(sigmas, results[i],
            label=name,
            linewidth=3,
            linestyle=ls)


ax.set_xlabel('sigma')
ax.set_xscale('log')
plt.grid()
ax.set_ylabel('Normalized M.A.E.')
ax.hlines(0, sigmas[0], sigmas[-1], label=r'Perfect',
          color='k', linestyle='--', linewidth=3)
ax.legend(loc='lower right')
f.tight_layout()
plt.show()
