import tensorflow as tf
import numpy as np
from Bio import SeqIO
import sys
from helper import *
import tensorflow.keras.backend as K
from tensorflow.keras.layers import Layer, InputSpec

# evaluate the recall rate and reconstruction rate
def recall_reconstruction_rate(Recovered_Haplo, SNVHaplo, region_len):
    distance_table = np.zeros((len(Recovered_Haplo), len(SNVHaplo)))
    for i in range(len(Recovered_Haplo)):
        for j in range(len(SNVHaplo)):
            distance_table[i, j] = hamming_distance(SNVHaplo[j, :], Recovered_Haplo[i, :])
    min_rank = min(len(Recovered_Haplo), len(SNVHaplo))
    
    index = list(permutations(list(range(Recovered_Haplo.shape[0]))))
    distance = []
    for item in index:
        count = 0
        for i in range(min_rank):
            count += distance_table[item[i], i]
        distance.append(count)
    index = index[np.argmin(np.array(distance))]

    reconstruction_rate = [0] * len(SNVHaplo)
    for i in range(min_rank):
        reconstruction_rate[i] = 1 - distance_table[index[i], i] / region_len
    CPR = np.mean(reconstruction_rate)
    
    return reconstruction_rate, CPR

# load SNV fragment matrix and convert each SNV to an one-hot vector
zone_name = sys.argv[1]
SNVmatrix_name = zone_name + '_SNV_matrix.txt'
SNVmatrix, SNVonehot = import_SNV(SNVmatrix_name)
n_read, _, n_SNV, _ = SNVonehot.shape
estimated_population_size = {'p17': 5, 'p24': 5, 'p2p6' : 5, 'PR' : 5, 'RT' : 6, 'RNase' : 5, 'int' : 5, 'vif' : 5, 'vpr' : 5, 'vpu' : 6, 'gp120' : 5, 'gp41' : 5, 'nef' : 5}
n_clusters = estimated_population_size[zone_name]

class ClusteringLayer(Layer):
    def __init__(self, n_clusters = n_clusters, weights = None, alpha = 1.0, **kwargs):
        if 'input_shape' not in kwargs and 'input_dim' in kwargs:
            kwargs['input_shape'] = (kwargs.pop('input_dim'),)
        super(ClusteringLayer, self).__init__(**kwargs)
        self.n_clusters = n_clusters
        self.alpha = alpha
        self.initial_weights = weights
        self.input_spec = InputSpec(ndim = 2)

    def build(self, input_shape):
        assert len(input_shape) == 2
        input_dim = int(input_shape[1])
        self.input_spec = InputSpec(dtype = K.floatx(), shape = (None, input_dim))
        self.clusters = self.add_weight(shape = (self.n_clusters, input_dim), initializer = 'glorot_uniform', name = 'clusters')
        if self.initial_weights is not None:
            self.set_weights(self.initial_weights)
            del self.initial_weights
        self.built = True

    def call(self, inputs, **kwargs):
        q = 1.0 / (1.0 + (K.sum(K.square(K.expand_dims(inputs, axis = 1) - self.clusters), axis = 2) / self.alpha))
        q **= (self.alpha + 1.0) / 2.0
        q = K.transpose(K.transpose(q) / K.sum(q, axis = 1))
        
        return q

model = tf.keras.models.load_model(zone_name + '_done.h5', custom_objects = {'ClusteringLayer':ClusteringLayer})
_, q = model.predict(SNVonehot)
y_pred = q.argmax(1)
haplotypes = origin2haplotype(y_pred, SNVmatrix, n_clusters)
# correction
pre_mec = 0
mec = MEC(SNVmatrix, haplotypes)
count = 0
while mec != pre_mec:
    index = []
    for i in range(SNVmatrix.shape[0]):
        dis = np.zeros((haplotypes.shape[0]))
        for j in range(haplotypes.shape[0]):
            dis[j] = hamming_distance(SNVmatrix[i, :], haplotypes[j, :])
        index.append(np.argmin(dis))

    new_haplo = np.zeros((haplotypes.shape))
    for i in range(haplotypes.shape[0]):
        new_haplo[i, :] = np.argmax(ACGT_count(SNVmatrix[np.array(index) == i, :]), axis = 1) + 1
    haplotypes = new_haplo.copy()
    pre_mec = mec
    mec = MEC(SNVmatrix, haplotypes)
    count += 1

region_matrix = [list(range(790 - 1, 1186)),  # 0 : p17
                list(range(1186 - 1, 1879)),  # 1 : p24
                list(range(1879 - 1, 2292)),  # 2 : p2p6
                list(range(2253 - 1, 2550)),  # 3 : PR
                list(range(2550 - 1, 3870)),  # 4 : RT
                list(range(3870 - 1, 4230)),  # 5 : RNase
                list(range(4230 - 1, 5096)),  # 6 : int
                list(range(5041 - 1, 5620)),  # 7 : vif
                list(range(5559 - 1, 5850)),  # 8 : vpr
                list(range(6062 - 1, 6310)),  # 9 : vpu
                list(range(6225 - 1, 7758)),  # 10 : gp120
                list(range(7758 - 1, 8795)),  # 11 : gp41
                list(range(8797 - 1, 9417))]  # 12 : nef
                
gene_dic = {'p17': 0, 'p24': 1, 'p2p6' : 2, 'PR' : 3, 'RT' : 4, 'RNase' : 5, 'int' : 6, 'vif' : 7, 'vpr' : 8, 'vpu' : 9,
            'gp120' : 10, 'gp41' : 11, 'nef' : 12}

SNVposition_name = zone_name + '_SNV_pos.txt'

with open(SNVposition_name, 'r') as f:
    SNVposition_list = f.readlines()

SNVposition = np.fromstring(SNVposition_list[0][0:-1:1], dtype = int, sep = ' ')

# ground truth strains
fasta_filename = '5VirusMixReference.fasta'
fasta_list = []
for record in SeqIO.parse(fasta_filename, 'fasta'):
    fasta_list.append(str(record.seq))

SNVHaplo = list2array(fasta_list)[:, region_matrix[gene_dic[zone_name]]][:, SNVposition] 

reconstruction_rate, CPR = recall_reconstruction_rate(new_haplo, SNVHaplo, len(region_matrix[gene_dic[zone_name]]))
print('Reconstruction Rate: {}\n'.format(reconstruction_rate))
print('CPR: {}'.format(CPR))
