from itertools import product



def compute_marginal(density, var_interest):
    def joint_pdf(x1, x2, y1, y2):
        return density[(x1, x2, y1, y2)]
    var_all = ['x1', 'x2', 'y1', 'y2']
    var_to_sum = [v for v in var_all if v not in var_interest]

    sum_combinations = list(product([0, 1], repeat=len(var_to_sum)))

    marginal_prob = 0
    for combination in sum_combinations:
        var_ranges = dict(zip(var_to_sum, combination))
        var_ranges.update(var_interest)
        marginal_prob += joint_pdf(**var_ranges)
    return marginal_prob

def compute_conditional(density, conditioned_var, key):
    marginal = compute_marginal(density, conditioned_var)
    conditional = density[key] / marginal
    return conditional


def exch_causal_effect_computation(joint_density, intervened_var, intervened_val, dag):
    intervention_desc = ('do(' + intervened_var + ')' + '=' + str(intervened_val))
    keys = [(a, b, c, d) for a, b, c, d in product([0, 1], [0, 1], [0, 1], [0, 1])]
    joint_causal_effect = {intervention_desc: dict.fromkeys(keys, 0)}
    var_all = ['x1', 'x2', 'y1', 'y2']
    position = var_all.index(intervened_var)
    if dag == 'xtoy':
        for key in keys:
            indicator = (key[position] == intervened_val)
            if position <= 1:
                observed_cause = 'x2' if position == 0 else 'x1'
                causes = compute_marginal(joint_density, {observed_cause: key[1 - position]})
                peffect_given_cause = compute_conditional(joint_density, {'x1': key[0], 'x2':key[1]}, key)
            else:
                causes = compute_marginal(joint_density, {'x1': key[0], 'x2': key[1]})
                if position == 2:
                    marginal = compute_marginal(joint_density, {'x2': key[1], 'y2': key[3]})
                    peffect_given_cause = marginal / compute_marginal(joint_density, {'x2': key[1]})
                if position == 3:
                    marginal = compute_marginal(joint_density, {'x1': key[0], 'y1': key[2]})
                    peffect_given_cause = marginal / compute_marginal(joint_density, {'x1': key[0]})
            joint_causal_effect[intervention_desc][key] = peffect_given_cause * causes * indicator
    elif dag == 'ytox':
        for key in keys:
            indicator = (key[position] == intervened_val)
            if position == 0:
                causes = compute_marginal(joint_density, {'y1': key[2], 'y2': key[3]})
                marginal = compute_marginal(joint_density, {'x2': key[1], 'y2': key[3]})
                peffect_given_cause = marginal / compute_marginal(joint_density, {'y2': key[3]})
            elif position == 1:
                causes = compute_marginal(joint_density, {'y1': key[2], 'y2': key[3]})
                marginal = compute_marginal(joint_density, {'x1': key[0], 'y1': key[2]})
                peffect_given_cause = marginal / compute_marginal(joint_density, {'y1': key[2]})
            elif position == 2:
                causes = compute_marginal(joint_density, {'y2': key[3]})
                peffect_given_cause = compute_conditional(joint_density, {'y1': key[2], 'y2': key[3]}, key)
            else:
                causes = compute_marginal(joint_density, {'y1': key[2]})
                peffect_given_cause = compute_conditional(joint_density, {'y1': key[2], 'y2': key[3]}, key)
            joint_causal_effect[intervention_desc][key] = peffect_given_cause * causes * indicator
    else:
        for key in keys:
            indicator = (key[position] == intervened_val)
            if position <= 1:
                observed_cause = 'x2' if position == 0 else 'x1'
                causes = compute_marginal(joint_density, {observed_cause: key[1 - position]})
                peffects = compute_marginal(joint_density, {'y1': key[2], 'y2': key[3]})
            else:
                observed_effect = 'y1' if position == 3 else 'y2'
                observed_pos = 2 if position == 3 else 3
                causes =  compute_marginal(joint_density, {'x1': key[0], 'x2': key[1]})
                peffects = compute_marginal(joint_density, {observed_effect: key[observed_pos]})
            joint_causal_effect[intervention_desc][key] = peffects * causes * indicator


    return joint_causal_effect