import numpy as np
import tensorflow as tf
import argparse
import loss

import scipy.io.wavfile as wav

import time
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
import sys
from collections import namedtuple
sys.path.append("../../new/DeepSpeech")
import DeepSpeech

from tf_logits import get_logits
from sklearn.metrics import roc_curve, auc

# These are the tokens that we're allowed to use.
# The - token is special and corresponds to the epsilon
# value in CTC decoding, and can not occur in the phrase.
toks = " abcdefghijklmnopqrstuvwxyz'-"



def main():
    parser = argparse.ArgumentParser(description=None)
    args = parser.parse_args()
    while len(sys.argv) > 1:
        sys.argv.pop()
    with tf.Session() as sess:
        MAX = (1<<15)*10
        new_input = tf.placeholder(tf.float32, [1, MAX])
        lengths = tf.placeholder(tf.int32, [1])

        with tf.variable_scope("", reuse=tf.AUTO_REUSE):
            logits = get_logits(new_input, lengths)

        saver = tf.train.Saver()
        saver.restore(sess, "deepspeech-0.4.1-checkpoint/model.v0.4.1")

        decoded, _ = tf.nn.ctc_beam_search_decoder(logits, lengths, merge_repeated=False, beam_width=100)

        def decode(audio, K=1):
            #audio = audio[:int(len(audio)*K)]
            aa = list(audio)
            aa = aa + [0] * (MAX-len(aa))
            aa = np.array(aa)
            length = int((len(audio)-1)//320*K)
            l = len(aa)
            r = sess.run(decoded, {new_input: [aa],
                                   lengths: [length]})
            return "".join([toks[x] for x in r[0].values])


        num_samples = 100 + len(os.listdir("adaptive"))
        y_test = np.zeros(num_samples * 2)
        roc_auc = np.zeros(3)
        TD = np.zeros((3, num_samples * 2), dtype = np.float32)
        count = 0
        ss1, ss2, ss3 = 0, 0, 0
        sss1, sss2, sss3 = 0, 0, 0
        ratio = 0.25

        for i, f in enumerate(sorted(os.listdir("adaptive"), key=lambda x: int(x.split("_")[1].split(".")[0]))):
            z, w = wav.read("adaptive/"+f)
            
            strw = decode(w, 1)
            halfw = decode(w, ratio)
            
            s1 = loss.newWER(strw, halfw)
            s2 = loss.newCER(strw, halfw)
            s3 = loss.lcp(strw, halfw)
            print ("WER: " + str(s1) + " CER: " + str(s2) + " LCP: " + str(s3))
            print ("Adv: " + strw)
            print ("Half of Adv: " + halfw)
    
            sss1 += s1
            sss2 += s2
            sss3 += s3
            y_test[count] = 1
            TD[0][count] = float(s1)
            TD[1][count] = float(s2)
            TD[2][count] = float(s3)
            count += 1
    
    
            #print ("WER: " + str(s1) + " CER: " + str(s2) + " LCP: " + str(s3))

        for epoch in range(100):
            print("HAVE", epoch)
            x, y = wav.read("commonvoice_subset/sample-%06d" % (epoch) + ".wav")

            stry = decode(y, 1)
            halfy = decode(y, ratio)
            #ratio = np.random.random_sample() * 0.6 + 0.2 
            #ratio = (numcut) * 1.0 / (numcut - 1)
    
            #print ("Origin: " + stry)
            #print ("Half of Origin: " + halfy)
            s1 = loss.newWER(stry, halfy)
            s2 = loss.newCER(stry, halfy)
            s3 = loss.lcp(stry, halfy)
    
            ss1 += s1
            ss2 += s2
            ss3 += s3
            print ("WER: " + str(s1) + " CER: " + str(s2) + " LCP: " + str(s3))
            print ("Adv: " + stry)
            print ("Half of Adv: " + halfy)
            y_test[count] = 0
            TD[0][count] = float(s1)
            TD[1][count] = float(s2)
            TD[2][count] = float(s3)
    
            count += 1


        
        print(ss1 / 20, ss2 / 20, ss3 / 20)
        print(sss1 / 20, sss2 / 20, sss3 / 30)
        for i in range(3):
            if (i == 2):
                y_test = 1 - y_test
            fpr, tpr, threshold = roc_curve(y_test, TD[i])
            roc_auc[i] = auc(fpr, tpr)
    
        print ("WER: " + str(roc_auc[0]) + " CER: " + str(roc_auc[1]) + " LCP: " + str(roc_auc[2]))
            
            
main()



"""

import numpy as np
import tensorflow as tf
import argparse

import scipy.io.wavfile as wav

import time
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '2'
import sys
from collections import namedtuple
sys.path.append("../../new/DeepSpeech")
import DeepSpeech

from tf_logits import get_logits


# These are the tokens that we're allowed to use.
# The - token is special and corresponds to the epsilon
# value in CTC decoding, and can not occur in the phrase.
toks = " abcdefghijklmnopqrstuvwxyz'-"



def main():
    parser = argparse.ArgumentParser(description=None)
    args = parser.parse_args()
    while len(sys.argv) > 1:
        sys.argv.pop()
    with tf.Session() as sess:
        def decode(audio, K):
            audio = audio[:int(len(audio)*K)]
            N = len(audio)
            new_input = tf.placeholder(tf.float32, [1, N])
            lengths = tf.placeholder(tf.int32, [1])
            
            with tf.variable_scope("", reuse=tf.AUTO_REUSE):
                logits = get_logits(new_input, lengths)
                
            saver = tf.train.Saver()
            saver.restore(sess, "deepspeech-0.4.1-checkpoint/model.v0.4.1")
            
            decoded, _ = tf.nn.ctc_beam_search_decoder(logits, lengths, merge_repeated=False, beam_width=500)

            aa = list(audio)
            aa = np.array(aa)
            length = (len(aa)-1)//320
            l = len(aa)
            r = sess.run(decoded, {new_input: [aa],
                                   lengths: [length]})
    
            return "".join([toks[x] for x in r[0].values])

        

main()
"""
