from __future__ import print_function, division
import tensorflow as tf
import numpy as np
import datetime
import os
from shutil import copyfile
import matplotlib
matplotlib.use('agg')
import matplotlib.pyplot as pltf
import sys
import scipy.io



os.environ['CUDA_VISIBLE_DEVICES'] = "0"

#####################################  ECC config ########################################
code_parityCheckMatrix = np.load('/ECC_MATRIX/BCH_63_51_PCM.npy')
code_generatorMatrix = np.load('/ECC_MATRIX/BCH_63_51_GM.npy')
code_n = 63
code_k = 51
code_rate = 1.0*code_k/code_n


# init the AWGN
start_snr = 1
step = 1
stop_snr = 8
snr_db = np.arange(start_snr,stop_snr+step,step,dtype=np.float32)
word_seed = 786000
noise_seed = 345000

## folders
dir_name = './output_' + datetime.datetime.strftime(datetime.datetime.now(), '%Y-%m-%d_%H:%M:%S')
os.mkdir(dir_name)
os.mkdir(dir_name+'/weights')
weights_path_format = dir_name + '/weights/weights_epoch_%03d.ckpt'
results_path = dir_name + "/results.txt"
src_script = sys.argv[0]
dst_script = dir_name + '/script.py'
copyfile(src_script,dst_script)
load_weights = False
weights_path = 'PATH_TO_WEIGHTS'

#################################  Neural Network config #################################
input_output_layer_size = code_n
num_hidden_layers = 5 #5 means that we have 10 layers of parity and check nodes
batch_size = 15*len(snr_db)
numOfWordSim_train = 15
batches_for_val_per_snr = np.array((100, 100, 100, 100, 1000, 5000, 10000, 10000))
multiloss = 1
batch_in_epoch = 500
num_of_batch = 100000000
learning_rate = 0.0001
gpu_mem_fraction = 0.99
train_on_zero_word = True
test_on_zero_word = False


# Variable Node Network
neurons_per_odd_layer = int(np.sum(code_parityCheckMatrix))
neurons_per_even_layer = neurons_per_odd_layer
n_input = neurons_per_odd_layer + 1 # plus 1 for skip connection
n_hidden_1 = 16
n_hidden_2 = 16
n_classes = 1

# SF network
sf_n_input = neurons_per_odd_layer
sf_n_hidden_1 = 32
sf_n_hidden_2 = 32
sf_n_hidden_3 = 32
sf_n_classes = 1

##################################  init Parameters  ####################################
##### ECC parameters
snr_lin = 10.0**(snr_db/10.0)
scaling_factor = np.sqrt(1.0/(2.0*snr_lin*code_rate))
wordRandom = np.random.RandomState(word_seed)
random = np.random.RandomState(noise_seed)


##### NN parameters
W_input = np.zeros((input_output_layer_size, neurons_per_odd_layer), dtype=np.float32)
W_odd2even = np.zeros((neurons_per_odd_layer, neurons_per_even_layer), dtype=np.float32)
W_odd2even_graphnn = np.zeros((neurons_per_odd_layer, neurons_per_even_layer, neurons_per_odd_layer), dtype=np.float32)
W_skipconn2even = np.zeros((input_output_layer_size,neurons_per_even_layer), dtype=np.float32)
W_even2odd = np.zeros((neurons_per_even_layer, neurons_per_odd_layer), dtype=np.float32)
W_output = np.zeros((neurons_per_odd_layer, input_output_layer_size), dtype=np.float32)

# init W_input
k = 0
for i in range(0,code_parityCheckMatrix.shape[0],1):
    for j in range(0,code_parityCheckMatrix.shape[1],1):
        if(code_parityCheckMatrix[i,j] == 1):
            vec = code_parityCheckMatrix[i,:].copy()
            vec[j] = 0
            W_input[:,k] = vec
            k += 1

