from sdv.metadata import MultiTableMetadata
import json
import pandas as pd
from sdv.evaluation.single_table import evaluate_quality as single_table_evaluate_quality
from sdv.evaluation.multi_table import evaluate_quality

# CHILD_PRIMARY_KEY = 'SalesID'
# PARENT_PRIMARY_KEY = 'CustomerID'
# PARENT_TABLE_NAME = 'customers'
# CHILD_TABLE_NAME = 'sales'

# parent_domain_path = 'salesDB_v1/parent_recovered_domain_dict.json'
# child_domain_path = 'salesDB_v1/Sales_domain.json'
# parent_df_path = 'salesDB_v1/parent_recovered_original_df.csv'
# child_df_path = 'salesDB_v1/Sales.csv'

# synthetic_parent_path = 'salesDB_v1/parent_recovered_generated_df.csv'
# synthetic_child_path = 'salesDB_v1/child_final.csv'

# CHILD_PRIMARY_KEY = 'user_rates_movie_id'
# PARENT_PRIMARY_KEY = 'movie_id'
# PARENT_TABLE_NAME = 'movie'
# CHILD_TABLE_NAME = 'user_rates_movie'

# parent_domain_path = './movie_lens_1m/movie_domain_continuous.json'
# child_domain_path = 'movie_lens_1m/user_rates_movie_domain_continuous.json'
# parent_df_path = 'movie_lens_1m/movie.csv'
# child_df_path = 'movie_lens_1m/user_rates_movie.csv'

# synthetic_parent_path = 'movie_lens_1m/parent_final.csv'
# synthetic_child_path = 'movie_lens_1m/child_final.csv'

"""
CHILD_PRIMARY_KEY = 'INDIVIDUAL'
PARENT_PRIMARY_KEY = 'HOUSEHOLD'
CHILD_TABLE_NAME = 'individual'
PARENT_TABLE_NAME = 'household'
ALL_NUMERICAL = False

parent_domain_path = 'privLava_data/California/household_domain.json'
child_domain_path = 'privLava_data/California/individual_domain.json'
parent_df_path = 'privLava_data/California/household.csv'
child_df_path = 'privLava_data/California/individual.csv'

synthetic_parent_path = 'parent_final.csv'
synthetic_child_path = 'child_final.csv'
"""
# synthetic_parent_path = 'privLava_data/California/save/001/parent_final.csv'
# synthetic_child_path = 'privLava_data/California/save/001/child_final.csv'

CHILD_PRIMARY_KEY = 'disp_id'
PARENT_PRIMARY_KEY = 'account_id'
CHILD_TABLE_NAME = 'disp'
PARENT_TABLE_NAME = 'account'
ALL_NUMERICAL = False
parent_domain_path = 'berka/preprocessed/account_domain.json'
child_domain_path = 'berka/preprocessed/disp_domain.json'
parent_df_path = 'berka/preprocessed/account.csv'
child_df_path = 'berka/preprocessed/disp.csv'

synthetic_parent_path = 'berka/preprocessed/save/account_disp/parent_final.csv'
synthetic_child_path = 'berka/preprocessed/save/account_disp/merged_disp_2_non_unique.csv'

individual_domain_dict = json.load(open(child_domain_path))
household_domain_dict = json.load(open(parent_domain_path))
individual_df = pd.read_csv(child_df_path)
household_df = pd.read_csv(parent_df_path)

individual_df_non_id_cols = [col for col in individual_df.columns if col in individual_domain_dict]
household_df_non_id_cols = [col for col in household_df.columns if col in household_domain_dict]
individual_cols = [CHILD_PRIMARY_KEY] + individual_df_non_id_cols + [PARENT_PRIMARY_KEY]
household_cols = [PARENT_PRIMARY_KEY] + household_df_non_id_cols

individual_df = individual_df[individual_cols]
household_df = household_df[household_cols]

individual_cluster_df = pd.read_csv(synthetic_child_path)
household_cluster_df = pd.read_csv(synthetic_parent_path)

individual_cluster_df = individual_cluster_df[individual_cols]
household_cluster_df = household_cluster_df[household_cols]

if ALL_NUMERICAL:
    # SET ALL COLUMNS TO NUMERICAL
    for col in individual_domain_dict.keys():
        individual_domain_dict[col]['type'] = 'continuous'

    for col in household_domain_dict.keys():
        household_domain_dict[col]['type'] = 'continuous'


cols_to_change = []
for col in individual_domain_dict.keys():
    if col in household_domain_dict:
        individual_df = individual_df.rename(columns={col: 'child_' + col})
        household_df = household_df.rename(columns={col: 'parent_' + col})
        individual_cluster_df = individual_cluster_df.rename(columns={col: 'child_' + col})
        household_cluster_df = household_cluster_df.rename(columns={col: 'parent_' + col})
        cols_to_change.append(col)

for col in cols_to_change:
    individual_domain_dict['child_' + col] = individual_domain_dict.pop(col)
    household_domain_dict['parent_' + col] = household_domain_dict.pop(col)

individual_cat_cols = []
individual_num_cols = []
household_cat_cols = []
household_num_cols = []

for col in individual_domain_dict.keys():
    if individual_domain_dict[col]['type'] == 'discrete':
        individual_cat_cols.append(col)
    else:
        individual_num_cols.append(col)


for col in household_domain_dict.keys():
    if household_domain_dict[col]['type'] == 'discrete':
        household_cat_cols.append(col)
    else:
        household_num_cols.append(col)

    
        

# individual_cluster_df = individual_cluster_df.rename(columns={'add_numerical': 'add_numerical_child'})
# household_cluster_df = household_cluster_df.rename(columns={'add_numerical': 'add_numerical_parent'})

