# partial dependency plot   
import sys
sys.path.append("../SEV/")
from Encoder import DataEncoder
import pandas as pd
import os
import numpy as np
import torch
from data_loader import data_loader
from sklearn.model_selection import train_test_split
import argparse
from FlexibleSEV_new import FlexibleSEV
from torch.utils.data import DataLoader
from sklearn.ensemble import GradientBoostingClassifier
from tqdm import tqdm
import time
from sklearn.metrics import accuracy_score,roc_auc_score
from sklearn.linear_model import LogisticRegression
from sklearn.neural_network import MLPClassifier

# load the dataset
X, y, X_neg = data_loader("german")
print("Working on the dataset {}".format("german"))

# encode the data
encoder = DataEncoder(standard=True)
encoder.fit(X_neg)
encoded_X = encoder.transform(X)
encoded_X_neg = encoder.transform(X_neg)

# train the model
model = MLPClassifier(hidden_layer_sizes=(128,128),random_state=42)

# do a train test split
X_train, X_test, y_train, y_test = train_test_split(encoded_X, y, test_size=0.2, random_state=42)

model.fit(X_train, y_train)


thresholds = {}
positiveness = {}
for index,i in enumerate(X_train.columns):
    # generate a dataset with zeros
    X_zero = np.zeros((100, X_train.shape[1]))
    X_zero[:,index] = np.linspace(X_train[i].min(),X_train[i].max(),100)
    # possible_values
    possible_values = np.linspace(X_train[i].min(),X_train[i].max(),100)
    # do the model prediction
    y_zero = model.predict_proba(X_zero)[:,1]
    # find out the values that changes from below 0.5 to above 0.5
    y_zero = y_zero>0.6
    for y_ind in range(1,y_zero.shape[0]):
        if y_zero[y_ind] != y_zero[y_ind-1]:
            if i in thresholds:
                thresholds[i].append(possible_values[y_ind])
            else:
                thresholds[i] = [possible_values[y_ind]]
            
            if i in positiveness:
                positiveness[i].append(y_zero[y_ind])
            else:
                positiveness[i] = [y_zero[y_ind-1],y_zero[y_ind]]

X_cate = []
for index,i in enumerate(X_train.columns):
    if i not in thresholds:
        continue
    cates = []
    values = X_train[i]
    for v in values:
        if v < thresholds[i][0]:
            cates.append(thresholds[i][0])
        else:
            index = 0
            for threshold in thresholds[i]:
                if v > threshold:
                    index += 1
            cates.append(positiveness[i][index])
    X_cate.append(cates)

X_cate = np.array(X_cate).T

# calculate how many positive and negative each sample have
positive_num = np.sum(X_cate>0.6,axis=1)
negative_num = np.sum(X_cate<0.6,axis=1)

# calculate SEV for the test dataset
sev = FlexibleSEV(model, encoder, encoded_X.columns,encoded_X_neg,tol=0,k=1)

# generate the explanations:
flexible_sev = []
L_inf = []
time_lst = []
used_lst = []
for ind,xi in enumerate(tqdm(np.array(X_train))):
    if model.predict([xi]) != 1:
        flexible_sev.append(0)
        continue
    start_time = time.time()
    flexible_sev_num,original_diff,used = sev.sev_cal(xi,mode="minus")
    flexible_sev.append(flexible_sev_num)
    L_inf.append(np.max(np.abs(original_diff)))
    time_lst.append(time.time()-start_time)
    used_lst.append(used)

# generate the explanations:
flexible_sev_plus = []
for ind,xi in enumerate(tqdm(np.array(X_train))):
    if model.predict([xi]) != 1:
        flexible_sev_plus.append(0)
        continue
    start_time = time.time()
    flexible_sev_num,original_diff,used = sev.sev_cal(xi,mode="plus")
    flexible_sev_plus.append(flexible_sev_num)
    L_inf.append(np.max(np.abs(original_diff)))
    time_lst.append(time.time()-start_time)
    used_lst.append(used)


# combine all three together: flexible_sev, positive_num, negative_num
flexible_sev = np.array(flexible_sev)
positive_num = np.array(positive_num)
negative_num = np.array(negative_num)
flexible_sev_plus = np.array(flexible_sev_plus)
print(flexible_sev.shape,positive_num.shape,negative_num.shape,flexible_sev_plus.shape)
output = np.c_[flexible_sev_plus,flexible_sev,positive_num,negative_num]

output_df = pd.DataFrame(output,columns=["Flexible_SEV_Plus","Flexible_SEV","Positive_Num","Negative_Num"])
# drop all 0 row
output_df.to_csv("analysis_german.csv")

for i in output_df["Flexible_SEV_Plus"].unique():
    selected = output_df[output_df["Flexible_SEV_Plus"]==i]
    print("Flexible_SEV_Plus:{}".format(i))
    print("positive_num:{}".format(np.mean(selected["Positive_Num"])))
    print("Negative_Num:{}".format(np.mean(selected["Negative_Num"])))

print("=====================================")
for i in output_df["Flexible_SEV"].unique():
    selected = output_df[output_df["Flexible_SEV"]==i]
    print("Flexible_SEV:{}".format(i))
    print("Positive_Num:{}".format(np.mean(selected["Positive_Num"])))
    print("Negative_Num:{}".format(np.mean(selected["Negative_Num"])))