# init W_odd2even & W_skipconn2even
k = 0
vec_tmp = np.zeros((neurons_per_odd_layer),dtype=np.float32)
for j in range(0,code_parityCheckMatrix.shape[1],1):
    for i in range(0,code_parityCheckMatrix.shape[0],1):
        if(code_parityCheckMatrix[i,j] == 1):

            num_of_conn = np.sum(code_parityCheckMatrix[:,j])        # get the number of connection of the variable node
            idx = np.argwhere(code_parityCheckMatrix[:,j] ==1)       # get the indexes
            for l in range(0, num_of_conn, 1):                                 # adding num_of_conn columns to W
                vec_tmp = np.zeros((neurons_per_odd_layer),dtype=np.float32)
                for r in range(0, code_parityCheckMatrix.shape[0], 1):         # adding one to the right place
                    if(code_parityCheckMatrix[r,j] == 1 and idx[l][0] != r):
                        idx_vec = np.cumsum(code_parityCheckMatrix[r,0:j+1])[-1] - 1
                        # vec_tmp[idx_vec + r*d_c] = 1.0 #6.5 #1.0
                        vec_tmp[int(idx_vec + np.sum(code_parityCheckMatrix[:r,:]))] = 1.0
                W_odd2even[:,k] = vec_tmp.transpose()
                k += 1
            break

# init W_even2odd
k = 0
for j in range(0,code_parityCheckMatrix.shape[1],1):
    for i in range(0,code_parityCheckMatrix.shape[0],1):
        if(code_parityCheckMatrix[i,j] == 1):
            idx_row = np.cumsum(code_parityCheckMatrix[i,0:j+1])[-1] - 1 # the index of this parity check

            till_d_c = np.sum(code_parityCheckMatrix[:i,:])
            this_d_c = np.sum(code_parityCheckMatrix[:(i+1),:])
            W_even2odd[k,int(till_d_c):int(this_d_c)] = 1.0
            W_even2odd[k,int(till_d_c+idx_row)] = 0.0

            k += 1

# init W_output
k = 0
for j in range(0,code_parityCheckMatrix.shape[1],1):
    for i in range(0,code_parityCheckMatrix.shape[0],1):
        if(code_parityCheckMatrix[i,j] == 1):
            idx_row = np.cumsum(code_parityCheckMatrix[i,0:j+1])[-1] - 1
            till_d_c = np.sum(code_parityCheckMatrix[:i,:])
            W_output[int(till_d_c+idx_row), k] = 1.0

    k += 1

# init W_skipconn2even
k = 0
for j in range(0,code_parityCheckMatrix.shape[1],1):
    for i in range(0,code_parityCheckMatrix.shape[0],1):
        if(code_parityCheckMatrix[i,j] == 1):
            W_skipconn2even[j,k] = 1.0
            k += 1


# init W_odd2even_graphnn
for j in range(0,W_odd2even.shape[1],1):
    for i in range(0,W_odd2even.shape[0],1):
            W_odd2even_graphnn[j, i, i] = W_odd2even[i, j]



##################################  Functions  ####################################
def create_mix_epoch(scaling_factor, wordRandom, numOfWordSim, code_n, code_k, code_generatorMatrix, is_zeros_word):

    X = np.zeros([1,code_n], dtype=np.float32)
    Y = np.zeros([1,code_n], dtype=np.int32)

    # build set for epoch
    for sf_i in scaling_factor:
        if is_zeros_word:
            infoWord_i = 0*wordRandom.randint(0, 2, size=(numOfWordSim, code_k))
        else:
            infoWord_i = wordRandom.randint(0, 2, size=(numOfWordSim, code_k))

        Y_i = np.dot(infoWord_i, code_generatorMatrix) % 2
        X_p_i = random.normal(0.0,1.0,Y_i.shape)*sf_i + (-1)**(Y_i)
        x_llr_i = 2*X_p_i/(sf_i**2)
        X = np.vstack((X,x_llr_i))
        Y = np.vstack((Y,Y_i))
    X = X[1:]
    Y = Y[1:]

    return X,Y



