import numpy as np
from joblib import Parallel, delayed
from pyriemann.tangentspace import TangentSpace
from skada.base import DAEstimator
from sklearn.linear_model import Ridge


def score(y_true, y_pred):
    return -np.mean((y_true - y_pred) ** 2)


class Dummy(DAEstimator):
    def __init__(self, y_mean):
        # check if y_mean is a dictionary
        assert isinstance(y_mean, dict), 'y_mean should be a dictionary'
        self.y_mean = y_mean

    def fit(self, X, y=None, sample_domain=None):
        self.fitted_ = True
        return self

    def predict(self, X, sample_domain=None):
        y_pred = np.zeros(X.shape[0])
        for domain in np.unique(sample_domain):
            mask = sample_domain == domain
            y_pred[mask] = self.y_mean[np.abs(domain)]

        return y_pred

    def score(self, X, y, sample_domain=None):
        # Predict
        y_pred = self.predict(X, sample_domain)

        # Return score
        return score(y, y_pred)


class TSRidge(DAEstimator):
    def __init__(self, recenter, fit_intercept_per_domain, y_mean=None,
                 lambda_=1, n_jobs=1):
        self.recenter = recenter
        self.fit_intercept_per_domain = fit_intercept_per_domain
        if fit_intercept_per_domain:
            assert isinstance(y_mean, dict), 'y_mean should be a dictionary'
        self.y_mean = y_mean
        self.lambda_ = lambda_
        self.n_jobs = n_jobs

    def fit(self, X, y, sample_domain=None):
        assert X.ndim == 4, 'X should be a 4D tensor: (n_samples, n_freqs, n_channels, n_channels)'

        # Get source data
        X = X[sample_domain >= 0]
        y = y[sample_domain >= 0]
        sample_domain = sample_domain[sample_domain >= 0]

        # Map to tangent space
        n_samples, n_freqs, _, _ = X.shape
        if self.recenter:
            Z = list()
            indices = list()
            for k in np.unique(sample_domain):
                mask = sample_domain == k
                Z_freq = Parallel(n_jobs=self.n_jobs)(
                    delayed(
                        TangentSpace(metric='riemann',
                                     tsupdate=False).fit_transform
                    )(X[mask, freq]) for freq in range(n_freqs)
                )
                Z_freq = np.stack(Z_freq).transpose(1, 0, 2)
                Z.append(Z_freq)
                indices.append(np.arange(n_samples)[mask])
            Z = np.concatenate(Z)
            indices = np.concatenate(indices)
            Z = Z[np.argsort(indices)]
        else:
            self.ts_ = Parallel(n_jobs=self.n_jobs)(
                delayed(
                    TangentSpace(metric='riemann', tsupdate=False).fit
                )(X[:, freq]) for freq in range(n_freqs)
            )
            Z = Parallel(n_jobs=self.n_jobs)(
                delayed(
                    ts.transform
                )(X[:, freq]) for freq, ts in enumerate(self.ts_)
            )
            Z = np.stack(Z).transpose(1, 0, 2)
        Z = Z.reshape(n_samples, -1)

        # Fit intercept
        if not self.fit_intercept_per_domain:
            self.intercept_ = np.mean(y)

        # Train Ridge
        if self.fit_intercept_per_domain:
            y_centered = np.zeros(y.shape)
            for k in np.unique(sample_domain):
                mask = sample_domain == k
                y_centered[mask] = y[mask] - self.y_mean[np.abs(k)]
        else:
            y_centered = y - self.intercept_
        self.ridge_ = Ridge(alpha=self.lambda_, fit_intercept=False)
        self.ridge_.fit(Z, y_centered)

        return self

    def predict(self, X, sample_domain=None):
        # Map to tangent space
        n_samples, n_freqs, _, _ = X.shape
        if self.recenter:
            Z = list()
            indices = list()
            for k in np.unique(sample_domain):
                mask = sample_domain == k
                Z_freq = Parallel(n_jobs=self.n_jobs)(
                    delayed(
                        TangentSpace(metric='riemann',
                                     tsupdate=False).fit_transform
                    )(X[mask, freq]) for freq in range(n_freqs)
                )
                Z_freq = np.stack(Z_freq).transpose(1, 0, 2)
                Z.append(Z_freq)
                indices.append(np.arange(n_samples)[mask])
            Z = np.concatenate(Z)
            indices = np.concatenate(indices)
            Z = Z[np.argsort(indices)]
        else:
            Z = Parallel(n_jobs=self.n_jobs)(
                delayed(
                    ts.transform
                )(X[:, freq]) for freq, ts in enumerate(self.ts_)
            )
            Z = np.stack(Z).transpose(1, 0, 2)
        Z = Z.reshape(n_samples, -1)

        # Predict
        y_pred = self.ridge_.predict(Z)
        if self.fit_intercept_per_domain:
            for k in np.unique(sample_domain):
                mask = sample_domain == k
                y_pred[mask] += self.y_mean[np.abs(k)] - np.mean(y_pred[mask])
        else:
            y_pred += self.intercept_

        return y_pred

    def score(self, X, y, sample_domain=None):
        # Predict
        y_pred = self.predict(X, sample_domain)

        # Return score
        return score(y, y_pred)
