# waiServer.py : whisper daemon server code.
import pika
import time
import os
import signal
import psutil
import math
import numpy as np
import struct
import array
import logging
import sys
from daemonize import Daemonize
from logging.handlers import TimedRotatingFileHandler
import traceback
from http.server import BaseHTTPRequestHandler, HTTPServer
import threading
import json
import requests

import waiMsg_pb2
import waiWorkerMsg_pb2
import waiUtil
from waiSessionTable import SessionTable
from waiSessionTable import SessionError
from waiSessionTable import SessionStatus
from waiConstants import SERVER_INBOX
from waiConstants import FFMPEG_INBOX_PREFIX
from waiConstants import WORKER_INBOX
from waiConstants import WORKER_OUTBOX
from waiConstants import FFMPEG_PID_CHECK_INTERVAL
from waiConstants import FFMPEG_TIMEBASE
from waiConstants import AI_TOKEN_MODE
from waiConstants import CHUNK_CUTOFF_PTS #= 8000 # 0.5 sec
from waiConstants import LOG_FILE_SIZE_LIMIT 
from waiConstants import LOG_DIR_SIZE_CHECK_INTERVAL
from waiConstants import NO_SPEECH_PROB_MAX
from waiConstants import AVG_LOGPROB_MIN
from waiConstants import COMPRESSION_RATIO_MAX
from waiConstants import SUPPRESS_LOW
from waiConstants import SUPPRESS_HIGH
from waiConstants import GARBAGE_LIST
from waiConstants import HTTP_PORT
from waiConstants import FFMPEG_PID_CHECK_URL


session_table = SessionTable()
server_inbox_name = f'{SERVER_INBOX}'
worker_inbox_name = f'{WORKER_INBOX}'
worker_outbox_name = f'{WORKER_OUTBOX}'

def process_open_session(channel, wai_msg):
    #sender_id = f'{wai_msg.sender_pid}_{wai_msg.sender_uid}'
    sender_id = session_table.get_sender_id_from_pid_uid(wai_msg.sender_pid, wai_msg.sender_uid)
    sender_pid = wai_msg.sender_pid
    session_id, error_code = session_table.open_session(sender_id, sender_pid)

    outbox_queue = FFMPEG_INBOX_PREFIX.format(sender_id)
    channel.queue_declare(queue=outbox_queue)
    result_msg = waiMsg_pb2.waiResponseMsg()
    result_msg.msg_type = waiMsg_pb2.waiMsgType.OPEN_SESSION 
    if error_code == SessionError.NO_ERROR:
        logging.info(f"Opened session: {session_id} for sender: {sender_id}")
        result_msg.result = waiMsg_pb2.waiResultType.SUCCESS
    else:
        logging.error(f"Failed to open session for sender: {sender_id}, Error Code: {error_code}")
        result_msg.result = waiMsg_pb2.waiResultType.ERROR_GENERAL

    channel.basic_publish(exchange='amq.direct', routing_key=outbox_queue, body=result_msg.SerializeToString())
    waiUtil.print_wai_msg(f"Sent to {sender_id}", result_msg)

def process_close_session(channel, wai_msg):
    sender_id = session_table.get_sender_id_from_pid_uid(wai_msg.sender_pid, wai_msg.sender_uid)
    session_id, error_code = session_table.get_session_id_from_sender_id(sender_id)

    if session_id is not None:
        session_table.close_session(session_id)
        outbox_queue = FFMPEG_INBOX_PREFIX.format(sender_id)
        channel.queue_delete(queue=outbox_queue)
        logging.info(f"Closed session: {session_id} for sender: {sender_id}")
    else:
        logging.error(f"Failed to close session (Invalid sender_id: {sender_id} ({error_code})")

