import numpy as np
from scipy.stats import chi2
from itertools import product
import random
import math


def gtest(X, Y, Z):
    """
    The G test can test for goodness of fit to a distribution.
    We are here testing the null hypothesis X \perp Y |Z
    Note X, Y, Z are all binary variables
    Parameters
    ----------
    X : array
    Y : array
    Z : n*d matrix with d number of conditioned variables
    ddof : int, optional
        adjustment to the degrees of freedom for the p-value
    Returns
    -------
    chisquare statistic : float
        The chisquare test statistic
    p : float
        The p-value of the test.
    Notes
    -----
    """
    assert len(X) == len(Y)
    n = Z.shape[1]
    df = 2**n
    EmpTable = np.zeros((2, 2, df))
    conditionedset = list(product([0, 1], repeat=n))
    for x in [0, 1]:
        for y in [0, 1]:
            for z in conditionedset:
                xindex = set(np.where(X == x)[0])
                yindex = set(np.where(Y == y)[0])
                valid_index = xindex.intersection(yindex)
                for i in range(len(z)):
                    valid_index = valid_index.intersection(
                        set(np.where(Z[:, i] == z[i])[0])
                    )
                ind = conditionedset.index(z)
                EmpTable[x, y, ind] = len(valid_index)

    ExpTable = np.zeros((2, 2, df))
    for x in [0, 1]:
        for y in [0, 1]:
            for z in range(len(conditionedset)):
                ExpTable[x, y, z] = np.sum(EmpTable[x, :, z]) * np.sum(
                    EmpTable[:, y, z]
                )
                ExpTable[x, y, z] /= np.sum(EmpTable[:, :, z])
    prop = EmpTable.reshape(-1) * np.log((EmpTable / ExpTable).reshape(-1))
    prop = np.nan_to_num(prop)
    g = 2 * np.sum(prop)
    sig = chi2.sf(g, df)
    return g, sig


def simulate_data(nenv, nsample, alpha, beta, graph):
    theta = np.random.beta(alpha, beta, nenv)
    psi = np.random.beta(alpha, beta, nenv)
    xs = []
    ys = []
    for i in range(nenv):
        if graph == "xtoy":
            x = np.random.binomial(1, theta[i], size=nsample)
            ber_psi = np.random.binomial(1, psi[i], size=nsample)
            y = (ber_psi != x).astype(int)
        if graph == "ytox":
            y = np.random.binomial(1, psi[i], size=nsample)
            ber_theta = np.random.binomial(1, theta[i], size=nsample)
            x = (ber_theta != y).astype(int)
        if graph == "xindy":
            x = np.random.binomial(1, theta[i], size=nsample)
            y = np.random.binomial(1, psi[i], size=nsample)
        xs.append(x)
        ys.append(y)
    xs = np.array(xs).reshape(-1, nsample)
    ys = np.array(ys).reshape(-1, nsample)
    dict = {
        "x1": xs[:, 0],
        "x2": xs[:, 1],
        "y1": ys[:, 0],
        "y2": ys[:, 1],
        "true_dag": graph,
    }
    return dict


def arbitary_causal_effect_generator(list_vars, nsample):
    dict = {"target_vars": [], "interve_vars": [], "condition_vars": []}
    for _ in range(nsample):
        ntargets = random.randint(1, 3)
        target_vars = random.sample(list_vars, ntargets)
        remain_vars = [v for v in list_vars if v not in target_vars]
        # set intervened variable to be 1
        interve_vars = random.sample(remain_vars, 1)
        condition_vars = [v for v in remain_vars if v not in interve_vars]
        dict["target_vars"].append(target_vars)
        dict["interve_vars"].append(interve_vars)
        dict["condition_vars"].append(condition_vars)
    return dict


