# waiWhisperUtil.py : Whisper AI interface code.

import pika
import sys
import time
import random
import whisper
#import stable_whisper 
import numpy as np
import logging
import pickle

import waiMsg_pb2
import waiWorkerMsg_pb2
import waiUtil
from waiConstants import AI_TOKEN_MODE 
from waiConstants import WORKER_OUTBOX
from waiConstants import WHISPER_AI_MODEL
from waiConstants import FFMPEG_TIMEBASE

# Set this variable to True to enable print statements, and False to disable them
enable_print_statements = False 
enable_save_ai_data = False 
worker_outbox_name = f'{WORKER_OUTBOX}'

def setupConnection():
    global channelWorkerOutbox
    #parameters = pika.ConnectionParameters(host='localhost', connection_attempts=3, retry_delay=5)
    parameters = pika.ConnectionParameters(host='localhost',heartbeat=600,
                                       blocked_connection_timeout=300)
    connection = pika.BlockingConnection(parameters)
    channelWorkerOutbox = connection.channel()
    #channelWorkerOutbox.queue_declare(queue=worker_outbox_name)

def loadWhisperModel(modelName):
    global gWhisperModel    # This global is important
    #gWhisperModel = stable_whisper.load_model(modelName)
    gWhisperModel = whisper.load_model(modelName)
    gWhisperModel.eval()

def loadDefaultWhisperModel(workerId):
    global worker_id
    worker_id = workerId 
    modelName = WHISPER_AI_MODEL
    loadWhisperModel(modelName)
    logging.info(f'Worker_{worker_id} Initializing... loading model \"{modelName}\"')
    if enable_save_ai_data:
        global ai_data_filename
        ai_data_filename = f'waiData.txt'
        waiUtil.delete_file(ai_data_filename)

def save_result_to_file(result, filename):
    with open(filename, 'a') as file:
        file.write(f"{result}\n")

def load_audios(filename):
     with open(filename, 'rb') as file:
         return pickle.load(file)

def warmUpWhisperModel(warmup_filename):
    logging.info(f'Worker_{worker_id} warmup model with \"{warmup_filename}\"')
    audios = load_audios(warmup_filename)
    chunk_offset = 0
    for chunk_offset, chunk in enumerate(audios):
        result = gWhisperModel.transcribe(chunk, fp16=False, language='english')
        #logging.info(f'Worker_{worker_id} result {result["text"]}')
        break #just do one
    logging.info(f'Worker_{worker_id} model warmup completed\"')

