import os
import numpy as np
import scipy.sparse as sp
from itertools import product
import scipy
from scipy.optimize import fmin_l_bfgs_b

class LinearWithGradientCorrection():
    def __init__(self, alpha_l1=0.1, alpha_l2=0.1, tol=1e-6,
                 reg_feature_inds=None, fit_intercept=True):
        self._alpha_l1 = alpha_l1
        self._alpha_l2 = alpha_l2
        self._tol = tol
        self._coef = None
        self._feature_inds = reg_feature_inds
        self.fit_intercept = fit_intercept

    def fit(self, X, y, grad_corrections=None, sample_weights=None):
        n_samples, n_features = X.shape
        
        y = y.reshape(-1, 1)

        if grad_corrections is None:
            grad_corrections = np.zeros((n_samples, 1))
        grad_corrections = grad_corrections.reshape(-1, 1)
        if sample_weights is None:
            sample_weights = np.ones((n_samples, 1))
        sample_weights = sample_weights.reshape(-1, 1)

        def loss_and_jac(extended_coef):
            intercept = extended_coef[-1]
            coef = extended_coef[:n_features] - extended_coef[n_features:-1]
            extended_coef = extended_coef[:-1]
            index = np.matmul(X, coef.reshape(-1, 1))
            y_pred = index + intercept
            m_loss = .5 * (y - y_pred)**2 + grad_corrections * index
            if self._feature_inds is None:
                reg_weights = np.ones(len(extended_coef))
            else:
                reg_weights = np.zeros(len(extended_coef))
                feature_inds = np.array(self._feature_inds)
                reg_weights[feature_inds] = 1
                reg_weights[n_features + feature_inds] = 1
            loss = np.mean(sample_weights * m_loss) + self._alpha_l1 * np.sum(reg_weights * extended_coef)\
                    + 0.5 * self._alpha_l2 * np.sum((reg_weights * extended_coef)**2)
            moment = (y_pred - y + grad_corrections) * X
            grad = np.mean(sample_weights * moment, axis=0).flatten() 
            jac = np.concatenate((grad, -grad)) + self._alpha_l1 * reg_weights\
                 + self._alpha_l2 * reg_weights * extended_coef
            if self.fit_intercept:
                jac = np.concatenate((jac, [np.mean(sample_weights * (y_pred - y))]))
            else:
                jac = np.concatenate((jac, [0]))
            return loss, jac
        
        w, _, _ = fmin_l_bfgs_b(loss_and_jac, np.zeros(1 + 2*n_features), 
                                bounds=[(0, None)] * n_features * 2 + [(None, None)],
                                pgtol=self._tol)

        self._coef = w[:n_features] - w[n_features:-1]
        self._intercept = w[-1]
    
        return self

    def predict(self, X):     
        return np.matmul(X, self.coef_) + self.intercept_

    @property
    def coef_(self):
        return self._coef

    @property
    def intercept_(self):
        return self._intercept
