diff options
Diffstat (limited to 'agl_service_voiceagent/utils/stt_model.py')
-rw-r--r-- | agl_service_voiceagent/utils/stt_model.py | 91 |
1 files changed, 77 insertions, 14 deletions
diff --git a/agl_service_voiceagent/utils/stt_model.py b/agl_service_voiceagent/utils/stt_model.py index d51ae31..7e8ad8b 100644 --- a/agl_service_voiceagent/utils/stt_model.py +++ b/agl_service_voiceagent/utils/stt_model.py @@ -20,12 +20,18 @@ import vosk import wave from agl_service_voiceagent.utils.common import generate_unique_uuid +# import the whisper model +import whisper +# for whisper timeout feature +from concurrent.futures import ThreadPoolExecutor +import subprocess + class STTModel: """ STTModel is a class for speech-to-text (STT) recognition using the Vosk speech recognition library. """ - def __init__(self, model_path, sample_rate=16000): + def __init__(self, vosk_model_path,whisper_model_path,whisper_cpp_path,whisper_cpp_model_path,sample_rate=16000): """ Initialize the STTModel instance with the provided model and sample rate. @@ -34,12 +40,15 @@ class STTModel: sample_rate (int, optional): The audio sample rate in Hz (default is 16000). """ self.sample_rate = sample_rate - self.model = vosk.Model(model_path) + self.vosk_model = vosk.Model(vosk_model_path) self.recognizer = {} self.chunk_size = 1024 + # self.whisper_model = whisper.load_model(whisper_model_path) + self.whisper_cpp_path = whisper_cpp_path + self.whisper_cpp_model_path = whisper_cpp_model_path - def setup_recognizer(self): + def setup_vosk_recognizer(self): """ Set up a Vosk recognizer for a new session and return a unique identifier (UUID) for the session. @@ -47,10 +56,9 @@ class STTModel: str: A unique identifier (UUID) for the session. """ uuid = generate_unique_uuid(6) - self.recognizer[uuid] = vosk.KaldiRecognizer(self.model, self.sample_rate) + self.recognizer[uuid] = vosk.KaldiRecognizer(self.vosk_model, self.sample_rate) return uuid - def init_recognition(self, uuid, audio_data): """ Initialize the Vosk recognizer for a session with audio data. @@ -64,8 +72,8 @@ class STTModel: """ return self.recognizer[uuid].AcceptWaveform(audio_data) - - def recognize(self, uuid, partial=False): + # Recognize speech using the Vosk recognizer + def recognize_using_vosk(self, uuid, partial=False): """ Recognize speech and return the result as a JSON object. @@ -84,14 +92,53 @@ class STTModel: self.recognizer[uuid].Reset() return result + # Recognize speech using the whisper model + def recognize_using_whisper(self,filename,language = None,timeout = 5,fp16=False): + """ + Recognize speech and return the result as a JSON object. + + Args: + filename (str): The path to the audio file. + timeout (int, optional): The timeout for recognition (default is 5 seconds). + fp16 (bool, optional): If True, use 16-bit floating point precision, (default is False) because cuda is not supported. + language (str, optional): The language code for recognition (default is None). + + Returns: + dict: A JSON object containing recognition results. + """ + def transcribe_with_whisper(): + return self.whisper_model.transcribe(filename, language = language,fp16=fp16) + + with ThreadPoolExecutor() as executor: + future = executor.submit(transcribe_with_whisper) + try: + return future.result(timeout=timeout) + except TimeoutError: + return {"error": "Transcription with Whisper exceeded the timeout."} + + def recognize_using_whisper_cpp(self,filename): + command = self.whisper_cpp_path + arguments = ["-m", self.whisper_cpp_model_path, "-f", filename, "-l", "en","-nt"] + + # Run the executable with the specified arguments + result = subprocess.run([command] + arguments, capture_output=True, text=True) + + if result.returncode == 0: + result = result.stdout.replace('\n', ' ').strip() + return {"text": result} + else: + print("Error:\n", result.stderr) + return {"error": result.stderr} + - def recognize_from_file(self, uuid, filename): + def recognize_from_file(self, uuid, filename,stt_framework="vosk"): """ Recognize speech from an audio file and return the recognized text. Args: uuid (str): The unique identifier (UUID) for the session. filename (str): The path to the audio file. + stt_model (str): The STT model to use for recognition (default is "vosk"). Returns: str: The recognized text or error messages. @@ -115,12 +162,28 @@ class STTModel: audio_data += chunk if audio_data: - if self.init_recognition(uuid, audio_data): - result = self.recognize(uuid) - return result['text'] - else: - result = self.recognize(uuid, partial=True) - return result['partial'] + # Perform speech recognition using the specified STT model + if stt_framework == "vosk": + if self.init_recognition(uuid, audio_data): + result = self.recognize_using_vosk(uuid) + return result['text'] + else: + result = self.recognize_using_vosk(uuid, partial=True) + return result['partial'] + + elif stt_framework == "whisper": + result = self.recognize_using_whisper_cpp(filename) + if 'error' in result: + print(result['error']) + # If Whisper times out, fall back to Vosk + if self.init_recognition(uuid, audio_data): + result = self.recognize_using_vosk(uuid) + return result['text'] + else: + result = self.recognize_using_vosk(uuid, partial=True) + return result['partial'] + else: + return result.get('text', '') else: print("Voice not recognized. Please speak again...") |