def calc_ber_fer(snr_db, Y_v_pred, Y_v, batches_for_val_per_snr):
    ber_test = np.zeros(snr_db.shape[0])
    fer_test = np.zeros(snr_db.shape[0])
    last_ind = 0
    for i in range(0,snr_db.shape[0]):
        numOfWordSim = int(batches_for_val_per_snr[i]*1.0)
        Y_v_pred_i = Y_v_pred[last_ind:(last_ind + numOfWordSim),:]
        Y_v_i = Y_v[last_ind:(last_ind + numOfWordSim),:]
        ber_test[i] = np.abs(((Y_v_pred_i<0.5)-Y_v_i)).sum()/(Y_v_i.shape[0]*Y_v_i.shape[1])
        fer_test[i] = (np.abs(np.abs(((Y_v_pred_i<0.5)-Y_v_i))).sum(axis=1)>0).sum()*1.0/Y_v_i.shape[0]
        last_ind = last_ind + numOfWordSim

    return ber_test, fer_test




def MLP_VN(x, weights):

    layer_1 = tf.einsum('aij,ajb->aib', x, weights['h1'])
    layer_1 = tf.nn.tanh(layer_1)

    # Output layer with linear activation
    out_layer = tf.einsum('aij,ajb->aib', layer_1, weights['out'])
    out_layer = tf.nn.tanh(out_layer)
    out_layer = tf.squeeze(out_layer, 2)

    return out_layer


## SF network
sf_weights  = {
    'sf_h1': tf.get_variable('sf_w1_xaiver', [sf_n_input, sf_n_hidden_1],initializer=tf.contrib.layers.xavier_initializer()),
    'sf_h2': tf.get_variable('sf_w2_xaiver', [sf_n_hidden_1, sf_n_hidden_2],initializer=tf.contrib.layers.xavier_initializer()),
    'sf_h3': tf.get_variable('sf_w3_xaiver', [sf_n_hidden_2, sf_n_hidden_3],initializer=tf.contrib.layers.xavier_initializer()),
    'sf_h4': tf.get_variable('sf_w4_xaiver', [sf_n_hidden_3, sf_n_hidden_3],initializer=tf.contrib.layers.xavier_initializer()),
    'sf_head1': tf.get_variable('sf_head1',[sf_n_hidden_3, n_input*n_hidden_1],initializer=tf.contrib.layers.xavier_initializer()),
    'sf_head3': tf.get_variable('sf_head3',[sf_n_hidden_3, n_hidden_2*n_classes],initializer=tf.contrib.layers.xavier_initializer())
}


def SF(x, sf_weights):

    layer_1 = tf.einsum('aj,jb->ab', x, sf_weights['sf_h1'])
    layer_1 = tf.nn.tanh(layer_1)

    # Hidden layer with RELU activation
    layer_2 = tf.einsum('aj,jb->ab', layer_1, sf_weights['sf_h2'])
    layer_2 = tf.nn.tanh(layer_2)

    # Hidden layer with RELU activation
    layer_3 = tf.einsum('aj,jb->ab', layer_2, sf_weights['sf_h3'])
    layer_3 = tf.nn.tanh(layer_3)

    # Hidden layer with RELU activation
    layer_4 = tf.einsum('aj,jb->ab', layer_3, sf_weights['sf_h4'])
    layer_4 = tf.nn.tanh(layer_4)

    # Output layer with linear activation
    out_1 = tf.einsum('aj,jb->ab', layer_4, sf_weights['sf_head1'])
    out_3 = tf.einsum('aj,jb->ab', layer_4, sf_weights['sf_head3'])

    return out_1, out_3