def build_worker_chunk_msg(session_id, wai_msg):

    wai_worker_msg = waiWorkerMsg_pb2.waiWorkerCommandMsg()
    wai_worker_msg.session_id = session_id
    wai_worker_msg.chunk_offset = wai_msg.chunk_pts
    #wai_worker_msg.chunk_data.extend(wai_msg.chunk_data)
    
    #logging.info(f'progress_count:{progress_count}')
    #logging.info(f'wai_msg.chunk_data:{wai_msg.chunk_data[:40]}')
    #logging.info(f'wai_msg.chunk_data len:{len(wai_msg.chunk_data)}')
            
    # Create an array to hold the int16 values
    chunk_data_int16 = array.array('h')
            
    chunk_data_int16.frombytes(wai_msg.chunk_data) 
    #logging.info(f'chunk_data_int16: {chunk_data_int16[:40]}')  
            
    # Assign the list of integers to wai_worker_msg.chunk_data       
    wai_worker_msg.chunk_data.extend(chunk_data_int16)
    #logging.info(f'wai_worker_msg.chunk_data:{wai_worker_msg.chunk_data[:40]}')

    #for multiple worker
    wai_worker_msg.chunk_index = session_table.get_worker_chunk_data_send_count(session_id)
    session_table.increase_worker_chunk_data_send_count(session_id)
    return wai_worker_msg

def process_chunk_msg(channel, wai_msg):
    sender_id = session_table.get_sender_id_from_pid_uid(wai_msg.sender_pid, wai_msg.sender_uid)
    session_id, error_code = session_table.get_session_id_from_sender_id(sender_id)

    if error_code == SessionError.NO_ERROR:
        progress_count, error_code = session_table.increase_chunk_data_in_progress_count(session_id)
        if ((AI_TOKEN_MODE == False) or (progress_count == 1)): # the only packet, so send it out.
            wai_worker_msg = build_worker_chunk_msg(session_id, wai_msg)
            if (AI_TOKEN_MODE == True):
                tokens, error_code = session_table.get_latest_tokens(session_id)
                wai_worker_msg.tokens.extend(tokens)

            channel.basic_publish(exchange='', routing_key=worker_inbox_name, body=wai_worker_msg.SerializeToString())
            waiUtil.print_wai_msg(f"Sent to workers ({progress_count})", wai_worker_msg)
        else:
            session_table.push_chunk_data(session_id, wai_msg)
    else:
        logging.error(f"Failed to process chunk message (Invalid sender_id: {sender_id} ({error_code})")

def send_eos(channel, session_id):
    sender_id, error_code = session_table.get_sender_id(session_id)
    if sender_id is not None:
        eos_msg = waiMsg_pb2.waiResponseMsg()
        eos_msg.msg_type = waiMsg_pb2.waiMsgType.EOS
        eos_msg.result = waiMsg_pb2.waiResultType.SUCCESS
        
        outbox_queue = FFMPEG_INBOX_PREFIX.format(sender_id)
        channel.basic_publish(exchange='amq.direct', routing_key=outbox_queue, body=eos_msg.SerializeToString())
        waiUtil.print_wai_msg(f"Sent EOS to {outbox_queue}", eos_msg)
    else:
        logging.error(f"Failed to send EOS (Invalid session_id: {session_id})")

def process_eos_msg(channel, wai_msg):
    sender_id = session_table.get_sender_id_from_pid_uid(wai_msg.sender_pid, wai_msg.sender_uid)
    session_id, error_code = session_table.get_session_id_from_sender_id(sender_id)

    if error_code == SessionError.NO_ERROR:
        error_code = session_table.set_is_eos_received(session_id, 1)
        count, error_code = session_table.get_chunk_data_in_progress_count(session_id)
        if count == 0:
            # Call send_eos only if the count is 0
            send_eos(channel, session_id)
    else:
        logging.error(f"Failed to process eos (Invalid sender_id: {sender_id})")


