from __future__ import print_function

import sys
sys.path.insert(0, '../src')
sys.path.insert(0, '../graph_methods')

import argparse
import pandas as pd
import csv
import numpy as np
import json
import sys
import random
from function import *
import time


target_name_to_hit_ratio = {
    'NR-AR': 0.0415885461053,
    'NR-AR-LBD': 0.034836817015,
    'NR-AhR': 0.118862559242,
    'NR-Aromatase': 0.0510356609011,
    'NR-ER': 0.125826487678,
    'NR-ER-LBD': 0.0495367070563,
    'NR-PPAR-gamma': 0.0286263208453,
    'SR-ARE': 0.161624709118,
    'SR-ATAD5': 0.0367582706109,
    'SR-HSE': 0.0577032946106,
    'SR-MMP': 0.156383890317,
    'SR-p53': 0.0608950843727,
}


def run_single_deep_classification():
    from deep_classification import SingleClassification

    with open(config_json_file, 'r') as f:
        conf = json.load(f)
    conf['hit_ratio'] = target_name_to_hit_ratio[target_name]
    label_name_list = [target_name]
    print('label_name_list ', label_name_list)

    # read data
    test_index = [running_index]
    train_index = filter(lambda x: x not in test_index, np.arange(5))
    train_file_list = file_list[train_index]
    test_file_list = file_list[test_index]

    print('train files ', train_file_list)
    print('test files ', test_file_list)

    train_pd = filter_out_missing_values(read_merged_data(train_file_list), label_list=label_name_list)
    test_pd = filter_out_missing_values(read_merged_data(test_file_list), label_list=label_name_list)

    X_train, y_train = extract_feature_and_label(train_pd,
                                                 feature_name='Fingerprints',
                                                 label_name_list=label_name_list)
    X_test, y_test = extract_feature_and_label(test_pd,
                                               feature_name='Fingerprints',
                                               label_name_list=label_name_list)
    print('done data preparation')

    task = SingleClassification(conf=conf)
    task.train_and_predict(X_train, y_train, X_test, y_test, weight_file)
    task.eval_with_existing(X_train, y_train, X_test, y_test, weight_file)
    y_pred_on_test = task.predict_with_existing(X_test, weight_file)
    np.savez('output_on_test', y_test=y_test, y_pred=y_pred_on_test)
    return


def run_single_deep_classification_flatten_random_projection():
    from deep_classification import SingleClassification

    with open(config_json_file, 'r') as f:
        conf = json.load(f)
    conf['hit_ratio'] = target_name_to_hit_ratio[target_name]
    conf['input_layer_dimension'] = n_gram_num * random_projection_dimension * segmentation_num
    label_name_list = [target_name]
    print('label_name_list ', label_name_list)

    # read data
    test_index = [running_index]
    train_index = filter(lambda x: x not in test_index, np.arange(5))
    train_file_list = file_list[train_index]
    test_file_list = file_list[test_index]

    print('train files ', train_file_list)
    print('test files ', test_file_list)

    X_train, y_train = extract_feature_and_label_npy(train_file_list,
                                                     feature_name='random_projected_list',
                                                     label_name_list=['label_name'],
                                                     n_gram_num=n_gram_num)
    X_test, y_test = extract_feature_and_label_npy(test_file_list,
                                                   feature_name='random_projected_list',
                                                   label_name_list=['label_name'],
                                                   n_gram_num=n_gram_num)
    print('done data preparation')

    task = SingleClassification(conf=conf)
    task.train_and_predict(X_train, y_train, X_test, y_test, weight_file)
    task.eval_with_existing(X_train, y_train, X_test, y_test, weight_file)
    y_pred_on_test = task.predict_with_existing(X_test, weight_file)
    np.savez('output_on_test', y_test=y_test, y_pred=y_pred_on_test)
    return


