# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

"""Double ML IV for Heterogeneous Treatment Effects.

An Double/Orthogonal machine learning approach to estimation of heterogeneous
treatment effect with an endogenous treatment and an instrument.

"""

import numpy as np
from sklearn.model_selection import KFold, train_test_split, StratifiedKFold
from econml.utilities import hstack
from ortho_linear_regression import LinearWithGradientCorrection
from sklearn.base import clone


class _BaseOrthoDMLIV:
    """
    The class DMLIV implements Algorithm 1 for estimating a CATE that is orthogonal
    to $p$ and $q$ but not with respect to $h$
    """

    def __init__(self, model_Y_X, model_T_X, model_T_XZ, prel_model_effect, model_effect,
                 n_splits=2, binary_instrument=False, binary_treatment=False):
        """
        Parameters
        ----------
        model_Y_X : model to predict E[Y | X]
        model_T_X : model to predict E[T | X]. In alt_fit this model is also used
            to predict E[ E[T | X,Z] | X], i.e. regress E[T | X,Z] on X.
        model_T_XZ : model to predict E[T | X, Z]
        model_effect : final model that at fit time takes as input (residual Y), (residual T) and X
            and supports method .effect(X) that produces the cate at X
        n_splits : number of splits to use in cross-fitting
        binary_instrument : whether to stratify cross-fitting splits by instrument
        binary_treatment : whether to stratify cross-fitting splits by treatment
        """
        self.model_T_XZ = [clone(model_T_XZ, safe=False) for _ in range(n_splits)]
        self.model_Y_X = [clone(model_Y_X, safe=False) for _ in range(n_splits)]
        self.model_T_X = [clone(model_T_X, safe=False) for _ in range(n_splits)]
        self.prel_model_effect =[clone(prel_model_effect, safe=False) for _ in range(n_splits)]
        self.model_effect = model_effect
        self.n_splits = n_splits
        self.binary_instrument = binary_instrument
        self.binary_treatment = binary_treatment

    def fit(self, y, T, X, Z):
        """
        Parameters
        ----------
        y : outcome
        T : treatment (single dimensional)
        X : features/controls
        Z : instrument (single dimensional)
        """
        if len(T.shape) > 1 and T.shape[1] > 1:
            raise AssertionError("Can only accept single dimensional treatment")
        if len(y.shape) > 1 and y.shape[1] > 1:
            raise AssertionError("Can only accept single dimensional outcome")
        if len(Z.shape) == 1:
            Z = Z.reshape(-1, 1)
        if (Z.shape[1] > 1) and self.binary_instrument:
            raise AssertionError("Binary instrument flag is True, but instrument is multi-dimensional")
        T = T.flatten()
        y = y.flatten()

        n_samples = y.shape[0]
        proj_t = np.zeros(n_samples)
        pred_t = np.zeros(n_samples)
        res_y = np.zeros(n_samples)
        prel_theta = np.zeros(n_samples)

        if self.n_splits == 1:
            splits = [(np.arange(X.shape[0]), np.arange(X.shape[0]))]
        # TODO. Deal with multi-class instrument
        elif self.binary_instrument or self.binary_treatment:
            group = 2*T*self.binary_treatment + Z.flatten()*self.binary_instrument
            splits = StratifiedKFold(n_splits=self.n_splits, shuffle=True).split(X, group)
        else:
            splits = KFold(n_splits=self.n_splits, shuffle=True).split(X)

        for idx, (train, test) in enumerate(splits):
            # Estimate h(Z, X) = E[T | Z, X] in cross-fitting manner
            proj_t[test] = self.model_T_XZ[idx].fit(hstack([X[train], Z[train]]),
                                               T[train]).predict(hstack([X[test],
                                                                         Z[test]]))
            # Estimate residual Y_res = Y - q(X) = Y - E[Y | X] in cross-fitting manner
            res_y[test] = y[test] - self.model_Y_X[idx].fit(X[train], y[train]).predict(X[test])
            # Estimate p(X) = E[T | X] in cross-fitting manner
            pred_t[test] = self.model_T_X[idx].fit(
                X[train], T[train]).predict(X[test])
            # Preliminary theta(X)
            prel_theta[test] = self.prel_model_effect[idx].fit(
                y[train], T[train], X[train], Z[train]).effect(X[test]).flatten()
        
        # Orthogonal correction
        corrections = prel_theta * (proj_t - pred_t).flatten() * (T - proj_t.flatten())
        self.corrections=corrections
        # Estimate theta by minimizing square loss (Y_res - theta(X) * (h(Z, X) - p(X)))^2
        self.model_effect.fit(res_y, (proj_t - pred_t).reshape(-1, 1), X, corrections)

        return self

    def effect(self, X):
        """
        Parameters
        ----------
        X : features
        """
        return self.model_effect.predict(X)


    @property
    def coef_(self):
        return self.effect_model.coef_
    @property
    def intercept_(self):
        return self.effect_model.intercept_
    @property
    def effect_model(self):
        return self.model_effect
    @property
    def fitted_nuisances(self):
        return {'prel_model_effect': self.prel_model_effect,
                'model_Y_X': self.model_Y_X,
                'model_T_X': self.model_T_X,
                'model_T_XZ': self.model_T_XZ}