def check_process_status(current_time):
    logging.info(f"Checking Process...")
    for session_index, entry in enumerate(session_table.sessions):
        sender_pid = entry.get("sender_pid")
        if sender_pid is not None:
            alive_timestamp = entry.get("alive_timestamp")
            alive_delta = current_time - alive_timestamp
            logging.info(f"Checking Process with PID {sender_pid:x} - {alive_delta:.0f}s is alive?")
            if alive_delta >= FFMPEG_PID_CHECK_INTERVAL:
                try:
                    response = requests.get(f"{FFMPEG_PID_CHECK_URL}/checkpid", params={'pid': sender_pid})
                    #logging.info(f"Check {sender_pid} from {FFMPEG_PID_CHECK_URL} got status {response.status_code}")
                    if response.status_code != 200:
                        # Process is not alive, close the session
                        logging.info(f"Process with PID {sender_pid:x} is not alive. Closing session.")
                        session_table.close_session(session_index)
                except requests.RequestException as e:
                    logging.info(f"An error occurred doing checkpid: {e}")

def translate_worker_response(worker_response):
    session_id = worker_response.session_id
    result = worker_response.result
    chunk_offset = worker_response.chunk_offset
    phrases = worker_response.caption_return.phrases

    response_msgs = []

    if result != waiMsg_pb2.waiResultType.SUCCESS:
        # If result is not SUCCESS, create one response message with the error
        response_msg = waiMsg_pb2.waiResponseMsg(
            msg_type=waiMsg_pb2.waiMsgType.CHUNK,
            result=result,
            pts=0,
            len=0,
            caption_return=""
        )
        response_msgs.append(response_msg)
    else:
        # If result is SUCCESS, create response messages for each phrase
        cutoff_pts = worker_response.chunk_len - CHUNK_CUTOFF_PTS 
        for phrase in phrases:
            #start_pts = chunk_offset + phrase.start * FFMPEG_TIMEBASE
            #end_pts = chunk_offset + phrase.end * FFMPEG_TIMEBASE
            start_pts = round(phrase.start * FFMPEG_TIMEBASE)
            end_pts = round(phrase.end * FFMPEG_TIMEBASE)
            #duration = (phrase.end - phrase.start) * FFMPEG_TIMEBASE
            duration = end_pts - start_pts
            caption_return = phrase.text
            no_speech_prob = phrase.no_speech_prob
            avg_logprob = phrase.avg_logprob
            compression_ratio = phrase.compression_ratio
            #logging.info(f'Checking {caption_return}, start({start_pts}), end({end_pts}), chunk_len({worker_response.chunk_len/FFMPEG_TIMEBASE:.2f}s), {no_speech_prob:.2f}, {avg_logprob:.2f}, {compression_ratio:.2f}')
            
            is_garbage = True
            for w in caption_return.split(" "):
                if w.strip() == "":
                    continue
                if w.strip() in GARBAGE_LIST:
                    continue
                else:
                    is_garbage = False
                    break
            if is_garbage:
                logging.error(f'Caption: {caption_return}, DROP GARBAGE_LIST)')
                continue

            # Reduce log probability for certain words/phrases
            for s in SUPPRESS_LOW:
                if s in caption_return:
                    avg_logprob -= 0.15
            for s in SUPPRESS_HIGH:
                if s in caption_return:
                    avg_logprob -= 0.35

            if avg_logprob < AVG_LOGPROB_MIN or no_speech_prob > NO_SPEECH_PROB_MAX or compression_ratio > COMPRESSION_RATIO_MAX:
                logging.error(f'Caption: {caption_return}, DROP no_speech_prob({no_speech_prob:.2f}), avg_logprob({avg_logprob:.2f}), compression_ratio({compression_ratio:.2f})')
                continue
            elif (start_pts > worker_response.chunk_len) or ((cutoff_pts > 0) and (start_pts > cutoff_pts)):
                logging.error(f'Caption: {caption_return}, DROP start({start_pts}), end({end_pts}), chunk_len({worker_response.chunk_len/FFMPEG_TIMEBASE:.2f}s)')
                continue
            else:
                if (start_pts != 0) and (start_pts < CHUNK_CUTOFF_PTS):
                    logging.error(f'Caption: {caption_return}, clear start_pts({start_pts})')
                    start_pts = chunk_offset
                else:
                    start_pts = start_pts + chunk_offset
                response_msg = waiMsg_pb2.waiResponseMsg(
                    msg_type=waiMsg_pb2.waiMsgType.CHUNK,
                    result=waiMsg_pb2.waiResultType.SUCCESS,
                    pts=np.int64(start_pts),
                    len=np.int64(duration),
                    caption_return=caption_return
                )
                response_msgs.append(response_msg)

    return response_msgs