def ARC_TANH_LIKE(x, order):

    out = x
    for i in range(3, order+1):
        if (i-1) % 2 == 0:
            out += (1.0/i)*tf.pow(x, i*tf.ones_like(x))

    return out



##################################  Network architecture ####################################
x = tf.placeholder(tf.float32, shape=[batch_size, code_n])
y = tf.placeholder(tf.float32, shape=[batch_size, code_n])
W_odd2even_graphnn_var = tf.Variable(W_odd2even_graphnn)

# first layer
W_input = W_input.transpose().copy()
x_tile = tf.tile(x,multiples=[1,neurons_per_odd_layer])
W_input_reshape = tf.reshape(W_input, [-1])
x_tile_mul = tf.multiply(x_tile, W_input_reshape)
x_tile_mul_reshape = tf.reshape(x_tile_mul, [batch_size,neurons_per_odd_layer,code_n])
u_i = tf.tanh(0.5*tf.clip_by_value(x_tile_mul_reshape, clip_value_min=-10, clip_value_max=10))
u_i_1 = tf.add(u_i, 1-tf.to_float(tf.abs(u_i) > 0))
z_input = tf.reduce_prod(u_i_1,reduction_indices=2)
x_input = tf.log(tf.div(1+z_input, 1-z_input))

# hidden layers
net_dict = {}
for i in range(0, num_hidden_layers-1, 1):

    # parity layer
    net_dict["hidden_parity_sc_{0}".format(i)] = W_skipconn2even
    net_dict["hidden_parity_sc_{0}".format(i)] = tf.multiply(W_skipconn2even, net_dict["hidden_parity_sc_{0}".format(i)])

    if(i == 0):

        net_dict["ada_wght1{0}".format(i)], net_dict["ada_wght3{0}".format(i)] = SF(tf.abs(x_input), sf_weights)
        net_dict["weights_mlpvn_ada{0}".format(i)] = {'h1':tf.reshape(net_dict["ada_wght1{0}".format(i)], [batch_size, n_input, n_hidden_1]), 'out':tf.reshape(net_dict["ada_wght3{0}".format(i)], [batch_size, n_hidden_2, n_classes])}

        # learn_only_FG_weights
        W_odd2even_graphnn_var = tf.multiply(W_odd2even_graphnn, W_odd2even_graphnn_var)
        W_odd2even_graphnn_var = tf.to_float(W_odd2even_graphnn_var)
        x_input_tile = tf.expand_dims(x_input, 1)
        net_dict["hidden_parity_x_graphnn_0{0}".format(i)] = tf.einsum('aij,bjk->abik', x_input_tile, W_odd2even_graphnn_var)
        net_dict["hidden_parity_x_graphnn_1{0}".format(i)] = tf.squeeze(net_dict["hidden_parity_x_graphnn_0{0}".format(i)], 2)

        # concat the skip connection
        net_dict["hidden_parity_x_graphnn_2{0}".format(i)] = tf.expand_dims(tf.matmul(x, net_dict["hidden_parity_sc_{0}".format(i)]), 2)
        net_dict["hidden_parity_x_graphnn_3{0}".format(i)] = tf.concat([net_dict["hidden_parity_x_graphnn_1{0}".format(i)], net_dict["hidden_parity_x_graphnn_2{0}".format(i)]], 2)
        net_dict["hidden_parity_x_graphnn_4_mlp{0}".format(i)] = MLP_VN(net_dict["hidden_parity_x_graphnn_3{0}".format(i)], net_dict["weights_mlpvn_ada{0}".format(i)])
        net_dict["hidden_parity_x_{0}".format(i)] = net_dict["hidden_parity_x_graphnn_4_mlp{0}".format(i)]


    else:

        net_dict["ada_wght1{0}".format(i)], net_dict["ada_wght3{0}".format(i)] = SF(tf.abs(net_dict["hidden_check_x_{0}".format(i-1)]), sf_weights)
        net_dict["weights_mlpvn_ada{0}".format(i)] = {'h1':tf.reshape(net_dict["ada_wght1{0}".format(i)], [batch_size, n_input, n_hidden_1]), 'out':tf.reshape(net_dict["ada_wght3{0}".format(i)], [batch_size, n_hidden_2, n_classes])}

        # learn_only_FG_weights
        W_odd2even_graphnn_var = tf.multiply(W_odd2even_graphnn, W_odd2even_graphnn_var)
        W_odd2even_graphnn_var = tf.to_float(W_odd2even_graphnn_var)
        net_dict["hidden_check_x_tile_in{0}".format(i)] = tf.expand_dims(net_dict["hidden_check_x_{0}".format(i-1)], 1)
        net_dict["hidden_parity_x_graphnn_0{0}".format(i)] = tf.einsum('aij,bjk->abik', net_dict["hidden_check_x_tile_in{0}".format(i)], W_odd2even_graphnn_var)
        net_dict["hidden_parity_x_graphnn_1{0}".format(i)] = tf.squeeze(net_dict["hidden_parity_x_graphnn_0{0}".format(i)], 2)

        # concat the skip connection
        net_dict["hidden_parity_x_graphnn_2{0}".format(i)] = tf.expand_dims(tf.matmul(x, net_dict["hidden_parity_sc_{0}".format(i)]), 2)
        net_dict["hidden_parity_x_graphnn_3{0}".format(i)] = tf.concat([net_dict["hidden_parity_x_graphnn_1{0}".format(i)], net_dict["hidden_parity_x_graphnn_2{0}".format(i)]], 2)
        net_dict["hidden_parity_x_graphnn_4_mlp{0}".format(i)] = MLP_VN(net_dict["hidden_parity_x_graphnn_3{0}".format(i)], net_dict["weights_mlpvn_ada{0}".format(i)])
        net_dict["hidden_parity_x_{0}".format(i)] = net_dict["hidden_parity_x_graphnn_4_mlp{0}".format(i)]

    # check layer
    net_dict["hidden_check_x_0{0}".format(i)] = tf.tile(net_dict["hidden_parity_x_{0}".format(i)], multiples=[1, neurons_per_odd_layer])
    net_dict["hidden_check_x_1{0}".format(i)] = tf.multiply(net_dict["hidden_check_x_0{0}".format(i)],tf.reshape(W_even2odd.transpose(), [-1]))
    net_dict["hidden_check_x_2{0}".format(i)] = tf.reshape(net_dict["hidden_check_x_1{0}".format(i)],[batch_size, neurons_per_odd_layer, neurons_per_even_layer])
    net_dict["hidden_check_x_3{0}".format(i)] = tf.add(net_dict["hidden_check_x_2{0}".format(i)], 1 - tf.to_float(tf.abs(net_dict["hidden_check_x_2{0}".format(i)]) > 0))
    net_dict["hidden_check_x_4{0}".format(i)] = tf.reduce_prod(net_dict["hidden_check_x_3{0}".format(i)],reduction_indices=2)
    net_dict["hidden_check_x_{0}".format(i)] = 2*ARC_TANH_LIKE(net_dict["hidden_check_x_4{0}".format(i)], order=1005)

    if (i%multiloss ==0 and i != (num_hidden_layers-1)):
        net_dict["hidden_check_output_{0}".format(i)] = tf.add(x, tf.matmul(net_dict["hidden_check_x_{0}".format(i)], W_output))