original_joint_df = pd.merge(individual_df, household_df, on=PARENT_PRIMARY_KEY, how='left').drop(
    columns=[PARENT_PRIMARY_KEY, CHILD_PRIMARY_KEY]
)
cluster_joint_df = pd.merge(individual_cluster_df, household_cluster_df, on=PARENT_PRIMARY_KEY, how='left').drop(
    columns=[PARENT_PRIMARY_KEY, CHILD_PRIMARY_KEY]
)

from sdv.metadata import SingleTableMetadata
metadata = SingleTableMetadata()
metadata.detect_from_dataframe(original_joint_df)
for column_name in individual_domain_dict.keys():
    if individual_domain_dict[column_name]['type'] == 'discrete':
        metadata.update_column(
            column_name=column_name,
            sdtype='categorical'
        )
        original_joint_df[column_name] = original_joint_df[column_name].astype('str')
        cluster_joint_df[column_name] = cluster_joint_df[column_name].astype(int).astype('str')
    else:
        metadata.update_column(
            column_name=column_name,
            sdtype='numerical'
        )
        original_joint_df[column_name] = original_joint_df[column_name].astype('float')
        cluster_joint_df[column_name] = cluster_joint_df[column_name].astype('float')

for column_name in household_domain_dict.keys():
    if household_domain_dict[column_name]['type'] == 'discrete':
        metadata.update_column(
            column_name=column_name,
            sdtype='categorical'
        )
        original_joint_df[column_name] = original_joint_df[column_name].astype('str')
        cluster_joint_df[column_name] = cluster_joint_df[column_name].astype(int).astype('str')
    else:
        metadata.update_column(
            column_name=column_name,
            sdtype='numerical'
        )
        original_joint_df[column_name] = original_joint_df[column_name].astype('float')
        cluster_joint_df[column_name] = cluster_joint_df[column_name].astype('float')

# evaluate joint dataframe
single_table_evaluate_quality(
    original_joint_df,
    cluster_joint_df,
    metadata
)

metadata = MultiTableMetadata()

metadata.detect_table_from_dataframe(
    table_name=CHILD_TABLE_NAME,
    data=individual_df
)

metadata.detect_table_from_dataframe(
    table_name=PARENT_TABLE_NAME,
    data=household_df
)

metadata.update_column(
    table_name=CHILD_TABLE_NAME,
    column_name=CHILD_PRIMARY_KEY,
    sdtype='id',
)

metadata.update_column(
    table_name=CHILD_TABLE_NAME,
    column_name=PARENT_PRIMARY_KEY,
    sdtype='id'
)

metadata.update_column(
    table_name=PARENT_TABLE_NAME,
    column_name=PARENT_PRIMARY_KEY,
    sdtype='id'
)

metadata.set_primary_key(
    table_name=CHILD_TABLE_NAME,
    column_name=CHILD_PRIMARY_KEY
)

metadata.set_primary_key(
    table_name=PARENT_TABLE_NAME,
    column_name=PARENT_PRIMARY_KEY
)

metadata.add_relationship(
    parent_table_name=PARENT_TABLE_NAME,
    child_table_name=CHILD_TABLE_NAME,
    parent_primary_key=PARENT_PRIMARY_KEY,
    child_foreign_key=PARENT_PRIMARY_KEY
)

for column_name in individual_domain_dict.keys():
    if individual_domain_dict[column_name]['type'] == 'discrete':
        metadata.update_column(
            table_name=CHILD_TABLE_NAME,
            column_name=column_name,
            sdtype='categorical'
        )
    else:
        metadata.update_column(
            table_name=CHILD_TABLE_NAME,
            column_name=column_name,
            sdtype='numerical'
        )

for column_name in household_domain_dict.keys():
    if household_domain_dict[column_name]['type'] == 'discrete':
        metadata.update_column(
            table_name=PARENT_TABLE_NAME,
            column_name=column_name,
            sdtype='categorical'
        )
    else:
        metadata.update_column(
            table_name=PARENT_TABLE_NAME,
            column_name=column_name,
            sdtype='numerical'
        )

individual_test_df = individual_cluster_df.copy()
household_test_df = household_cluster_df.copy()


individual_benchmark_df = individual_df.copy()
for col in individual_test_df.columns:
    if col in individual_cat_cols:
        individual_test_df[col] = individual_test_df[col].round().astype(int)
        individual_benchmark_df[col] = individual_benchmark_df[col].astype(int)
    elif col in individual_num_cols:
        individual_test_df[col] = individual_test_df[col].astype(float)
        individual_benchmark_df[col] = individual_benchmark_df[col].astype(float)
    else: # ids
        individual_test_df[col] = individual_test_df[col].astype(int)
        individual_benchmark_df[col] = individual_benchmark_df[col].astype(int)

household_benchmark_df = household_df.copy()
for col in household_test_df.columns:
    if col in household_cat_cols:
        household_test_df[col] = household_test_df[col].round().astype(int)
        household_benchmark_df[col] = household_benchmark_df[col].astype(int)
    elif col in household_num_cols:
        household_test_df[col] = household_test_df[col].astype(float)
        household_benchmark_df[col] = household_benchmark_df[col].astype(float)
    else: # ids
        household_test_df[col] = household_test_df[col].astype(int)
        household_benchmark_df[col] = household_benchmark_df[col].astype(int)


# evaluate multitable
cluster_quality_report = evaluate_quality(
    {
        CHILD_TABLE_NAME: individual_benchmark_df,
        PARENT_TABLE_NAME: household_benchmark_df

    },
    {
        CHILD_TABLE_NAME: individual_test_df,
        PARENT_TABLE_NAME: household_test_df
    },
    metadata
)

print('finished')