class OrthoDMLIV(_BaseOrthoDMLIV):
    """
    A child of the _BaseDMLIV class that specifies a particular effect model
    where the treatment effect is linear in some featurization of the variable X
    The features are created by a provided featurizer that supports fit_transform.
    Then an arbitrary model fits on the composite set of features.
    """

    def __init__(self, model_Y_X, model_T_X, model_T_XZ, prel_model_effect,
                featurizer, reg_feature_inds=None, fit_intercept=True,
                alpha_l1=0.1, alpha_l2=0.1, tol=1e-6, n_splits=2,
                binary_instrument=False, binary_treatment=False):
        """
        Parameters
        ----------
        model_Y_X : model to predict E[Y | X]
        model_T_X : model to predict E[T | X]
        model_T_XZ : model to predict E[T | X, Z]
        featurizer : create features of X to use for effect model
        alpha_l1 : l1 regularization
        alpha_l2 : l2 regularization
        tol : optimization tolerance
        n_splits : number of cross fitting splits (n_splits=1 means no-split)
        """
        class ModelEffect:
            """
            A wrapper class that takes as input X, T, y and estimates an effect model of the form
            y= theta(X) * T + epsilon
            """

            def __init__(self):
                """
                Parameters
                ----------
                model_effect : model for CATE. At fit takes as input features(X) * (residual T)
                    and (residual Y). At predict time takes as input features(X)
                featurizer : model to produces features(X) from X
                """
                self.model_effect = LinearWithGradientCorrection(reg_feature_inds=reg_feature_inds,
                                                                 fit_intercept=fit_intercept,
                                                                 alpha_l1=alpha_l1, alpha_l2=alpha_l2, tol=tol)
                self.featurizer = featurizer

            def fit(self, y, T, X, corrections):
                """
                Parameters
                ----------
                y : outcome
                T : treatment
                X : features
                prel_theta : preliminary model of treatment effects
                """
                self.model_effect.fit(self.featurizer.fit_transform(X) * T, y, grad_corrections=corrections)
                return self

            def predict(self, X):
                """
                Parameters
                ----------
                X : features
                """
                return self.model_effect.predict(self.featurizer.fit_transform(X))

            def __getattr__(self, name):
                return getattr(self.model_effect, name)

        super(OrthoDMLIV, self).__init__(model_Y_X, model_T_X, model_T_XZ,
                                    prel_model_effect,
                                    ModelEffect(),
                                    n_splits=n_splits,
                                    binary_instrument=binary_instrument,
                                    binary_treatment=binary_treatment)