def run_xgboost_classification():
    from xgboost_classification import XGBoostClassification

    with open(config_json_file, 'r') as f:
        conf = json.load(f)
    conf['hit_ratio'] = target_name_to_hit_ratio[target_name]
    label_name_list = [target_name]
    print('label_name_list ', label_name_list)

    # read data
    test_index = [running_index]
    train_index = filter(lambda x: x not in test_index, np.arange(5))
    train_file_list = file_list[train_index]
    test_file_list = file_list[test_index]

    print('train files ', train_file_list)
    print('test files ', test_file_list)

    train_pd = filter_out_missing_values(read_merged_data(train_file_list), label_list=label_name_list)
    test_pd = filter_out_missing_values(read_merged_data(test_file_list), label_list=label_name_list)

    X_train, y_train = extract_feature_and_label(train_pd,
                                                 feature_name='Fingerprints',
                                                 label_name_list=label_name_list)
    X_test, y_test = extract_feature_and_label(test_pd,
                                               feature_name='Fingerprints',
                                               label_name_list=label_name_list)
    print('done data preparation')

    task = XGBoostClassification(conf=conf)
    task.train_and_predict(X_train, y_train, X_test, y_test, weight_file)
    task.eval_with_existing(X_train, y_train, X_test, y_test, weight_file)
    y_pred_on_test = task.predict_with_existing(X_test, weight_file)
    np.savez('output_on_test', y_test=y_test, y_pred=y_pred_on_test)
    return


def run_xgboost_classification_flatten_random_projection():
    from xgboost_classification import XGBoostClassification

    with open(config_json_file, 'r') as f:
        conf = json.load(f)
    conf['hit_ratio'] = target_name_to_hit_ratio[target_name]
    label_name_list = [target_name]
    print('label_name_list ', label_name_list)

    # read data
    test_index = [running_index]
    train_index = filter(lambda x: x not in test_index, np.arange(5))
    train_file_list = file_list[train_index]
    test_file_list = file_list[test_index]

    print('train files ', train_file_list)
    print('test files ', test_file_list)

    X_train, y_train = extract_feature_and_label_npy(train_file_list,
                                                     feature_name='random_projected_list',
                                                     label_name_list=['label_name'],
                                                     n_gram_num=n_gram_num)
    X_test, y_test = extract_feature_and_label_npy(test_file_list,
                                                   feature_name='random_projected_list',
                                                   label_name_list=['label_name'],
                                                   n_gram_num=n_gram_num)
    print('done data preparation')

    task = XGBoostClassification(conf=conf)
    print(X_train.shape, '\t', y_train.shape, '\t', X_test.shape, '\t', y_test.shape)
    task.train_and_predict(X_train, y_train, X_test, y_test, weight_file)
    task.eval_with_existing(X_train, y_train, X_test, y_test, weight_file)
    y_pred_on_test = task.predict_with_existing(X_test, weight_file)
    np.savez('output_on_test', y_test=y_test, y_pred=y_pred_on_test)
    return


def run_random_forest_classification():
    from random_forest_classification import RandomForestClassification

    with open(config_json_file, 'r') as f:
        conf = json.load(f)
    conf['hit_ratio'] = target_name_to_hit_ratio[target_name]
    label_name_list = [target_name]
    print('label_name_list ', label_name_list)

    # read data
    test_index = [running_index]
    train_index = filter(lambda x: x not in test_index, np.arange(5))
    train_file_list = file_list[train_index]
    test_file_list = file_list[test_index]

    print('train files ', train_file_list)
    print('test files ', test_file_list)

    train_pd = filter_out_missing_values(read_merged_data(train_file_list), label_list=label_name_list)
    test_pd = filter_out_missing_values(read_merged_data(test_file_list), label_list=label_name_list)

    X_train, y_train = extract_feature_and_label(train_pd,
                                                 feature_name='Fingerprints',
                                                 label_name_list=label_name_list)
    X_test, y_test = extract_feature_and_label(test_pd,
                                               feature_name='Fingerprints',
                                               label_name_list=label_name_list)
    print('done data preparation')

    task = RandomForestClassification(conf=conf)
    task.train_and_predict(X_train, y_train, X_test, y_test, weight_file)
    task.eval_with_existing(X_train, y_train, X_test, y_test, weight_file)
    y_pred_on_test = task.predict_with_existing(X_test, weight_file)
    np.savez('output_on_test', y_test=y_test, y_pred=y_pred_on_test)
    return