def main(log_directory):
    #waiUtil.delete_queue("inbox")
    #waiUtil.delete_queue("worker_inbox")
    #waiUtil.delete_queue("worker_outbox")
    # Log product version and commit ID
    product_version = os.getenv('PRODUCTVERSION', 'unknown')
    commit_id = os.getenv('COMMITID', 'unknown')
    logging.info(f"Product version: {product_version}, Commit ID: {commit_id}")

    logging.info(f'waiServer start....')
    connection = pika.BlockingConnection(pika.ConnectionParameters('localhost'))
    channel = connection.channel()
    channel.queue_declare(queue=server_inbox_name)
    channel.queue_bind(queue=server_inbox_name, exchange='amq.direct')
    channel.queue_declare(queue=worker_inbox_name)
    channel.queue_declare(queue=worker_outbox_name)
 
    waiUtil.flush_queue(channel, server_inbox_name)
    waiUtil.flush_queue(channel, worker_inbox_name)
    waiUtil.flush_queue(channel, worker_outbox_name)

    try:
        last_pid_check_time = time.time()
        last_log_dir_size_check_time = 0

        while True:

            current_time = time.time()

            # Check log dir size periodically
            if current_time - last_log_dir_size_check_time >= LOG_DIR_SIZE_CHECK_INTERVAL:
                # delete old log files if necessary
                waiUtil.delete_old_log_files(log_directory, LOG_FILE_SIZE_LIMIT)  # Max total size in bytes (5G)
                last_log_dir_size_check_time = current_time

            # Check process status periodically
            if current_time - last_pid_check_time >= FFMPEG_PID_CHECK_INTERVAL:
                check_process_status(current_time)
                last_pid_check_time = current_time

            method_frame, header_frame, body = channel.basic_get(queue=server_inbox_name, auto_ack=True)
            if method_frame:
                wai_msg = waiMsg_pb2.waiCommandMsg()
                wai_msg.ParseFromString(body)
                waiUtil.print_wai_msg("Received", wai_msg)

                if wai_msg.msg_type == waiMsg_pb2.waiMsgType.OPEN_SESSION:
                    process_open_session(channel, wai_msg)
                elif wai_msg.msg_type == waiMsg_pb2.waiMsgType.CLOSE_SESSION:
                    process_close_session(channel, wai_msg)
                elif wai_msg.msg_type == waiMsg_pb2.waiMsgType.CHUNK:
                    process_chunk_msg(channel, wai_msg)
                elif wai_msg.msg_type == waiMsg_pb2.waiMsgType.EOS:
                    process_eos_msg(channel, wai_msg)

            if session_table.get_active_session_count() == 0:
                # if we don't have active session, we don't need to check worker return and can sleep more.
                time.sleep(0.09)
            else:
                method_frame, header_frame, body = channel.basic_get(queue=worker_outbox_name, auto_ack=True)
                if method_frame:
                    result_msg = waiWorkerMsg_pb2.waiWorkerResponseMsg()
                    result_msg.ParseFromString(body)
                    waiUtil.print_wai_msg(f'Received', result_msg)
                    session_id = result_msg.session_id
                    sender_id, error_code = session_table.get_sender_id(session_id)
                    if sender_id is not None:
                        if result_msg.chunk_index == session_table.get_client_caption_send_count(session_id):
                            outbox_queue = FFMPEG_INBOX_PREFIX.format(sender_id)
                            while True:
                                response_msgs = translate_worker_response(result_msg)
                                # Iterate over response messages and send them
                                for response_msg in response_msgs:
                                    # Serialize the response message to bytes
                                    serialized_msg = response_msg.SerializeToString()
                                
                                    # Publish the serialized message to the outbox queue
                                    channel.basic_publish(exchange='amq.direct', routing_key=outbox_queue, body=serialized_msg)
                                    # Print a message indicating that the message was sent
                                    waiUtil.print_wai_msg(f'Sent to {sender_id}', response_msg)
                                count, error_code = session_table.decrease_chunk_data_in_progress_count(session_id)
                                if (AI_TOKEN_MODE == True):
                                    session_table.set_latest_tokens(session_id, result_msg.tokens)
                                if count == 0:
                                    is_eos_received, error_code = session_table.get_is_eos_received(session_id)
                                    if is_eos_received == 1:
                                        # Call send_eos only if the count is 0
                                        send_eos(channel, session_id)
                                elif (AI_TOKEN_MODE == True):
                                    wai_msg, error_code = session_table.pop_chunk_data(session_id)
                                    wai_worker_msg = build_worker_chunk_msg(session_id, wai_msg)
                                    wai_worker_msg.tokens.extend(result_msg.tokens)
                                    #for multiple worker
                                    channel.basic_publish(exchange='', routing_key=worker_inbox_name, body=wai_worker_msg.SerializeToString())
                                    waiUtil.print_wai_msg("Sent to workers", wai_worker_msg)
                                #for multiple workers
                                chunk_index = session_table.increase_client_caption_send_count(session_id)                                     
                                result_msg = session_table.pop_caption(session_id, chunk_index) 
                                if result_msg == None:
                                    #for multiple workers, we don't have next one.
                                    break   
                                logging.debug(f'*******pop {session_id}, {chunk_index}')
                        else:
                            session_table.save_caption(session_id, result_msg.chunk_index, result_msg)
                            logging.debug(f'*******push {session_id}, {result_msg.chunk_index}')
                    else:
                        logging.error(f"Failed to process worker chunk message (Invalid session_id: {session_id})")

            time.sleep(0.01) # let other system processes do something for reduce idle time cpu usage.
    except Exception as ex:
        if type(ex) == KeyboardInterrupt:
            pass
        logging.info(traceback.format_exc())
    finally:
        connection.close()
        logging.info(f'waiServer stopped')

