#!/usr/bin/env python
# coding: utf-8

import pandas as pd
import sys
import random as rn
import os
# os.environ['PYTHONHASHSEED'] = '0'

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
from sklearn.model_selection import KFold
import matplotlib.pyplot as plt
import numpy as np


sys.path.append('./src/')
from Utils import makeDir

from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from sklearn.model_selection import GridSearchCV
from tensorflow.keras.wrappers.scikit_learn import KerasClassifier
from sklearn.metrics import accuracy_score


# np.random.seed(37)
# rn.seed(1254)
# tf.set_random_seed(89)

# from tensorflow.keras import backend as K
# session_conf = tf.compat.v1.ConfigProto(intra_op_parallelism_threads=1, inter_op_parallelism_threads=1)
# sess = tf.compat.v1.Session(graph=tf.compat.v1.get_default_graph(), config=session_conf)
# K.set_session(sess)


all_history=pd.DataFrame()
mean_squared_error = []
val_mean_squared_error = []
feature_length = 0
layer_size = 0
hidden_size = 0

def make_cv_data(dataset_path, n_folds, data_dir,fold_data):
    #Read data
    column_names = ['Age','WorkClass','fnlwgt','Education','Education Num','Marital Status',
                'Occupation', 'Relationship', 'Race','Sex','Capital-Gain','Capital Loss'
                ,'Hours Per Week','Native Country','Income Range']
    raw_dataset = pd.read_csv(dataset_path, names=column_names,
                          na_values = "?", comment='\t',
                          sep=",", skipinitialspace=True)
    dataset = raw_dataset.copy()
    dataset = dataset.dropna()
    
    dataset['Age'] = dataset['Age'].astype(float)
    dataset['WorkClass'] = dataset['WorkClass'].astype('category').cat.codes
    dataset['Education Num'] = dataset['Education Num'].astype(int)
    dataset['Marital Status'] = dataset['Marital Status'].astype('category').cat.codes
    dataset['Occupation'] = dataset['Occupation'].astype('category').cat.codes
    dataset['Relationship'] = dataset['Relationship'].astype('category').cat.codes
    dataset['Race'] = dataset['Race'].astype('category').cat.codes
    dataset['Sex'] = dataset['Sex'].astype('category').cat.codes
    
    
    dataset['Capital-Gain'] = dataset['Capital-Gain'].astype(float)
    dataset['Capital Loss'] = dataset['Capital Loss'].astype(float)
    dataset['Hours Per Week'] = dataset['Hours Per Week'].astype(int)
    dataset['Native Country'] = dataset['Native Country'].astype('category').cat.codes
    dataset['Income Range'] = dataset['Income Range'].astype('category').cat.codes

    dataset.pop('Education')
    dataset.pop('fnlwgt')

    
    min_max_dict = getMinMaxRangeOfFeatures(dataset)
    # prepare the k-fold cross-validation configuration
    train_dataset = []
    train_labels = []
    test_dataset= []
    test_labels= []
    #make train and test
    for i in range(n_folds):
        print(data_dir)
        makeDir(data_dir+'%d/'%(i))
        train_path = fold_data + '%d/train_data.csv'%(i)
        test_path = fold_data + '%d/test_data.csv'%(i)

        if (os.path.isfile(train_path)):
            print("Reading from exisiting location " + str(train_path))
            trainX = pd.read_csv(train_path, index_col=0)
            testX = pd.read_csv(test_path, index_col=0)
        else:
            testX = dataset.sample(frac=0.20,random_state=0)
            trainX = dataset.drop(testX.index)
            trainX.to_csv(train_path)
            testX.to_csv(test_path) 
            
        train_dataset.append(trainX)
        test_dataset.append(testX)

        trainY = trainX.pop('Income Range')
        testY = testX.pop('Income Range')
        train_labels.append(trainY)
        test_labels.append(testY)

    return train_dataset, train_labels, test_dataset, test_labels, min_max_dict