# output layer
W_output_var = tf.Variable(W_output.copy())
W_output_var = tf.multiply(W_output_var, W_output)
y_output_1 = tf.matmul(net_dict["hidden_check_x_{0}".format(num_hidden_layers-2)], W_output_var)
y_output = tf.add(x, y_output_1)
y_output = tf.to_float(y_output, name="ToFloat")

with tf.name_scope("xent") as scope:

    # cross entropy loss
    arg_loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=y_output,labels=1-y)
    for i in range(0, num_hidden_layers-1, 1):
        if (i % multiloss == 0):
            arg_loss = arg_loss + tf.nn.sigmoid_cross_entropy_with_logits(logits=net_dict["hidden_check_output_{0}".format(i)],labels=1-y)

    loss = tf.reduce_mean(arg_loss)


train_step = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(loss)

##################################  Train  ####################################
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_mem_fraction)
sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True))
sess.run(tf.global_variables_initializer())
merged = tf.summary.merge_all()

saver = tf.train.Saver()
f_results = open(results_path, 'w+')
train_loss_vec = np.zeros(1,dtype=np.float32)
val_loss_vec = np.zeros(1,dtype=np.float32)

# load model
if load_weights:
    saver.restore(sess, weights_path)

temp_batch = np.zeros([1, code_n], dtype=np.float32)
loss_trn = np.zeros([1, 1], dtype=np.float32)
max_weights_init = 0