class SimpleHTTPRequestHandler(BaseHTTPRequestHandler):
    def do_GET(self):
        product_version = os.getenv('PRODUCTVERSION', 'unknown')
        commit_id = os.getenv('COMMITID', 'unknown')
        response = {
            "product_version": product_version,
            "commit_id": commit_id
        }
        
        self.send_response(200)
        self.send_header("Content-type", "application/json")
        self.end_headers()
        self.wfile.write(json.dumps(response).encode())

def run_http_server():
    server_address = ('', HTTP_PORT)  # Listen on all interfaces, port 9098
    httpd = HTTPServer(server_address, SimpleHTTPRequestHandler)
    logging.info(f"HTTP server running on port {HTTP_PORT}")
    httpd.serve_forever()

if __name__ == '__main__':
    # Check for root privileges
    if os.geteuid() != 0:
        print("You need root privileges to run this script.")
        sys.exit(1)

    # Define log directory path
    log_directory = os.path.join(os.getcwd(), 'logs')
    os.makedirs(log_directory, exist_ok=True)

    # Define log file path
    logfile = os.path.join(log_directory, 'waiServer.log')
    with open(logfile, 'w') as f:
        pass
    # For parent host read
    os.chmod(logfile, 0o0644)

    # Daemonize the process
    pid = "/var/run/waiServer.pid"
    action = "start"
    daemon = Daemonize(app="waiServer", pid=pid, action=action)
    daemon.start()

    # Initialize logger with TimedRotatingFileHandler
    handler = TimedRotatingFileHandler(logfile, when="H", interval=1)
    handler.setFormatter(logging.Formatter('%(asctime)s - %(levelname)s - %(message)s'))
    logger = logging.getLogger()
    logger.addHandler(handler)
    logger.setLevel(logging.INFO)

    # Set the desired logging level for Pika
    logging.getLogger('pika').setLevel(logging.WARNING)
    
    # Log message
    logger.info("Daemon started successfully")
    
    # Start the HTTP server in a separate thread
    http_thread = threading.Thread(target=run_http_server)
    http_thread.daemon = True
    http_thread.start()

       # Call the main function
    main(log_directory)
