import os, sys
import numpy as np
import tensorflow as tf
from sklearn.externals import joblib
from DataModule import MnistModule, NewsModule, AdultModule
import SGDInfluence

def settings(key):
    assert key in ['mnist', '20news', 'adult']
    if key == 'mnist':
        module = MnistModule()
        n_tr, n_val, n_test = 200, 200, 200
        m = [8, 8]
        alpha = 0.001
        lr, decay, num_epoch, batch_size = 0.1, False, 10, 20
        return module, (n_tr, n_val, n_test), m, alpha, (lr, decay, num_epoch, batch_size)
    elif key == '20news':
        module = NewsModule()
        n_tr, n_val, n_test = 200, 200, 200
        m = [8, 8]
        alpha = 0.001
        lr, decay, num_epoch, batch_size = 0.1, False, 10, 20
        return module, (n_tr, n_val, n_test), m, alpha, (lr, decay, num_epoch, batch_size)
    elif key == 'adult':
        module = AdultModule(csv_path='./data')
        n_tr, n_val, n_test = 200, 200, 200
        m = [8, 8]
        alpha = 0.001
        lr, decay, num_epoch, batch_size = 0.1, False, 10, 20
        return module, (n_tr, n_val, n_test), m, alpha, (lr, decay, num_epoch, batch_size)
    
def build_dnn(input_tensor, output_tensor, m=[32, 64], seed=0):
    d = input_tensor.get_shape()[1]
    p = output_tensor.get_shape()[1]
    
    # Dense 1
    params = []
    c = 0
    with tf.name_scope('fc1'):
        w_fc1 = tf.get_variable('w_fc1', shape=(d,m[0]), dtype=tf.float32, 
                                  initializer=tf.initializers.he_normal(seed=seed+c))
        c += 1
        b_fc1 = tf.get_variable('b_fc1', shape=(m[0],), dtype=tf.float32, 
                                  initializer=tf.initializers.truncated_normal(seed=seed+c))
        c += 1
        fc1 = tf.nn.relu(tf.nn.bias_add(tf.matmul(input_tensor, w_fc1), b_fc1))
        params.append(w_fc1)
        params.append(b_fc1)
    
    # Dense 2
    with tf.name_scope('fc2'):
        w_fc2 = tf.get_variable('w_fc2', shape=(m[0],m[1]), dtype=tf.float32, 
                                  initializer=tf.initializers.he_normal(seed=seed+c))
        c += 1
        b_fc2 = tf.get_variable('b_fc2', shape=(m[1],), dtype=tf.float32, 
                                  initializer=tf.initializers.truncated_normal(seed=seed+c))
        c += 1
        fc2 = tf.nn.relu(tf.nn.bias_add(tf.matmul(fc1, w_fc2), b_fc2))
        params.append(w_fc2)
        params.append(b_fc2)
    
    # Dense 3
    with tf.name_scope('fc3'):
        w_fc3 = tf.get_variable('w_fc3', shape=(m[1],p), dtype=tf.float32, 
                                  initializer=tf.initializers.he_normal(seed=seed+c))
        c += 1
        b_fc3 = tf.get_variable('b_fc3', shape=(p,), dtype=tf.float32, 
                                  initializer=tf.initializers.truncated_normal(seed=seed+c))
        c += 1
        #logit = tf.nn.relu(tf.nn.bias_add(tf.matmul(fc2, w_fc3), b_fc3))
        logit = tf.nn.bias_add(tf.matmul(fc2, w_fc3), b_fc3)
        params.append(w_fc3)
        params.append(b_fc3)
    
    # softmax
    sigmoid = tf.nn.sigmoid(logit)
    
    return sigmoid, logit, params
    
def test(key, seed=0, gpu_index=0):
    dn = './%s_dnn' % (key,)
    fn = '%s/sgd%03d.dat' % (dn, seed)
    if not os.path.exists(dn):
        os.mkdir(dn)
    
    # fetch data
    module, (n_tr, n_val, n_test), m, alpha, (lr, decay, num_epoch, batch_size) = settings(key)
    z_tr, z_val, _, _ = module.fetch(n_tr, n_val, n_test, seed)
    (x_tr, y_tr), (x_val, y_val) = z_tr, z_val
    y_tr = y_tr[:, np.newaxis]
    y_val = y_val[:, np.newaxis]
    
    # fit & save - sgd
    res = {'sgd':{}}
    for i in range(-1, n_tr):
        tf.reset_default_graph()
        with tf.device('/gpu:%d' % (gpu_index,)):
            input_tensor = tf.placeholder(tf.float32, shape=(None, x_tr.shape[1]))
            output_tensor = tf.placeholder(tf.float32, shape=(None, 1))
            sigmoid, logit, params = build_dnn(input_tensor, output_tensor, m=m, seed=seed)
            loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(labels=output_tensor, logits=logit))
            graph = tf.get_default_graph()
            sess = tf.Session(graph=graph)
            inf = SGDInfluence.SGDInfluence(graph, sess, input_tensor, output_tensor, sigmoid, loss, params, alpha=alpha)
            if i < 0:
                skip_index = []
            else:
                skip_index = [i]
            inf.run_sgd(x_tr, y_tr, x_val, y_val, 
                        lr=lr, decay=decay, num_epoch=num_epoch, batch_size=batch_size, 
                        seed=seed, prefix='%s_dnn_seed0%03d' % (key, seed), skip_index=skip_index)
            if i < 0:
                res['sgd']['noskip'] = {'a':inf.a_, 'info':inf.info_}
            else:
                res['sgd'][i] = {'a':inf.a_, 'info':[]}
    joblib.dump(res, fn, compress=9)
    
if __name__ == '__main__':
    key = sys.argv[1]
    seed = int(sys.argv[2])
    gpu_index = int(sys.argv[3])
    test(key, seed, gpu_index)
    
