import numpy as np
from helpers import extend_basis_function

def h_linear(z, epsilon, coef=None, intercept=0.0) -> np.array:
    """
    This function generates the variable 'x' from the input variable 'z' using a linear relationship. 
    If the 'coef' parameter is not provided, it defaults to an identity matrix. The 'coef' parameter should be a matrix of shape (dz, dx).

    :param z: np.array, the input variable with shape (n_samples, dz)
    :param epsilon: np.array, the noise variable with shape (n_samples, dx)
    :param coef: np.array, the coefficient matrix defining the linear relationship. It should have shape (dz, dx)
    :param intercept: float, the intercept term for the linear relationship
    :return: np.array, the generated variable 'x' with shape (n_samples, dx)
    """
    if coef is None:
        assert coef.shape[0] == z.shape[1], "The number of columns in 'z' should match the number of rows in 'coef'."
        coef = np.eye(z.shape[1])  # If 'coef' is not provided, it defaults to an identity matrix. This is only applicable if dz equals dx. Otherwise, 'coef' should be explicitly specified.
    else:
        assert coef.shape[0] == z.shape[1], "The number of columns in 'z' should match the number of rows in 'coef'."
    assert coef.shape[1] == epsilon.shape[1], "The number of columns in 'coef' should match the number of columns in 'epsilon'."
    return z @ coef + intercept + epsilon


def f_linear(x, coef, intercept=0.0) -> np.array:
    """
    This function generates the variable 'y' from the input variable 'x' using a linear relationship. 
    The 'coef' parameter should be an array with shape (dx,).

    :param x: np.array, the input variable with shape (n_samples, dx)
    :param coef: np.array, the coefficient for the linear relationship
    :param intercept: float, the intercept term
    :return: np.array, the generated variable 'y' with shape (n_samples,)
    """
    # Perform a dot product between 'x' and 'coef'
    assert x.shape[1] == coef.shape[0], "The number of columns in 'x' should match the number of rows in 'coef'."
    y = np.dot(x, coef)
    
    # Add the intercept to each element of 'y'
    y += intercept
    
    return y


def h_nonlinear_additive(z, epsilon, interaction=False, coef=None) -> np.array:
    """
    This function generates the variable 'x' from the input variable 'z' using a nonlinear additive relationship. 
    If the 'coef' parameter is not provided, it defaults to a matrix filled with ones. The 'coef' parameter should be a matrix of shape (dz, dx).

    :param z: np.array, the input variable with shape (n_samples, dz)
    :param epsilon: np.array, the noise variable with shape (n_samples, dx)
    :param interaction: bool, whether to include interaction effects
    :param coef: np.array, the coefficient matrix defining the nonlinear relationship. It should have shape (dz, dx)
    :return: np.array, the generated variable 'x' with shape (n_samples, dx)
    """
    dx = epsilon.shape[1]
    x_base = np.sin(z)  # Apply the sine function element-wise to 'z', results in (n_samples, dz)
    
    if interaction:
        # If 'interaction' is True, include interaction effects in the following way
        # If the number of columns in 'z' is greater than or equal to 2, calculate the product of the first two columns of 'z' and take the cosine of the result. Otherwise, take the cosine of a matrix filled with ones.
        inter_effect = np.prod(z[:, :2], axis=1, keepdims=True) if z.shape[1] >= 2 else np.ones((z.shape[0], 1))
        x_base += 0.5 * np.cos(z * inter_effect)
    
    if coef is not None:
        # If 'coef' is provided, transform 'x_base' to ensure output dimensionality matches 'dx'
        x_base = np.dot(x_base, coef)
        assert x_base.shape[1] == dx, "The number of columns in 'x_base' should match the number of columns in 'coef'."
    else:
        # If 'coef' is not provided, default to a matrix filled with ones
        coef = np.ones((z.shape[1], dx))
        x_base = np.dot(x_base, coef)
    assert x_base.shape[1] == dx, "The number of columns in 'x_base' should match the number of columns in 'coef'."
    return x_base + epsilon

