#!/usr/bin/env python3

import os
import urllib.request
import pandas as pd

from jax import numpy as jnp


from fairgym.utils.logistic_regression_jax import LogisticRegression
from sklearn import preprocessing


this_dir = os.path.dirname(os.path.realpath(__file__))

# Train (3.8M)
# train_path = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data'
train_path = os.path.join(this_dir, "adult.data")
train_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.data"
if not os.path.exists(train_path):
    urllib.request.urlretrieve(train_url, train_path)

# Test (1.9M)
# test_path = 'https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test'
test_path = os.path.join(this_dir, "adult.test")
test_url = "https://archive.ics.uci.edu/ml/machine-learning-databases/adult/adult.test"
if not os.path.exists(test_path):
    urllib.request.urlretrieve(test_url, test_path)


# Load data
################################################################################

all_features = [
    "Age",
    "Workclass",  # 9
    "fnlwgt",
    "Education",  # 16
    "Education-Num",  # 16
    "Marital Status",  # 7
    "Occupation",  # 15
    "Relationship",  # 6
    "Race",  # 5
    "Sex",  # 2
    "Capital Gain",
    "Capital Loss",
    "Hours per week",
    "Country",  # 42
    "Target",
]

train = pd.read_csv(
    train_path, names=all_features, sep=r"\s*,\s*", engine="python", na_values="?"
)
test = pd.read_csv(
    test_path,
    names=all_features,
    sep=r"\s*,\s*",
    engine="python",
    na_values="?",
    skiprows=1,
)

# remape labels to 0, 1
Target_dict = {">50K.": 1, ">50K": 1, "<=50K.": 0, "<=50K": 0}
df = pd.concat([train, test], ignore_index=True).replace({"Target": Target_dict})

################################################################################

X_real_cols = ["Age", "Capital Gain", "Capital Loss", "Hours per week"]
X_cat_cols = ["Workclass", "Education", "Marital Status", "Occupation"]
ignore_cols = ["Race", "Country"]
G_cols = ["Sex"]
Y_col = "Target"

#####
# X
X = df[X_real_cols]
for X_cat_col in X_cat_cols:
    if X_cat_col not in ignore_cols:
        X = pd.concat([X, pd.get_dummies(df[X_cat_col], dummy_na=True)], axis=1)

scaler = preprocessing.StandardScaler().fit(X.values)
X_scaled = scaler.transform(X.values)

#####
# G

G = pd.get_dummies(df["Sex"]).values[:, 0]  # Female = 1

#####
# Y
Y = df[Y_col].values


################################################################################

# frequency of Y=1 in the dataset
g0 = G == True  # female
g1 = G == False  # male
y1 = Y == True
y0 = Y == False


ref_prG0 = g0.mean()
ref_prG1 = 1 - ref_prG0
ref_prY1G1 = (y1 & g0).mean()
ref_prY0G1 = 1 - ref_prY1G1
ref_prY1G0 = (y1 & g1).mean()
ref_prY0G0 = 1 - ref_prY1G0


def retrain(target_prG0, target_prY1_G0, target_prY1_G1):
    """
    Retrain the logistic regressor,
    reweighting samples to match target distribution
    """

    target_prG1 = 1 - target_prG0
    target_prY0_G0 = 1 - target_prY1_G0
    target_prY0_G1 = 1 - target_prY1_G0

    target_prY0G0 = target_prG0 * target_prY0_G0
    target_prY0G1 = target_prG1 * target_prY0_G1
    target_prY1G0 = target_prG0 * target_prY1_G0
    target_prY1G1 = target_prG1 * target_prY1_G1

    # indexed by Y, then by G
    m00 = target_prY0G0 / ref_prY0G0
    m10 = target_prY0G1 / ref_prY0G1
    m01 = target_prY1G0 / ref_prY1G0
    m11 = target_prY1G1 / ref_prY1G1

    m = G * Y * m11 + G * (1 - Y) * m10 + (1 - G) * Y * m01 + (1 - G) * (1 - Y) * m00

    # retain same average weight
    m = m / m.mean()
    clf = LogisticRegression(random_state=0).fit(X_scaled, Y, sample_weight=m)

    return clf


def observe_distribution(prG0, prY1_G0, prY1_G1, num_bins):

    clf = retrain(prG0, prY1_G0, prY1_G1)

    prY0_G0 = 1 - prY1_G0
    prY0_G1 = 1 - prY1_G1

    # XX is learned feature

    # XX values filtered by group and Y value
    XX_Y0G0 = clf.predict_proba(scaler.transform((X[y0 & g0]).values))
    XX_Y0G1 = clf.predict_proba(scaler.transform((X[y0 & g1]).values))
    XX_Y1G0 = clf.predict_proba(scaler.transform((X[y1 & g0]).values))
    XX_Y1G1 = clf.predict_proba(scaler.transform((X[y1 & g1]).values))

    # print(XX_Y0G0)
    # print(XX_Y0G1)
    # print(XX_Y1G0)
    # print(XX_Y1G1)

    bins = jnp.linspace(0, 1, num_bins + 1)[1:-1]

    # represent conditional probability as array of values that sum to 1
    prXX_Y0G0 = jnp.bincount(jnp.digitize(XX_Y0G0, bins), length=num_bins) / len(
        XX_Y0G0
    )
    prXX_Y0G1 = jnp.bincount(jnp.digitize(XX_Y0G1, bins), length=num_bins) / len(
        XX_Y0G1
    )
    prXX_Y1G0 = jnp.bincount(jnp.digitize(XX_Y1G0, bins), length=num_bins) / len(
        XX_Y1G0
    )
    prXX_Y1G1 = jnp.bincount(jnp.digitize(XX_Y1G1, bins), length=num_bins) / len(
        XX_Y1G1
    )

    # probability density of XX conditioned on group
    prXX_G0 = prXX_Y0G0 * prY0_G0 + prXX_Y1G0 * prY1_G0
    prXX_G1 = prXX_Y0G1 * prY0_G1 + prXX_Y1G1 * prY1_G1

    # probability Y=1 conditioned on X and group.
    # by Bayes's rule
    # prY_X = prX_Y * pr_Y / pr_X
    prY1_XXG0 = prXX_Y1G0 * prY1_G0 / prXX_G0
    prY1_XXG1 = prXX_Y1G1 * prY1_G1 / prXX_G1

    # print(prXX_G0)
    # print(prXX_G1)
    # print(prY1_XXG0)
    # print(prY1_XXG1)
    return [
        [prXX_G0, prXX_G1],  # pr_X, indexed by g then x
        [prY1_XXG0, prY1_XXG1],  # pr_Y1gX, indexed by g, then x
    ]


if __name__ == "__main__":
    import time

    now = time.time_ns()
    XX = observe_distribution(0.5, 0.5, 0.5, 32)
    print((time.time_ns() - now) / 1e9)
    print(XX)
