import time
import warnings
import numpy as np
from numpy.linalg import norm
import matplotlib.pyplot as plt
from libsvmdata import fetch_libsvm
from sklearn.exceptions import ConvergenceWarning
from sklearn.linear_model import Lasso as Lasso_sklearn


try:
    from sksparse import Lasso
except ImportError:
    print(
        "Module sksparse not found, did you run `pip install -e .` in this "
        "directory?"
    )


def lasso_loss(X, y, beta, lmbda):
    return norm(y - X @ beta) ** 2 / (2 * len(y)) + lmbda * norm(beta, ord=1)


X, y = fetch_libsvm("rcv1.binary")
lmbda_max = norm(X.T @ y, ord=np.inf) / len(y)
lmbda = lmbda_max / 50


params = dict(alpha=lmbda, fit_intercept=False, tol=1e-10, max_iter=2,
              warm_start=False)
######################## sksparse #############################################
clf = Lasso(**params)
clf.fit(X, y)  # compile numba code


sks_loss = []
sks_time = []
for max_iter in range(1, 20):
    t0 = time.time()
    clf.max_iter = max_iter
    clf.fit(X, y)
    sks_time.append(time.time() - t0)
    sks_loss.append(lasso_loss(X, y, clf.coef_, lmbda))

######################## sklearn ##############################################
warnings.filterwarnings("ignore", category=ConvergenceWarning)
sklearn_loss = []
sklearn_time = []
clf = Lasso_sklearn(**params)
for max_iter in range(1, 200, 10):
    t0 = time.time()
    clf.max_iter = max_iter
    clf.fit(X, y)
    sklearn_time.append(time.time() - t0)
    sklearn_loss.append(lasso_loss(X, y, clf.coef_, lmbda))


sks_loss = np.array(sks_loss)
sklearn_loss = np.array(sklearn_loss)
min_loss = min(sks_loss.min(), sklearn_loss.min())

plt.figure()
plt.title(f"Lasso problem, rcv1.dataset, (n, p) = {X.shape}")
plt.semilogy(sks_time, sks_loss - min_loss, label='sksparse (ours)')
plt.semilogy(sklearn_time, sklearn_loss - min_loss, label='sklearn')
plt.xlabel("Time (s)")
plt.ylabel(r"$f(\beta) - f(\hat \beta)$")
plt.legend()
plt.show(block=False)
