function getEst(x, y)
    (n, d) = size(x)
    beta = zeros(d, 1)
    shs = similar(beta)
    beta_new = similar(beta)
    p = Array{Float64}(undef, n, 1)
    # η = similar(p)
    S = similar(x)
    xs = similar(x)
    H = Array{Float64}(undef, d, d)
    ss = Array{Float64}(undef, 1, d)
    loop = 1;
    Loop = 100;
    msg = "NA";
    while loop <= Loop
        # mul!(η, x, beta);
        # p .= s.(η)
        # @avxt p .= 1.0 ./ (1.0 .+ exp.(.- η))
        pred!(x, beta, p)
        S .= x .* (y .- p)
        ss = sum(S, dims=1)
        xs .= x .* p .* (1.0 .- p)
        mul!(H, x', xs)
        # H .+= 0.001
        try
            # ldiv!(shs, qr(H), ss')
            shs = H \ ss'
        catch
            msg = "H is singular"; println(msg)
            beta .= NaN
            break
        end
        beta_new .= beta .+ shs
        tlr  = sum(abs2, shs)
        beta .= beta_new
        if tlr < 0.000001
            msg = "Successful convergence"
            break
        end
        if loop == Loop
            msg = "Maximum iteration reached"; println(msg)
            beta .= NaN
            break
        end
        loop += 1
    end
    return vec(beta), msg, loop, H, S'S
end

function getWEst(x, y, w)
    (n, d) = size(x)
    beta = zeros(Float64, d, 1)
    shs = similar(beta)
    beta_new = similar(beta)
    p = Array{Float64}(undef, n, 1)
    η = similar(p)
    S = similar(x)
    xs = similar(x)
    H = Array{Float64}(undef, d, d)
    ss = Array{Float64}(undef, 1, d)
    loop = 1;
    Loop = 100;
    msg = "NA";
    wx = similar(x)
    wx .= w .* x ####################################### weighted x
    while loop <= Loop
        pred!(x, beta, p)
        S .= wx .* (y .- p)         ######################## use weighted x
        ss = sum(S, dims=1)
        xs .= wx .* p .* (1.0 .- p) ######################## use weighted x
        mul!(H, x', xs)
        # H .+= 0.001
        try
            # ldiv!(shs, qr(H), ss')
            shs = H \ ss'
        catch
            msg = "H is singular"; println(msg)
            beta .= NaN
            break
        end
        beta_new .= beta .+ shs
        tlr  = sum(shs.^2)
        beta .= beta_new
        if tlr < 0.000001
            msg = "Successful convergence"
            break
        end
        if loop == Loop
            msg = "Maximum iteration reached"; println(msg)
            beta .= NaN
            break
        end
        loop += 1
    end
    return vec(beta), msg, loop, H, S'S
end

function getMSLE(x, y, π, pilot)
    (n, d) = size(x)
    beta = copy(pilot)
    shs = similar(beta)
    beta_new = similar(beta)
    p = Array{Float64}(undef, n, 1)
    S = similar(x)
    xs = similar(x)
    H = Array{Float64}(undef, d, d)
    ss = Array{Float64}(undef, 1, d)
    loop = 1;
    Loop = 100;
    msg = "NA";
    while loop <= Loop
        # mul!(η, x, beta);
        # # p .= s.(η)
        # p .= 1.0 ./ (1.0 .+ exp.(.- η) .* π) ############ log odds correction
        pred!(x, beta, π, p)
        S .= x .* (y .- p)
        ss = sum(S, dims=1)
        xs .= x .* p .* (1.0 .- p)
        mul!(H, x', xs)
        # H .+= 0.001
        try
            # ldiv!(shs, qr(H), ss')
            shs = H \ ss'
        catch
            msg = "H is singular"; println(msg)
            beta .= NaN
            break
        end
        beta_new .= beta .+ shs
        tlr  = sum(abs2, shs)
        beta .= beta_new
        if tlr < 0.000001
            msg = "Successful convergence"
            break
        end
        if loop == Loop
            msg = "Maximum iteration reached"; println(msg)
            beta .= NaN
            break
        end
        loop += 1
    end
    return vec(beta), msg, loop, H, S'S
end

function ccb(Y::BitVector, xv, n::Int64)
    N = length(Y); N1 = sum(Y); N0 = N - N1
    p0 = 1 / 2N0; # p1 = 1 / 2N1
    pd = 1 / 2N1 - p0 
    PI = pd .* Y .+ p0
    balance!(PI, xv)
    idx = (1:N)[rand.() .<= n .* PI]
    π = min.(n .* PI[idx], 1)
    return idx, π, N1, N0
end

function cc(Y::BitVector, n::Int64)
    N = length(Y); N1 = sum(Y); N0 = N - N1
    nh = n÷2
    idx = [sample(findall(Y), nh); sample(findall(.!Y), nh)]
    π = repeat(nh ./ [N1, N0], inner=nh)
    return idx, π, N1, N0
end

struct Pilot 
    x::Matrix{Float64}
    y::BitVector
    β::Vector{Float64}
    π::Vector{Float64}
    idx::Vector{Int64}
    H::Matrix{Float64}
end

function getPilotB(X, Y, n0)
    (N, d) = size(X)
    idx_plt, π_plt, N1, N0 = ccb(Y, view(X, :, 2:3), n0)
    x_plt = X[idx_plt, :]
    y_plt = Y[idx_plt]
    # beta_plt, msg, loop, ddm_plt, F = getEst(x_plt, y_plt)
    beta_plt, _, _, ddm_plt, _ = getWEst(x_plt, y_plt, 1 ./ π_plt)
    return Pilot(x_plt, y_plt, beta_plt, π_plt, idx_plt, ddm_plt)
    # N, N1, N0
end

function getPilot(X, Y, n0)
    (N, d) = size(X)
    idx_plt, π_plt, N1, N0 = cc(Y, n0)
    x_plt = X[idx_plt, :]
    y_plt = Y[idx_plt]
    # beta_plt, msg, loop, ddm_plt, F = getEst(x_plt, y_plt)
    beta_plt, _, _, ddm_plt, _ = getEst(x_plt, y_plt)
    beta_plt[1] -= log(N0/N1)
    return Pilot(x_plt, y_plt, beta_plt, π_plt, idx_plt, ddm_plt)
    # N, N1, N0
end

function opt(X, Y, plt::Pilot, cri::Function)
    π = p = m = Array{Float64}(undef, N)
    pred!(X, plt.β, p)
    Xd = cri(X, plt)
    dsm = sum(abs2, Xd, dims=2)
    m .= p .* sqrt.(dsm)
    estM = sum(m[.!Y])
    π .= m ./ estM # / (1-d/n0)
    # balance!(π, view(X, :, 2:3))
    return π
end

function balance!(π, x) 
    xm = mean(x, dims=1)
    px = x .* (1 .- xm) .+ (1 .- x) .* xm
    π .*= sqrt.(sum(abs2, px, dims=2))
    normalize!(π, 1)
end

function calHS(plt::Pilot)
    ϕ = p = similar(plt.π)
    pred!(plt.x, plt.β, p)
    ϕ .= p .* (1.0 .- p)
    xs = similar(plt.x)
    xs .= plt.x .* ϕ ./ plt.π
    H = plt.x'xs
    xs .= xs .* ϕ
    S = plt.x'xs
    S =sqrt(Hermitian(S))
    return H, S
end

function IVopt(X, plt::Pilot)
    Xd = similar(X)
    H, S = calHS(plt)
    A = H \ S
    return mul!(Xd, X, A)
end

function Aopt(X, plt::Pilot)
    Xd = similar(X)
    H, _ = calHS(plt)
    A = inv(H)
    return mul!(Xd, X, A)
end

function Lopt(X, plt::Pilot)
    return X
end

function os(X, Y, n, plt::Pilot, cri::Function)
    pi_P = opt(X, Y, plt, cri)
    idx = sample_poi(Y, pi_P, n)
    x = X[idx, :]
    y = Y[idx]
    π = min.(n .* pi_P[idx], 1)
    getMSLE(x, y, π, plt.β), length(idx)
end

sample_poi(Y, pi_P, n) = findall(rand.() .<= Y .+ (1 .- Y) .* n .* pi_P)

function uni(X, Y, n)
    N = length(Y)
    N1 = sum(Y); N0 = N - N1
    π = 1.0 / N0
    idx = sample_poi(Y, π, n)
    x = X[idx, :]
    y = Y[idx]
    ft = getEst(x, y)
    ft[1][1] += log(n / N0)
    return ft
end

# Uniform sampling
function Uni(X, Y, nss)
    (N,d) = size(X)
    loc0 = Y.==0
    N0 = sum(loc0)
    lns = length(nss)
    Betas = fill(NaN, d, lns)
    n_star = Array{Int64,1}(undef, lns)
    for (idn, n) in enumerate(nss)
        u = rand(N)
        pi_uni = ones(N)
        pi_uni[loc0] .= n/N0
        idx_uni = u .<= pi_uni
        x_uni = X[idx_uni, :]
        y_uni = Y[idx_uni]
        Betas[:,idn] = getEst(x_uni, y_uni)[1]
        n_star[idn] = sum(idx_uni[loc0])
    end
    Betas[1,:] .+= log.(nss ./ N0)
    return Betas, n_star
end

s(x) = 1 / (1 + exp(-x))

function pred!(X, β, p)
    mul!(p, X, β)
    p .= s.(p)
end

function pred!(X, β, π, p)
    mul!(p, X, β)
    p .= 1.0 ./ (1.0 .+ exp.(.- p) .* π)
end
        # p .= 1.0 ./ (1.0 .+ exp.(.- η) .* π) ############ log odds correction
function predSqErr!(p, p0)
    sum(abs2, p .- p0)
end

dropsum(A; dims=:) = dropdims(sum(A; dims=dims); dims=dims)

nanmean(x) = mean(filter(!isnan, x))
nanmean(x,y) = mapslices(nanmean, x, dims=y)
nanvar(x) = var(filter(!isnan, x))
nanvar(x,y) = mapslices(nanvar, x, dims=y)

# nanmean(abs2, (Betas .- beta0), 2)
emse(Betas, beta0) = sum(nanmean((Betas .- beta0).^2, 2))
evar(Betas, beta0) = sum(nanvar(Betas, 2))
ebias(Betas, beta0) = sum((nanmean(Betas, 2) .- beta0).^2)

emse(Betas, beta0, scl) = sum(nanmean((Betas .- beta0 .* scl).^2, 2))
evar(Betas, beta0, scl) = sum(nanvar(Betas, 2))
ebias(Betas, beta0, scl) = sum((nanmean(Betas, 2) .- beta0 .* scl).^2)

function simu!(X, Y, beta0, case, rpt, n, n0, scl, P0, P_tmp)
    # betat = beta0
    betat = beta0 .* scl
    N, d = size(X)
    Beta_uni = fill(NaN, d, rpt);
    pred_uni = fill(NaN, rpt);
    Beta_osA = fill(NaN, d, rpt);
    pred_osA = fill(NaN, rpt);
    n_osA = fill(NaN, rpt);
    Beta_osL = fill(NaN, d, rpt);
    pred_osL = fill(NaN, rpt);
    n_osL = fill(NaN, rpt);
    Beta_osIV = fill(NaN, d, rpt);
    pred_osIV = fill(NaN, rpt);
    n_osIV = fill(NaN, rpt);
    # @time @floop ThreadedEx(basesize=50) for rr in 1:rpt
    for rr in 1:rpt
        genX!(N, case, betat, X)
        X ./= scl'
        genY!(betat, X, Y)
        pred!(X, betat, P0)
        beta_uni = uni(X, Y, n)[1]
        Beta_uni[:,rr] = beta_uni
        if !isnan(sum(beta_uni))
            pred!(X, beta_uni, P_tmp)
            pred_uni[rr] = sum(abs2, P_tmp .- P0)
        end
        plt = getPilot(X, Y, n0)
        # plt = getPilotB(X, Y, n0)
        fit_osA = os(X, Y, n, plt, Aopt)
        beta_osA = fit_osA[1][1]
        Beta_osA[:,rr] = beta_osA
        n_osA[rr] = fit_osA[2]
        if !isnan(sum(beta_osA))
            pred!(X, beta_osA, P_tmp)
            pred_osA[rr] = sum(abs2, P_tmp .- P0)
        end
        fit_osL = os(X, Y, n, plt, Lopt)
        beta_osL = fit_osL[1][1]
        Beta_osL[:,rr] = beta_osL
        n_osL[rr] = fit_osL[2]
        if !isnan(sum(beta_osL))
            pred!(X, beta_osL, P_tmp)
            pred_osL[rr] = sum(abs2, P_tmp .- P0)
        end
        fit_osIV = os(X, Y, n, plt, IVopt)
        beta_osIV = fit_osIV[1][1]
        Beta_osIV[:,rr] = beta_osIV
        n_osIV[rr] = fit_osIV[2]
        if !isnan(sum(beta_osIV))
            pred!(X, beta_osIV, P_tmp)
            pred_osIV[rr] = sum(abs2, P_tmp .- P0)
        end
    end
    return (
    Beta_uni = Beta_uni,
    pred_uni = pred_uni,
    Beta_osA = Beta_osA,
    pred_osA = pred_osA,
    n_osA = n_osA,
    Beta_osL = Beta_osL,
    pred_osL = pred_osL,
    n_osL = n_osL,
    pred_osIV = pred_osIV,
    Beta_osIV = Beta_osIV,
    n_osIV = n_osIV
    )
end

function calRes(res, beta0, scl) 
    mse = [emse(res.Beta_uni, beta0, scl),
            emse(res.Beta_osA, beta0, scl),
            emse(res.Beta_osL, beta0, scl),
            emse(res.Beta_osIV, beta0, scl)]
    pred = [nanmean(res.pred_uni),
             nanmean(res.pred_osA),
             nanmean(res.pred_osL),
             nanmean(res.pred_osIV)]
    navg = [mean(res.n_osA),
            mean(res.n_osL),
            mean(res.n_osIV)]
    return (mse=mse, pred=pred, navg=navg)
end

function mimic!(mη, μ, X, Y, beta0, case, rpt, n, n0, scl, P0, P_tmp)
    betat = beta0
    # betat = beta0 .* scl
    N, d = size(X)
    Beta_uni = fill(NaN, d, rpt);
    pred_uni = fill(NaN, rpt);
    Beta_osA = fill(NaN, d, rpt);
    pred_osA = fill(NaN, rpt);
    n_osA = fill(NaN, rpt);
    Beta_osL = fill(NaN, d, rpt);
    pred_osL = fill(NaN, rpt);
    n_osL = fill(NaN, rpt);
    Beta_osIV = fill(NaN, d, rpt);
    pred_osIV = fill(NaN, rpt);
    n_osIV = fill(NaN, rpt);
    # @time @floop ThreadedEx(basesize=50) for rr in 1:rpt
    @showprogress for rr in 1:rpt
        # gendat!(mη, μ, X, Y)
        # genX!(N, case, betat, X)
        genX!(mη, μ, X)
        # X ./= scl'
        genY!(betat, X, Y)
        pred!(X, betat, P0)
        beta_uni = uni(X, Y, n)[1]
        Beta_uni[:,rr] = beta_uni
        if !isnan(sum(beta_uni))
            pred!(X, beta_uni, P_tmp)
            pred_uni[rr] = sum(abs2, P_tmp .- P0)
        end
        plt = getPilot(X, Y, n0)
        if isnan(plt.β[1]) continue end
        # plt = getPilotB(X, Y, n0)
        fit_osA = os(X, Y, n, plt, Aopt)
        beta_osA = fit_osA[1][1]
        Beta_osA[:,rr] = beta_osA
        n_osA[rr] = fit_osA[2]
        if !isnan(sum(beta_osA))
            pred!(X, beta_osA, P_tmp)
            pred_osA[rr] = sum(abs2, P_tmp .- P0)
        end
        fit_osL = os(X, Y, n, plt, Lopt)
        beta_osL = fit_osL[1][1]
        Beta_osL[:,rr] = beta_osL
        n_osL[rr] = fit_osL[2]
        if !isnan(sum(beta_osL))
            pred!(X, beta_osL, P_tmp)
            pred_osL[rr] = sum(abs2, P_tmp .- P0)
        end
        fit_osIV = os(X, Y, n, plt, IVopt)
        beta_osIV = fit_osIV[1][1]
        Beta_osIV[:,rr] = beta_osIV
        n_osIV[rr] = fit_osIV[2]
        if !isnan(sum(beta_osIV))
            pred!(X, beta_osIV, P_tmp)
            pred_osIV[rr] = sum(abs2, P_tmp .- P0)
        end
    end
    return (
    Beta_uni = Beta_uni,
    pred_uni = pred_uni,
    Beta_osA = Beta_osA,
    pred_osA = pred_osA,
    n_osA = n_osA,
    Beta_osL = Beta_osL,
    pred_osL = pred_osL,
    n_osL = n_osL,
    pred_osIV = pred_osIV,
    Beta_osIV = Beta_osIV,
    n_osIV = n_osIV
    )
end

function mimicFull!(mη, μ, X, Y, beta0, case, rpt, scl, P0, P_tmp)
    betat = beta0
    # betat = beta0 .* scl
    N, d = size(X)
    Beta_full = fill(NaN, d, rpt);
    pred_full = fill(NaN, rpt);
    for rr in 1:rpt
        genX!(mη, μ, X)
        # X ./= scl'
        genY!(betat, X, Y)
        pred!(X, betat, P0)
        # beta_full = uni(X, Y, n)[1]
        beta_full = getEst(X, Y)[1]
        Beta_full[:,rr] = beta_full
        if !isnan(sum(beta_full))
            pred!(X, beta_full, P_tmp)
            pred_full[rr] = sum(abs2, P_tmp .- P0)
        end
    end
    return (Beta = Beta_full, pred = pred_full)
end

function calResF(res, beta0, scl) 
    mse = emse(res.Beta, beta0, scl)
    pred = nanmean(res.pred)
    return (mse=mse, pred=pred)
end
