import numpy as np
import torch
import torch.distributions as td


def calc_hill(reg_conc, half_response, coop_state, repressive):
    coop_state = coop_state.unsqueeze(-1)
    temp = torch.div(
        torch.pow(reg_conc, coop_state),
        (
            torch.pow(half_response.unsqueeze(-1), coop_state)
            + torch.pow(reg_conc, coop_state)
        ),
    )
    temp = torch.where(reg_conc == 0, 0, temp)
    result = torch.where(repressive.unsqueeze(-1), 1 - temp, temp)
    return result


def sim_sergio_pytorch(
    *,
    graph,
    toporder,
    number_bins,
    number_sc,
    noise_params,
    decays,
    basal_rates,
    k,
    hill,
    targets=None,
    interv_type="kout",
    sampling_state=15,
    dt=0.01,
    safety_steps=1
):
    n_parallel = graph.shape[0]
    graph = graph.bool()
    n_p_range = torch.arange(n_parallel)
    d = graph.shape[-1]
    mean_expr = -1 * torch.ones(n_parallel, d, number_bins)
    conc = torch.zeros(
        *mean_expr.shape, sampling_state * number_sc + d * safety_steps + 1
    )
    curr_conc_counter = torch.zeros(n_parallel, d).to(torch.int32)
    sc_expression = torch.zeros(n_parallel, d, number_bins, number_sc)

    assert toporder.shape[-1] == n_parallel
    for i, gene in enumerate(toporder):
        # gene is of size n_parallel
        is_mr = ~graph[n_p_range, :, gene].sum(-1).bool()
        n_req_steps = sampling_state * number_sc + (d - i) * safety_steps

        half_response = torch.zeros(*graph.shape[:-1])
        half_response[graph[n_p_range, :, gene]] = mean_expr[
            graph[n_p_range, :, gene]
        ].mean(-1)
        if interv_type == "kout":
            interv_factor = (~targets[n_p_range, gene]).float()
        else:
            interv_factor = torch.where(targets[n_p_range, gene] == 1.0, 0.5, 1.0)
        rate_ = torch.zeros(n_parallel, number_bins)
        rate_[is_mr] = basal_rates[is_mr, gene[is_mr]]
        mean_exp_pa = mean_expr * graph[n_p_range, :, gene].unsqueeze(-1)
        hill_unmasked = calc_hill(
            mean_exp_pa,
            half_response,
            hill[n_p_range, :, gene],
            k[n_p_range, :, gene] < 0,
        )
        rate_ += (
            hill_unmasked * ((graph * torch.abs(k))[n_p_range, :, gene]).unsqueeze(-1)
        ).sum(-2)
        append_conc = torch.div(
            interv_factor.unsqueeze(-1) * rate_, decays[n_p_range, gene][:, None]
        )
        append_conc = torch.where(append_conc < 0, 0.0, append_conc)
        conc[n_p_range, gene, :, curr_conc_counter[n_p_range, gene]] = append_conc
        curr_conc_counter[n_p_range, gene] += 1
        for _ in range(n_req_steps):
            curr_exp = conc[n_p_range, gene, :, curr_conc_counter[n_p_range, gene] - 1]
            # Calculate Production Rate
            rate_ = torch.zeros(n_parallel, number_bins)
            rate_[is_mr] = basal_rates[is_mr, gene[is_mr]]
            conc_parent = conc[
                n_p_range, ..., curr_conc_counter[n_p_range, gene] - 1
            ] * graph[n_p_range, :, gene].unsqueeze(-1)
            hill_unmasked = calc_hill(
                conc_parent,
                half_response,
                hill[n_p_range, :, gene],
                k[n_p_range, :, gene] < 0,
            )
            rate_ += (
                hill_unmasked
                * ((graph * torch.abs(k))[n_p_range, :, gene]).unsqueeze(-1)
            ).sum(-2)
            prod_rate = interv_factor.unsqueeze(-1) * rate_
            decay_ = decays[n_p_range, gene].unsqueeze(-1) * curr_exp
            dw_p = torch.randn_like(curr_exp)
            dw_d = torch.randn_like(curr_exp)
            amplitude_p = noise_params[n_p_range, gene].unsqueeze(-1) * torch.pow(
                prod_rate, 0.5
            )
            amplitude_d = noise_params[n_p_range, gene].unsqueeze(-1) * torch.pow(
                decay_, 0.5
            )
            noise = (amplitude_p * dw_p) + (amplitude_d * dw_d)
            dxdt = (dt * (prod_rate - decay_)) + (np.power(dt, 0.5) * noise)
            append_conc = (
                conc[n_p_range, gene, :, curr_conc_counter[n_p_range, gene] - 1] + dxdt
            )
            append_conc = torch.where(append_conc < 0, 0.0, append_conc)
            conc[n_p_range, gene, :, curr_conc_counter[n_p_range, gene]] = append_conc
            curr_conc_counter[n_p_range, gene] += 1
        select_steps = torch.randint(
            low=-sampling_state * number_sc,
            high=0,
            size=(
                n_parallel,
                number_bins,
                number_sc,
            ),
        )
        select_steps = curr_conc_counter[n_p_range, gene, None, None] + select_steps
        sampled_expr = torch.gather(conc[n_p_range, gene], -1, select_steps)
        mean_expr[n_p_range, gene] = sampled_expr.mean(-1)
        sc_expression[n_p_range, gene] = sampled_expr
    return sc_expression


def outlier_effect(scData, outlier_prob, mean, scale):
    sc_shape = scData.shape
    d = sc_shape[-3]
    out_indicator = (td.binomial.Binomial(probs=outlier_prob).sample((d,)) == 1).t()
    out_factors = td.log_normal.LogNormal(mean, scale).sample((d,)).t()
    scData = scData.view(*sc_shape[:2], -1)
    scData[out_indicator] *= out_factors[out_indicator].unsqueeze(-1)
    return scData.view(*sc_shape)


def lib_size_effect(scData, mean, scale):
    sc_shape = scData.shape
    out_factors = (
        td.log_normal.LogNormal(mean, scale).sample(sc_shape[-2:]).permute(2, 0, 1)
    )  # n_parallel x n_bins x n_sc
    sum_factors = scData.sum(-3)
    of_normalized = out_factors / torch.where(sum_factors == 0, 1.0, sum_factors)
    ret = of_normalized.unsqueeze(-3) * scData
    return ret


def dropout_indicator(scData, shape=1, percentile=65):
    sc_log_data = torch.log1p(scData)
    log_mid_point = torch.quantile(
        sc_log_data.view(scData.shape[0], -1), percentile[0] / 100, dim=1
    )
    prob_ber = torch.div(
        1,
        1
        + torch.exp(
            -1
            * shape[:, None, None, None]
            * (sc_log_data - log_mid_point[:, None, None, None])
        ),
    )
    binary_ind = td.binomial.Binomial(probs=prob_ber).sample()
    return binary_ind
