# -*- coding: utf-8 -*-
"""
asr-ns + llm-fd + tts-ns(s)
"""

import argparse
import json
import threading
import time
import sys

import numpy as np

from concurrent.futures import ThreadPoolExecutor
from time import sleep
from queue import Queue
from typing import Optional, List, Dict

from chatbot.utils.tts_utils import Speaker, text2speech, text2speech_stream
from chatbot.utils.asr_utils import Listener, speech2text
from chatbot.utils.llm_utils import chat

import logging

from chatbot.utils.llm_utils import remove_last_messages
from loguru import logger


def play_audio(speaker, play_audio_queue, play_audio_queue_lock, run_event, start_record_audio_event, replay_text_queue):
    logger.info("play_audio prepared")
    while run_event.is_set():
        while play_audio_queue.empty() or play_audio_queue_lock.locked():
            sleep(0.1)
            if not run_event.is_set():
                return
        logger.info(f"play_audio get play_audio_queue_lock")
        with play_audio_queue_lock:
            audio_data, text_data = play_audio_queue.get()
        logger.info(f"play_audio get audio_data done")
        try:
            logger.info(f"play length: {len(audio_data)}")
            speaker.play(audio_data)
            replay_text_queue.put(text_data.strip())
        except Exception as e:
            logger.error("error play.")
            sleep(1.0)
        print("play done")
        if play_audio_queue.empty():
            sleep(0.5)
            start_record_audio_event.set()


def stop_play_audio(force_stop_play_audio_event, 
                    play_audio_queue_lock, 
                    play_audio_queue, 
                    start_record_audio_event, 
                    run_event):
    while run_event.is_set():
        force_stop_play_audio_event.wait()
        with play_audio_queue_lock:
            logger.info("Begin to clear play_audio_queue.")
            while not play_audio_queue.empty():
                play_audio_queue.get()
        force_stop_play_audio_event.clear()
        sleep(1.0)
        start_record_audio_event.set()


def record_audio(licenser, record_audio_queue, start_record_audio_event, stop_record_audio_event, run_event):
    logger.info("Begin record_audio process.")
    while run_event.is_set():
        logger.info("await record_audio")
        start_record_audio_event.wait()
        start_record_audio_event.clear()
        logger.info("start record_audio")
        while run_event.is_set():
            logger.info("record 0.5s audio")
            audio_data = licenser.stream.read(1200*10, exception_on_overflow=False)
            record_audio_queue.put(audio_data)
            if stop_record_audio_event.is_set():
                break
        logger.info("stop record_audio")
        stop_record_audio_event.clear()
        sleep(2.0)
        start_record_audio_event.set()


def stop_record_audio(force_stop_record_audio_event, 
                      stop_record_audio_event, 
                      record_audio_queue, 
                      asr_audio_queue, 
                      start_record_audio_event,
                      run_event, 
                      llm_query_queue,
                      llm_answer_queue,
                      reset_process_audio_event):
    while run_event.is_set():
        force_stop_record_audio_event.wait()
        stop_record_audio_event.set()
        logger.info("begin to reset")
        while not record_audio_queue.empty():
            record_audio_queue.get()
        while not asr_audio_queue.empty():
            asr_audio_queue.get()
        while not llm_query_queue.empty():
            llm_query_queue.get()
            
        reset_process_audio_event.set()
        sleep(0.5)
        
        logger.info("re begin to reset")
        
        while not record_audio_queue.empty():
            record_audio_queue.get()
        while not asr_audio_queue.empty():
            asr_audio_queue.get()
        while not llm_query_queue.empty():
            llm_query_queue.get()
        force_stop_record_audio_event.clear()
        
        start_record_audio_event.set()
        logger.info("stop to reset")


def process_audio_record(licenser, 
                         record_audio_queue, 
                         record_audio_queue_lock, 
                         start_record_audio_event, 
                         run_event, 
                         force_stop_play_audio_event,
                         stop_record_audio_event,
                         asr_audio_queue, reset_process_audio_event):
    logger.info("Begin process_audio_record process.")
    start_record_audio_event.set()
    
    audio_records = b""
    last_audio_records = audio_records
    cnt, flag = 0, False
    
    while run_event.is_set():
        logger.info("start process audio record")
        while run_event.is_set():
            while record_audio_queue.empty() or record_audio_queue_lock.locked():
                if reset_process_audio_event.is_set():
                    logger.info(f"reset audio_records")
                    audio_records = b""
                    cnt, flag = 0, False
                    sleep(0.5)
                    reset_process_audio_event.clear()
                sleep(0.1)
                if flag:
                    cnt += 0.2
                    if cnt > 2:
                        logger.info(f"put again: {cnt}")
                        t = licenser.get_wav_header(len(audio_records)) + audio_records
                        asr_audio_queue.put(t)
                        last_audio_records = t
                        cnt = 0
                if not run_event.is_set():
                    return
            
            with record_audio_queue_lock:
                audio_record = record_audio_queue.get()
            
            data = np.fromstring(audio_record, dtype=np.short)
            level = np.percentile(data, 99.9)
            
            level_limit = 99
            
            if level > level_limit and not flag:
                flag = True
            if flag and level > level_limit: 
                audio_records += audio_record
                logger.info("reset process audio record")
                t = licenser.get_wav_header(len(audio_records)) + audio_records
                if last_audio_records != t:
                    asr_audio_queue.put(t)
                    last_audio_records = t
                    cnt = 0
            elif flag:
                t = licenser.get_wav_header(len(audio_records)) + audio_records
                cnt += 1
                if last_audio_records == t and cnt % 2 == 0:
                    asr_audio_queue.put(t)
                    last_audio_records = t
                    cnt = 0
            if reset_process_audio_event.is_set():
                logger.info(f"reset audio_records")
                audio_records = b""
                cnt, flag = 0, False
                sleep(0.5)
                reset_process_audio_event.clear()