for i in range(num_of_batch):

    if i !=0 :
        training_data, training_labels = create_mix_epoch(scaling_factor, wordRandom, numOfWordSim_train, code_n, code_k, code_generatorMatrix, is_zeros_word=train_on_zero_word)

        # train
        y_train, train_loss, _ = sess.run(fetches=[y_output, loss, train_step], feed_dict={x: training_data, y: training_labels})
        loss_trn = np.vstack((loss_trn, train_loss))


    if(i%batch_in_epoch == 0):

        print('Finish Epoch - ', i/batch_in_epoch)

        # validation
        y_v = np.zeros([1,code_n], dtype=np.float32)
        y_v_pred = np.zeros([1,code_n], dtype=np.float32)
        loss_v = np.zeros([1, 1], dtype=np.float32)

        for kk, k_sf in enumerate(scaling_factor):
            for j in range(batches_for_val_per_snr[kk]):

                x_v_j, y_v_j = create_mix_epoch([k_sf], wordRandom, batch_size, code_n, code_k, code_generatorMatrix, is_zeros_word=test_on_zero_word)
                y_v_pred_j, loss_v_j = sess.run(fetches = [y_output, loss], feed_dict={x:x_v_j, y:y_v_j})

                y_v = np.vstack((y_v,y_v_j))
                y_v_pred = np.vstack((y_v_pred,y_v_pred_j))
                loss_v = np.vstack((loss_v, loss_v_j))

        y_v_pred = 1.0 / (1.0 + np.exp(-1.0 * y_v_pred))

        ber_val, fer_val = calc_ber_fer(snr_db, y_v_pred[1:,:], y_v[1:,:], batch_size*batches_for_val_per_snr)

        # print & write to file
        print('SNR[dB] validation - ', snr_db)
        print('BER validation - ', ber_val)
        print('FER validation - ', fer_val)

        f_results.write('-' * 50 + '\n')
        f_results.write('epoch %d\n' % i)
        f_results.write('----- SNR[dB]:' + str(snr_db) + '\n')
        f_results.write('----- BER:' + str(ber_val) + '\n')
        f_results.write('----- FER:' + str(fer_val) + '\n')
        f_results.flush()

        # save weights
        saver.save(sess, weights_path_format % i)


        # print learning curve
        train_loss_vec = np.vstack((train_loss_vec, np.mean(loss_trn[1:])))
        val_loss_vec = np.vstack((val_loss_vec, np.mean(loss_v[1:])))
        pltf.figure()
        pltf.plot(train_loss_vec[1:])
        pltf.plot(val_loss_vec[1:], color='red')
        pltf.xlabel('epoch')
        pltf.ylabel('loss')
        pltf.legend(('Train', 'Validation'))
        pltf.grid(True)
        pltf.savefig(dir_name + '/Learning_curve.png')
        pltf.close()
        loss_trn = np.zeros([1, 1], dtype=np.float32)