def workerCallback(body):
    wai_msg = waiWorkerMsg_pb2.waiWorkerCommandMsg()
    wai_msg.ParseFromString(body)

    result_msg = waiWorkerMsg_pb2.waiWorkerResponseMsg()
    result_msg.session_id = wai_msg.session_id
    result_msg.chunk_index = wai_msg.chunk_index
    result_msg.chunk_offset = wai_msg.chunk_offset
    result_msg.worker_id = int(sys.argv[1])

    received_chunk_data_list = wai_msg.chunk_data

    # Convert the repeated field to a NumPy array
    chunk = np.array(received_chunk_data_list, dtype=np.float32)
    chunk = chunk / 32768.0
    #logging.info(f'chunk_offset:{wai_msg.chunk_offset}')
    #logging.info(f'wai_msg.chunk_data:{wai_msg.chunk_data[:40]}')
    #logging.info(f'received_chunk_data_list:{received_chunk_data_list[:40]}')
    #logging.info(f'chunk:{chunk[:40]}')

    result_msg.chunk_len = len(chunk)
    if enable_print_statements:
        waiUtil.print_wai_msg(f"Recieve: chunk_len({result_msg.chunk_len}) ", wai_msg)

    # use decode
    '''
    chunk2 = whisper.pad_or_trim(chunk)
    mel = whisper.log_mel_spectrogram(chunk2).to(gWhisperModel.device)
    #logging.info(len(mel))
    # decode the audio
    if (AI_TOKEN_MODE == True):
        options = whisper.DecodingOptions(fp16=False, language='english', prompt=wai_msg.tokens)
    else:
        options = whisper.DecodingOptions(fp16=False, language='english')
    result = whisper.decode(gWhisperModel, mel, options)
    
    result_msg.caption_return.num_phrases = 1
    phrase_struct = result_msg.caption_return.phrases.add()
    phrase_struct.start = 0
    phrase_struct.end = result_msg.chunk_len / FFMPEG_TIMEBASE
    phrase_struct.text = result.text
    phrase_struct.no_speech_prob = result.no_speech_prob
    phrase_struct.avg_logprob = result.avg_logprob
    phrase_struct.compression_ratio = result.compression_ratio
    #logging.info(f'Caption={phrase_struct.text}, {phrase_struct.no_speech_prob:.2f}, {phrase_struct.avg_logprob:.2f}, {phrase_struct.compression_ratio:.2f}')

    if (AI_TOKEN_MODE == True):
        result_msg.tokens.extend(result.tokens)
    '''
   
    #use transcribe
    if (AI_TOKEN_MODE == True):
        #logging.info(f'tokens={wai_msg.tokens}')
        result = gWhisperModel.transcribe(chunk, fp16=False, language='english', prompt=wai_msg.tokens)
    else:
        result = gWhisperModel.transcribe(chunk, fp16=False, language='english')
    #text = result['text'].strip()
    
    num_phrases = len(result.get('segments', []))
    result_msg.caption_return.num_phrases = num_phrases

    # Iterate over 'segments' and add each to 'phrases'
    for segment in result.get('segments', []):
        phrase_struct = result_msg.caption_return.phrases.add()
        phrase_struct.start = segment.get('start', 0.0)
        phrase_struct.end = segment.get('end', 0.0)
        phrase_struct.text = segment.get('text', '')
        phrase_struct.no_speech_prob = segment.get('no_speech_prob', 0.0)
        phrase_struct.avg_logprob = segment.get('avg_logprob', 0.0)
        phrase_struct.compression_ratio = segment.get('compression_ratio', 0.0)
        #logging.info(f'Caption={phrase_struct.text}, {phrase_struct.no_speech_prob:.2f}, {phrase_struct.avg_logprob:.2f}, {phrase_struct.compression_ratio:.2f}')
    #get the last tokens
    # Check if segments list is not empty
    if (AI_TOKEN_MODE == True):
        if result['segments']:
            result_msg.tokens.extend(result['segments'][-1]['tokens'])
    #logging.info(result_msg.caption_return)
    
    if enable_save_ai_data:
        save_result_to_file(result, ai_data_filename)

    if enable_print_statements:
        logging.info(f"done decode")
    #connection = pika.BlockingConnection(pika.ConnectionParameters('localhost'))
    ''' 
    parameters = pika.ConnectionParameters(host='localhost', connection_attempts=3, retry_delay=5)
    connection = pika.BlockingConnection(parameters)
    
    channel = connection.channel()
    channel.queue_declare(queue=worker_outbox_name)
    '''
    try:
        channelWorkerOutbox.basic_publish(exchange='', routing_key=worker_outbox_name, body=result_msg.SerializeToString())
        if enable_print_statements:
            waiUtil.print_wai_msg(f"Sent: ", result_msg)
    except pika.exceptions.AMQPConnectionError as e:
        logging.error(f"Worker_{worker_id}: {worker_outbox_name} Connection failed. Retrying in 2 seconds...")
        logging.error(f"Error details: {e}")
        time.sleep(2)
        # Attempt to reconnect
        setupConnection()
        # Retry the basic_publish operation once
        try:
            channelWorkerOutbox.basic_publish(exchange='', routing_key=worker_outbox_name, body=result_msg.SerializeToString())
            if enable_print_statements:
                waiUtil.print_wai_msg(f"Sent (Retry): ", result_msg)
        except pika.exceptions.AMQPConnectionError as e:
            logging.error(f"Worker_{worker_id}: Retry failed. Error details: {e}")
'''
    connection.close()
    if enable_print_statements:
        logging.info(f"Connection closed")
'''