def asr_process(asr_audio_queue, 
                llm_query_queue, 
                run_event):
    logger.info("Begin asr process")
    while run_event.is_set():
        while asr_audio_queue.empty():
            sleep(0.1)
            if not run_event.is_set():
                return
        logger.info("start asr process")
        while not asr_audio_queue.empty():
            audio_data = asr_audio_queue.get()
        logger.info(f"audio_data lenasr_audio_queue: {len(audio_data)}")
        if len(audio_data) < 22000:
            continue
        text = speech2text(speech_bytes=audio_data)
        logger.info(f"asr get query: {text}")
        llm_query_queue.put(text)
        logger.info("stop asr process")


def llm_process(llm_query_queue, 
                llm_answer_queue, 
                run_event, replay_text_queue,
                force_stop_play_audio_event,
                force_stop_record_audio_event,
                benchmark_record_queue=None):
    logger.info("Begin llm process")
    last_query = ""
    cnt = 0
    while run_event.is_set():
        while llm_query_queue.empty():
            sleep(0.1)
            if not run_event.is_set():
                return
        logger.info("start llm process")
        while not llm_query_queue.empty():
            query = llm_query_queue.get()
        
        if query == last_query:
            cnt += 1
        last_query = query
        
        if cnt < 1:
            query += "<incomplete>"
        else:
            query += "<finished>"
            cnt = 0
        
        replays = list()
        while not replay_text_queue.empty():
            replays.append(replay_text_queue.get())
        logger.info(f"llm process get query: {query}")
        start = True
        for sentence in chat(query=query, replay=" ".join(replays)):
            logger.info(f"llm answer: {sentence}")
            if sentence.strip() == "<wait>":
                remove_last_messages()
                break
            if start:
                force_stop_play_audio_event.set()
                force_stop_record_audio_event.set()
                start = False
                if benchmark_record_queue is not None:
                    benchmark_record_queue.put({"type": "llm-query", "query": query, "time_stamp": time.time()})
            llm_answer_queue.put(sentence)
            cnt = 0
        if not start:
            while not llm_query_queue.empty():
                llm_query_queue.get()
            sleep(1)
        logger.info("stop llm process")


def tts_process(llm_answer_queue, 
                play_audio_queue, 
                run_event,
                stream_tts=False):
    logger.info("Begin tts process")
    while run_event.is_set():
        while llm_answer_queue.empty():
            sleep(0.1)
            if not run_event.is_set():
                return
        logger.info("start tts process")
        text = llm_answer_queue.get()
        logger.info(f"get tts request: {text}")
        
        if stream_tts:
            # stream tts
            text_pieces = text.split(" ")
            length = 0
            for audio_data in text2speech_stream(text):
                length += len(audio_data)
                if length > 22000 and len(text_pieces) > 0:
                    cur = text_pieces.pop(0)
                    length = 0
                else:
                    cur = ""
                logger.info(f"put play_audio_queue: {len(audio_data)}")
                play_audio_queue.put((audio_data, cur))
        else:
            ## non-stream tts
            try:
                audio_data = text2speech(text)
            except Exception:
                continue
            text_pieces = text.split(" ")
            audio_piece_size = 22000
            while audio_data:
                if len(audio_data) < audio_piece_size * 1.5:
                    play_audio_queue.put((audio_data, " ".join(text_pieces)))
                    audio_data = b''
                else:
                    pos = int(len(text_pieces) * (audio_piece_size / len(audio_data)))
                    play_audio_queue.put((audio_data[:audio_piece_size], " ".join(text_pieces[:pos])))
                    audio_data = audio_data[audio_piece_size:]
                    text_pieces = text_pieces[pos:]
        logger.info("stop tts process")


def exit(run_event,
         force_stop_play_audio_event,
         start_record_audio_event,
         stop_record_audio_event,
         force_stop_record_audio_event):
    run_event.clear()
    force_stop_play_audio_event.set()
    start_record_audio_event.set()
    stop_record_audio_event.set()
    force_stop_record_audio_event.set()