def build(train_dataset, layer_size, hidden_size, learning_rate):
    layer_array = []
    for i in range(layer_size-1):
        if i == 0:
            layer_array.append(layers.Dense(hidden_size, activation=tf.nn.relu, input_shape=[len(train_dataset.keys())]))
        else:
            layer_array.append(layers.Dense(hidden_size, activation=tf.nn.relu))
    
    layer_array.append(layers.Dense(1,activation='sigmoid'))
    model = keras.Sequential(layer_array)
    
    optimizer = keras.optimizers.Adam(lr=learning_rate, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
    model.compile(loss='binary_crossentropy',
                optimizer=optimizer,
                metrics=['accuracy'])
    activation_types =[]
    for i in range(layer_size-1):
        activation_types.append('relu')
    activation_types.append('linear')
    return model, activation_types, layer_size


# In[6]:
def build_fn(lr):
    global layer_size
    global feature_length
    global hidden_size

    layer_array = []
    for i in range(layer_size-1):
        if i == 0:
            layer_array.append(layers.Dense(hidden_size, activation=tf.nn.relu, input_shape=[feature_length]))
        else:
            layer_array.append(layers.Dense(hidden_size, activation=tf.nn.relu))
    
    layer_array.append(layers.Dense(1,activation='sigmoid'))
    model = keras.Sequential(layer_array)

    #optimizer = tf.keras.optimizers.RMSprop(0.001)
    optimizer = keras.optimizers.Adam(lr=lr, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0, amsgrad=False)
    model.compile(loss='binary_crossentropy',
                optimizer=optimizer,
                metrics=['accuracy'])
    return model


#MIP_model is an array of layers where each layer is a tuple of (#activation (e.g., relu or linear), weight, bias)
def train(train_dataset, train_labels, fold, data_dir, layer_size_,hidden_size_):
    #build the model
    global feature_length
    global layer_size 
    global hidden_size
    hidden_size = hidden_size_
    layer_size = layer_size_
    feature_length = len(train_dataset.keys())
    model = KerasClassifier(build_fn=build_fn, verbose=0)

    callbacks = [EarlyStopping(monitor='val_loss', patience=1000),
             ModelCheckpoint(filepath=data_dir+'best_model.h5', monitor='val_loss', save_best_only=True)]

    MIP_model = []
    # lr = [0.01,0.001]
    # batch_size = [32, 64]
    # epochs = [500, 800]
    batch_size = [128, 256, 1024]
    epochs = [100, 500, 1000]
    lr = [0.1,0.01]


    # lr = [0.1]
    # batch_size = [32]
    # epochs = [100]
    
    param_grid = dict(batch_size=batch_size, epochs=epochs,lr=lr)
    grid = GridSearchCV(estimator=model, param_grid=param_grid,n_jobs = 25,cv=5)

    # if (os.path.isfile(model_file)):
    #     model.load_weights(model_file) 
    # else:
    data_dir = data_dir+str(fold)+"/initial/"
    makeDir(data_dir)
    model_file = data_dir+"model.h5"
    grid_result = grid.fit(train_dataset, train_labels)
    print(model_file)
    grid.best_estimator_.model.save(model_file)
    for i in range(layer_size):
        weight = grid.best_estimator_.model.layers[i].get_weights()[0]
        bias = grid.best_estimator_.model.layers[i].get_weights()[1]
        np.savetxt(data_dir+"weights_layer%d.csv"%(i),weight,delimiter=",")
        np.savetxt(data_dir+"bias_layer%d.csv"%(i),bias,delimiter=",")
    print("Best for : %d"%(layer_size_))
    print("Best: [%f] Layer Size [%d] and Neuron Size [%d] using %s" % (grid_result.best_score_, layer_size, hidden_size_, grid_result.best_params_))
    means = grid_result.cv_results_['mean_test_score']
    stds = grid_result.cv_results_['std_test_score']
    params = grid_result.cv_results_['params']
    for mean, stdev, param in zip(means, stds, params):
        print("%f (%f) with: %r" % (mean, stdev, param))
    return model, MIP_model


def evaluate(model, test_dataset,test_labels):
    scores = model.evaluate(test_dataset, test_labels, verbose=0)
    #TODO: score[1] is the accuracy
    return scores[1]


def update_batch (model, mip_model,batch_data,batch_label,fold,data_dir,batch_size):
    history=model.fit(batch_data, batch_label, epochs=1, batch_size=batch_size, validation_split = 0.2, verbose=0)
    layer_size = len(mip_model)
    updated_MIP_model = []
    # hidden_size = 64
    for i in range(layer_size):
        weight = model.layers[i].get_weights()[0]
        bias = model.layers[i].get_weights()[1]
        np.savetxt(data_dir+"/weights_layer%d.csv"%(i),weight,delimiter=",")
        np.savetxt(data_dir+"/bias_layer%d.csv"%(i),bias,delimiter=",")        
        activation_type = mip_model[i][0]
        updated_MIP_model.append((activation_type, weight, bias))
    return model, updated_MIP_model

def update(model, mip_model, batch_data, batch_label,fold, data_dir):
    history=model.fit(batch_data, batch_label, epochs=10, batch_size=64, validation_split = 0.2, verbose=1)
    # plot_history(history)
    layer_size = len(mip_model)
    updated_MIP_model = []
    # hidden_size = 64
    for i in range(layer_size):
        weight = model.layers[i].get_weights()[0]
        bias = model.layers[i].get_weights()[1]
        np.savetxt(data_dir+"/weights_layer%d.csv"%(i),weight,delimiter=",")
        np.savetxt(data_dir+"/bias_layer%d.csv"%(i),bias,delimiter=",")
        activation_type = mip_model[i][0]
        updated_MIP_model.append((activation_type, weight, bias))
    return model, updated_MIP_model


# In[9]:


def output(model, datapoint):
    column_names = ['Age','WorkClass','Education Num','Marital Status',
                'Occupation', 'Relationship', 'Race','Sex','Capital-Gain','Capital Loss'
                ,'Hours Per Week','Native Country','Income Range']
    x_point = pd.DataFrame(datapoint)#, columns= column_names)
    return model.predict_classes(x_point.transpose())


def plot_history(history):
  hist = pd.DataFrame(history.history)
  global mean_squared_error
  global val_mean_squared_error

  mean_squared_error.extend(hist['mean_squared_error'].tolist())
  val_mean_squared_error.extend(hist['val_mean_squared_error'].tolist())
  print(mean_squared_error)
  epoch = list(range(0, len(mean_squared_error)))
  print(epoch)
  plt.figure()
  plt.xlabel('Epoch')
  plt.ylabel('Mean Square Error [$MPG^2$]')
  plt.plot(epoch, mean_squared_error,
           label='Train Error')
  plt.plot(epoch, val_mean_squared_error,
           label = 'Val Error')
  plt.ylim([0,60])
  plt.legend()
  plt.show()

def getMinMaxRangeOfFeatures(dataset):
    column_names = ['Age','WorkClass','Education Num','Marital Status',
                'Occupation', 'Relationship', 'Race','Sex','Capital-Gain','Capital Loss'
                ,'Hours Per Week','Native Country','Income Range']
    min_max = {}
    for i in column_names:
        index = column_names.index(i)
        min_max[i] = [min(dataset[column_names[index]]),max(dataset[column_names[index]])]
    return min_max

def getMinMaxRange(dataset,index):
    column_names = ['Age','WorkClass','Education Num','Marital Status',
                'Occupation', 'Relationship', 'Race','Sex','Capital-Gain','Capital Loss'
                ,'Hours Per Week','Native Country','Income Range']
    return min(dataset[column_names[index]]), max(dataset[column_names[index]]),column_names[index]


def run(fold,n_folds):
    size = 16
    i = 3
    print("Running for size %d and layers %d"%(size,i))
    data_path = './Data/adult.data'
    hidden_neurons = size
    data_dir = './Models/adult/'+str(i)+'_'+str(hidden_neurons)+'/'
    makeDir(data_dir,True)
    train_dataset, train_labels, test_dataset, test_labels,min_max_dict = make_cv_data(data_path,n_folds,data_dir,data_dir)
    model, MIP_model = train(train_dataset[0],train_labels[0],fold,data_dir,i,size)

if __name__ == '__main__':
    fold = int(sys.argv[1:][0])
    n_folds = int(sys.argv[1:][1])
    run(fold,n_folds)