def h_nonlinear_multiplicative(z, confounder, interaction=False, coef=None) -> np.array:
    """
    This function generates the variable 'x' from the input variable 'z' using a nonlinear multiplicative relationship. 
    If the 'coef' parameter is not provided, it defaults to a matrix filled with ones. The 'coef' parameter should be a matrix of shape (dz, dx).

    :param z: np.array, the input variable with shape (n_samples, dz)
    :param epsilon: np.array, the noise variable with shape (n_samples, dx)
    :param confounder: np.array, the confounding variable affecting 'x', with shape (n_samples, 1)
    :param interaction: bool, whether to include interaction effects
    :param coef: np.array, the coefficient matrix defining the nonlinear relationship. It should have shape (dz, dx)
    :return: np.array, the generated variable 'x' with shape (n_samples, dx)
    """
    dx = z.shape[1]
    x_base = np.sin(z)  # Apply the sine function element-wise to 'z', results in (n_samples, dz)
    
    # Apply confounding effect, the shape of 1 + confounder[:, None] is (n_samples, 1), the multiplication results in (n_samples, dz) via broadcasting, i. e. the confounder is multiplied to each column of x_base
    x_confounded = x_base * (1 + confounder)
    
    if interaction:
        # If 'interaction' is True, include interaction effects
        inter_effect = np.prod(z[:, :2], axis=1, keepdims=True) if z.shape[1] >= 2 else np.ones((z.shape[0], 1))
        x_confounded *= (1 + 0.5 * np.cos(z * inter_effect))
    if coef is not None:
        assert x_confounded.shape[1] == coef.shape[0], "The number of columns in 'x_confounded' should match the number of rows in 'coef'."
        # If 'coef' is provided, transform 'x_confounded' to ensure output dimensionality matches 'dx'
        x_confounded = np.dot(x_confounded, coef)
    else:
        # If 'coef' is not provided, default to a matrix filled with ones
        coef = np.ones((z.shape[1], dx))
        x_confounded = np.dot(x_confounded, coef)
    assert x_confounded.shape[1] == dx, "The number of columns in 'x_confounded' should match the number of columns in 'coef'."
    return x_confounded


def f_nonlinear(x, interaction=False) -> np.array:
    """
    This function generates the variable 'y' from the input variable 'x' using a nonlinear relationship. 
    If the 'interaction' parameter is True, interaction effects are included in the relationship.

    :param x: np.array, the input variable with shape (n_samples, dx)
    :param interaction: bool, whether to include interaction effects
    :return: np.array, the generated variable 'y' with shape (n_samples,)
    """
    # Apply the exponential and sine functions element-wise to 'x'
    x_transformed = np.exp(-x) * np.sin(x)
    
    if interaction and x.shape[1] >= 2:
        # If 'interaction' is True and 'x' has at least 2 columns, include interaction effects via the product of the first two columns of 'x'
        inter_effect = np.prod(x[:, :2], axis=1)
        y = np.sum(x_transformed * inter_effect[:, None], axis=1, keepdims=True)
    else:
        # If 'interaction' is False or 'x' has less than 2 columns, do not include interaction effects
        y = np.sum(x_transformed, axis=1, keepdims=True)
    
    return y


def grad_f_nonlinear(x, comp) -> np.array:
    x_i = x[:, comp]
    
    # Compute the partial derivative with respect to x_i
    partial_derivatives = np.exp(-x_i) * (np.cos(x_i) - np.sin(x_i))
    
    # Return the partial derivatives, reshaping for consistency with f_nonlinear's output
    return partial_derivatives.reshape(-1, 1)



def h_basisfct(z, e_x, interaction=False, coef=None) -> np.array:
    x = z @ coef + z**2 @ coef + z**3 @ coef + e_x
    return x
    

def f_basisfct(x, theta, k_terms, extension="polynomial") -> np.array:
    xx = extend_basis_function(x, degree_array=np.arange(1, k_terms+1), extension=extension)
    xx_ones = np.hstack([np.ones((x.shape[0], 1)), xx])
    y = xx_ones@theta
    return y