def run_random_forest_classification_flatten_random_projection():
    from random_forest_classification import RandomForestClassification

    with open(config_json_file, 'r') as f:
        conf = json.load(f)
    conf['hit_ratio'] = target_name_to_hit_ratio[target_name]
    label_name_list = [target_name]
    print('label_name_list ', label_name_list)

    # read data
    test_index = [running_index]
    train_index = filter(lambda x: x not in test_index, np.arange(5))
    train_file_list = file_list[train_index]
    test_file_list = file_list[test_index]

    print('train files ', train_file_list)
    print('test files ', test_file_list)

    X_train, y_train = extract_feature_and_label_npy(train_file_list,
                                                     feature_name='random_projected_list',
                                                     label_name_list=['label_name'],
                                                     n_gram_num=n_gram_num)
    X_test, y_test = extract_feature_and_label_npy(test_file_list,
                                                   feature_name='random_projected_list',
                                                   label_name_list=['label_name'],
                                                   n_gram_num=n_gram_num)
    print('done data preparation')

    task = RandomForestClassification(conf=conf)
    task.train_and_predict(X_train, y_train, X_test, y_test, weight_file)
    task.eval_with_existing(X_train, y_train, X_test, y_test, weight_file)
    y_pred_on_test = task.predict_with_existing(X_test, weight_file)
    np.savez('output_on_test', y_test=y_test, y_pred=y_pred_on_test)
    return


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--config_json_file', required=True)
    parser.add_argument('--weight_file', required=True)
    parser.add_argument('--model',  required=True)
    parser.add_argument('--target_name', default='NR-AhR')
    parser.add_argument('--n_gram_num', type=int, default=6)
    parser.add_argument('--random_projection_dimension', type=int, default=100)
    parser.add_argument('--running_index', type=int, default=0)
    given_args = parser.parse_args()

    config_json_file = given_args.config_json_file
    weight_file = given_args.weight_file
    n_gram_num = given_args.n_gram_num
    random_projection_dimension = given_args.random_projection_dimension
    running_index = given_args.running_index

    K = 5
    target_name = given_args.target_name
    model = given_args.model

    if 'n_gram' in model:
        directory = '../node2vec/datasets/tox21/{}/{}/{{}}_grammed_cbow_{}_graph.npz'.format(target_name, running_index, random_projection_dimension)
    else:
        directory = '../datasets/tox21/{}/{{}}.csv.gz'.format(target_name)
    file_list = []
    for i in range(K):
        file_list.append(directory.format(i))
    file_list = np.array(file_list)

    start_time = time.time()
    if model == 'morgan_dnn':
        run_single_deep_classification()
    elif model == 'n_gram_dnn':
        run_single_deep_classification_flatten_random_projection()
    elif model == 'morgan_xgb':
        run_xgboost_classification()
    elif model == 'n_gram_xgb':
        run_xgboost_classification_flatten_random_projection()
    elif model == 'morgan_rf':
        run_random_forest_classification()
    elif model == 'n_gram_rf':
        run_random_forest_classification_flatten_random_projection()
    else:
        raise Exception('No such model! Should be among [{}, {}, {}, {}, {}, {}].'.format(
            'morgan_dnn',
            'n_gram_dnn',
            'morgan_xgb',
            'n_gram_xgb',
            'morgan_rf',
            'n_gram_rf'
        ))
    end_time = time.time()
    print('Running time: {}'.format(end_time - start_time))

    import os
    os.rename('output_on_test.npz', '../output/{}/{}/{}.npz'.format(running_index, model, target_name))
