diff options
author | Malik Talha <talhamalik727x@gmail.com> | 2023-10-29 20:52:29 +0500 |
---|---|---|
committer | Malik Talha <talhamalik727x@gmail.com> | 2023-10-29 20:52:29 +0500 |
commit | 42a03d2550f60a8064078f19a743afb944f9ff69 (patch) | |
tree | c9a7b3d028737d5fecd2e05f69e1c744810ed5fb /agl_service_voiceagent | |
parent | a10c988b5480ca5b937a2793b450cfa01f569d76 (diff) |
Update voice agent service
Add new features such as an option to load service
using an external config file, enhanced kuksa client,
and a more robust mapper.
Signed-off-by: Malik Talha <talhamalik727x@gmail.com>
Change-Id: Iba3cfd234c0aabad67b293669d456bb73d8e3135
Diffstat (limited to 'agl_service_voiceagent')
-rw-r--r-- | agl_service_voiceagent/client.py | 16 | ||||
-rw-r--r-- | agl_service_voiceagent/config.ini | 17 | ||||
-rw-r--r-- | agl_service_voiceagent/nlu/rasa_interface.py | 41 | ||||
-rw-r--r-- | agl_service_voiceagent/nlu/snips_interface.py | 38 | ||||
-rw-r--r-- | agl_service_voiceagent/protos/voice_agent.proto | 3 | ||||
-rw-r--r-- | agl_service_voiceagent/server.py | 3 | ||||
-rw-r--r-- | agl_service_voiceagent/service.py | 63 | ||||
-rw-r--r-- | agl_service_voiceagent/servicers/voice_agent_servicer.py | 109 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/audio_recorder.py | 52 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/common.py | 91 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/config.py | 34 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/kuksa_interface.py | 178 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/mapper.py | 261 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/stt_model.py | 83 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/wake_word.py | 95 |
15 files changed, 979 insertions, 105 deletions
diff --git a/agl_service_voiceagent/client.py b/agl_service_voiceagent/client.py index 9b2e0a0..922e08c 100644 --- a/agl_service_voiceagent/client.py +++ b/agl_service_voiceagent/client.py @@ -20,14 +20,8 @@ from agl_service_voiceagent.generated import voice_agent_pb2 from agl_service_voiceagent.generated import voice_agent_pb2_grpc from agl_service_voiceagent.utils.config import get_config_value -# following code is only reqired for logging -import logging -logging.basicConfig() -logging.getLogger("grpc").setLevel(logging.DEBUG) - -SERVER_URL = get_config_value('SERVER_ADDRESS') + ":" + str(get_config_value('SERVER_PORT')) - def run_client(mode, nlu_model): + SERVER_URL = get_config_value('SERVER_ADDRESS') + ":" + str(get_config_value('SERVER_PORT')) nlu_model = voice_agent_pb2.SNIPS if nlu_model == "snips" else voice_agent_pb2.RASA print("Starting Voice Agent Client...") print(f"Client connecting to URL: {SERVER_URL}") @@ -73,6 +67,12 @@ def run_client(mode, nlu_model): print("Command:", record_result.command) print("Status:", status) print("Intent:", record_result.intent) + intent_slots = [] for slot in record_result.intent_slots: print("Slot Name:", slot.name) - print("Slot Value:", slot.value)
\ No newline at end of file + print("Slot Value:", slot.value) + i_slot = voice_agent_pb2.IntentSlot(name=slot.name, value=slot.value) + intent_slots.append(i_slot) + + exec_voice_command_request = voice_agent_pb2.ExecuteInput(intent=record_result.intent, intent_slots=intent_slots) + response = stub.ExecuteVoiceCommand(exec_voice_command_request)
\ No newline at end of file diff --git a/agl_service_voiceagent/config.ini b/agl_service_voiceagent/config.ini index 9455d6a..074f6a8 100644 --- a/agl_service_voiceagent/config.ini +++ b/agl_service_voiceagent/config.ini @@ -1,22 +1,27 @@ [General] -service_version = 0.2.0 base_audio_dir = /usr/share/nlu/commands/ -stt_model_path = /usr/share/vosk/vosk-model-small-en-us-0.15 +stt_model_path = /usr/share/vosk/vosk-model-small-en-us-0.15/ +wake_word_model_path = /usr/share/vosk/vosk-model-small-en-us-0.15/ snips_model_path = /usr/share/nlu/snips/model/ channels = 1 sample_rate = 16000 bits_per_sample = 16 wake_word = hello auto server_port = 51053 -server_address = localhost +server_address = 127.0.0.1 rasa_model_path = /usr/share/nlu/rasa/models/ rasa_server_port = 51054 +rasa_detached_mode = 0 base_log_dir = /usr/share/nlu/logs/ store_voice_commands = 0 [Kuksa] -ip = localhost +ip = 127.0.0.1 port = 8090 protocol = ws -insecure = False -token = / +insecure = True +token = /usr/lib/python3.10/site-packages/kuksa_certificates/jwt/super-admin.json.token + +[Mapper] +intents_vss_map = /usr/share/nlu/mappings/intents_vss_map.json +vss_signals_spec = /usr/share/nlu/mappings/vss_signals_spec.json
\ No newline at end of file diff --git a/agl_service_voiceagent/nlu/rasa_interface.py b/agl_service_voiceagent/nlu/rasa_interface.py index 0232126..537a318 100644 --- a/agl_service_voiceagent/nlu/rasa_interface.py +++ b/agl_service_voiceagent/nlu/rasa_interface.py @@ -21,7 +21,20 @@ import subprocess from concurrent.futures import ThreadPoolExecutor class RASAInterface: + """ + RASAInterface is a class for interfacing with a Rasa NLU server to extract intents and entities from text input. + """ + def __init__(self, port, model_path, log_dir, max_threads=5): + """ + Initialize the RASAInterface instance with the provided parameters. + + Args: + port (int): The port number on which the Rasa NLU server will run. + model_path (str): The path to the Rasa NLU model. + log_dir (str): The directory where server logs will be saved. + max_threads (int, optional): The maximum number of concurrent threads (default is 5). + """ self.port = port self.model_path = model_path self.max_threads = max_threads @@ -31,6 +44,9 @@ class RASAInterface: def _start_server(self): + """ + Start the Rasa NLU server in a subprocess and redirect its output to the log file. + """ command = ( f"rasa run --enable-api -m \"{self.model_path}\" -p {self.port}" ) @@ -41,6 +57,9 @@ class RASAInterface: def start_server(self): + """ + Start the Rasa NLU server in a separate thread and wait for it to initialize. + """ self.thread_pool.submit(self._start_server) # Wait for a brief moment to allow the server to start @@ -48,6 +67,9 @@ class RASAInterface: def stop_server(self): + """ + Stop the Rasa NLU server and shut down the thread pool. + """ if self.server_process: self.server_process.terminate() self.server_process.wait() @@ -56,6 +78,16 @@ class RASAInterface: def preprocess_text(self, text): + """ + Preprocess the input text by converting it to lowercase, removing leading/trailing spaces, + and removing special characters and punctuation. + + Args: + text (str): The input text to preprocess. + + Returns: + str: The preprocessed text. + """ # text to lower case and remove trailing and leading spaces preprocessed_text = text.lower().strip() # remove special characters, punctuation, and extra whitespaces @@ -77,6 +109,15 @@ class RASAInterface: def process_intent(self, intent_output): + """ + Extract intents and entities from preprocessed text using the Rasa NLU server. + + Args: + text (str): The preprocessed input text. + + Returns: + dict: Intent and entity extraction result as a dictionary. + """ intent = intent_output["intent"]["name"] entities = {} for entity in intent_output["entities"]: diff --git a/agl_service_voiceagent/nlu/snips_interface.py b/agl_service_voiceagent/nlu/snips_interface.py index f0b05d2..1febe92 100644 --- a/agl_service_voiceagent/nlu/snips_interface.py +++ b/agl_service_voiceagent/nlu/snips_interface.py @@ -19,10 +19,30 @@ from typing import Text from snips_inference_agl import SnipsNLUEngine class SnipsInterface: + """ + SnipsInterface is a class for interacting with the Snips Natural Language Understanding Engine (Snips NLU). + """ + def __init__(self, model_path: Text): + """ + Initialize the SnipsInterface instance with the provided Snips NLU model. + + Args: + model_path (Text): The path to the Snips NLU model. + """ self.engine = SnipsNLUEngine.from_path(model_path) def preprocess_text(self, text): + """ + Preprocess the input text by converting it to lowercase, removing leading/trailing spaces, + and removing special characters and punctuation. + + Args: + text (str): The input text to preprocess. + + Returns: + str: The preprocessed text. + """ # text to lower case and remove trailing and leading spaces preprocessed_text = text.lower().strip() # remove special characters, punctuation, and extra whitespaces @@ -30,11 +50,29 @@ class SnipsInterface: return preprocessed_text def extract_intent(self, text: Text): + """ + Extract the intent from preprocessed text using the Snips NLU engine. + + Args: + text (Text): The preprocessed input text. + + Returns: + dict: The intent extraction result as a dictionary. + """ preprocessed_text = self.preprocess_text(text) result = self.engine.parse(preprocessed_text) return result def process_intent(self, intent_output): + """ + Extract intent and slot values from Snips NLU output. + + Args: + intent_output (dict): The intent extraction result from Snips NLU. + + Returns: + tuple: A tuple containing the intent name (str) and a dictionary of intent actions (entity-value pairs). + """ intent_actions = {} intent = intent_output['intent']['intentName'] slots = intent_output.get('slots', []) diff --git a/agl_service_voiceagent/protos/voice_agent.proto b/agl_service_voiceagent/protos/voice_agent.proto index 8ee8324..8c3ab65 100644 --- a/agl_service_voiceagent/protos/voice_agent.proto +++ b/agl_service_voiceagent/protos/voice_agent.proto @@ -78,5 +78,6 @@ message ExecuteInput { } message ExecuteResult { - ExecuteStatusType status = 1; + string response = 1; + ExecuteStatusType status = 2; } diff --git a/agl_service_voiceagent/server.py b/agl_service_voiceagent/server.py index 5fe0746..d8ce785 100644 --- a/agl_service_voiceagent/server.py +++ b/agl_service_voiceagent/server.py @@ -20,9 +20,8 @@ from agl_service_voiceagent.generated import voice_agent_pb2_grpc from agl_service_voiceagent.servicers.voice_agent_servicer import VoiceAgentServicer from agl_service_voiceagent.utils.config import get_config_value -SERVER_URL = get_config_value('SERVER_ADDRESS') + ":" + str(get_config_value('SERVER_PORT')) - def run_server(): + SERVER_URL = get_config_value('SERVER_ADDRESS') + ":" + str(get_config_value('SERVER_PORT')) print("Starting Voice Agent Service...") print(f"Server running at URL: {SERVER_URL}") print(f"STT Model Path: {get_config_value('STT_MODEL_PATH')}") diff --git a/agl_service_voiceagent/service.py b/agl_service_voiceagent/service.py index 1b34c27..784d8d9 100644 --- a/agl_service_voiceagent/service.py +++ b/agl_service_voiceagent/service.py @@ -25,7 +25,7 @@ generated_dir = os.path.join(current_dir, "generated") sys.path.append(generated_dir) import argparse -from agl_service_voiceagent.utils.config import update_config_value, get_config_value +from agl_service_voiceagent.utils.config import set_config_path, load_config, update_config_value, get_config_value from agl_service_voiceagent.utils.common import add_trailing_slash from agl_service_voiceagent.server import run_server from agl_service_voiceagent.client import run_client @@ -33,7 +33,7 @@ from agl_service_voiceagent.client import run_client def print_version(): print("Automotive Grade Linux (AGL)") - print(f"Voice Agent Service v{get_config_value('SERVICE_VERSION')}") + print(f"Voice Agent Service v0.3.0") def main(): @@ -47,10 +47,15 @@ def main(): server_parser = subparsers.add_parser('run-server', help='Run the Voice Agent gRPC Server') client_parser = subparsers.add_parser('run-client', help='Run the Voice Agent gRPC Client') - server_parser.add_argument('--config', action='store_true', help='Starts the server solely based on values provided in config file.') - server_parser.add_argument('--stt-model-path', required=False, help='Path to the Speech To Text model. Currently only supports VOSK Kaldi.') + server_parser.add_argument('--default', action='store_true', help='Starts the server based on default config file.') + server_parser.add_argument('--config', required=False, help='Path to a config file. Server is started based on this config file.') + server_parser.add_argument('--stt-model-path', required=False, help='Path to the Speech To Text model for Voice Commad detection. Currently only supports VOSK Kaldi.') + server_parser.add_argument('--ww-model-path', required=False, help='Path to the Speech To Text model for Wake Word detection. Currently only supports VOSK Kaldi. Defaults to the same model as --stt-model-path if not provided.') server_parser.add_argument('--snips-model-path', required=False, help='Path to the Snips NLU model.') server_parser.add_argument('--rasa-model-path', required=False, help='Path to the RASA NLU model.') + server_parser.add_argument('--rasa-detached-mode', required=False, help='Assume that the RASA server is already running and does not start it as a sub process.') + server_parser.add_argument('--intents-vss-map-path', required=False, help='Path to the JSON file containing Intent to VSS map.') + server_parser.add_argument('--vss-signals-spec-path', required=False, help='Path to the VSS signals specification JSON file.') server_parser.add_argument('--audio-store-dir', required=False, help='Directory to store the generated audio files.') server_parser.add_argument('--log-store-dir', required=False, help='Directory to store the generated log files.') @@ -63,7 +68,7 @@ def main(): print_version() elif args.subcommand == 'run-server': - if not args.config: + if not args.default and not args.config: if not args.stt_model_path: raise ValueError("The --stt-model-path is missing. Please provide a value. Use --help to see available options.") @@ -72,20 +77,42 @@ def main(): if not args.rasa_model_path: raise ValueError("The --rasa-model-path is missing. Please provide a value. Use --help to see available options.") + + if not args.intents_vss_map_path: + raise ValueError("The --intents-vss-map-path is missing. Please provide a value. Use --help to see available options.") + + if not args.vss_signals_spec_path: + raise ValueError("The --vss-signals-spec is missing. Please provide a value. Use --help to see available options.") + + # Contruct the default config file path + config_path = os.path.join(current_dir, "config.ini") + # Load the config values from the config file + set_config_path(config_path) + load_config() + + # Get the values provided by the user stt_path = args.stt_model_path snips_model_path = args.snips_model_path rasa_model_path = args.rasa_model_path + intents_vss_map_path = args.intents_vss_map_path + vss_signals_spec_path = args.vss_signals_spec_path # Convert to an absolute path if it's a relative path stt_path = add_trailing_slash(os.path.abspath(stt_path)) if not os.path.isabs(stt_path) else stt_path snips_model_path = add_trailing_slash(os.path.abspath(snips_model_path)) if not os.path.isabs(snips_model_path) else snips_model_path rasa_model_path = add_trailing_slash(os.path.abspath(rasa_model_path)) if not os.path.isabs(rasa_model_path) else rasa_model_path + intents_vss_map_path = os.path.abspath(intents_vss_map_path) if not os.path.isabs(intents_vss_map_path) else intents_vss_map_path + vss_signals_spec_path = os.path.abspath(vss_signals_spec_path) if not os.path.isabs(vss_signals_spec_path) else vss_signals_spec_path # Also update the config.ini file update_config_value(stt_path, 'STT_MODEL_PATH') update_config_value(snips_model_path, 'SNIPS_MODEL_PATH') update_config_value(rasa_model_path, 'RASA_MODEL_PATH') + update_config_value(intents_vss_map_path, 'INTENTS_VSS_MAP') + update_config_value(vss_signals_spec_path, 'VSS_SIGNALS_SPEC') + if args.rasa_detached_mode: + update_config_value('1', 'RASA_DETACHED_MODE') # Update the audio store dir in config.ini if provided audio_dir = args.audio_store_dir or get_config_value('BASE_AUDIO_DIR') @@ -98,6 +125,25 @@ def main(): update_config_value(log_dir, 'BASE_LOG_DIR') + elif args.config: + # Get config file path value + cli_config_path = args.config + + # if config file path provided then load the config values from it + if cli_config_path : + cli_config_path = os.path.abspath(cli_config_path) if not os.path.isabs(cli_config_path) else cli_config_path + print(f"New config file path provided: {cli_config_path}. Overriding the default config file path.") + set_config_path(cli_config_path) + load_config() + + elif args.default: + # Contruct the default config file path + config_path = os.path.join(current_dir, "config.ini") + + # Load the config values from the config file + set_config_path(config_path) + load_config() + # create the base audio dir if not exists if not os.path.exists(get_config_value('BASE_AUDIO_DIR')): os.makedirs(get_config_value('BASE_AUDIO_DIR')) @@ -109,6 +155,13 @@ def main(): run_server() elif args.subcommand == 'run-client': + # Contruct the default config file path + config_path = os.path.join(current_dir, "config.ini") + + # Load the config values from the config file + set_config_path(config_path) + load_config() + mode = args.mode if mode not in ['wake-word', 'auto', 'manual']: raise ValueError("Invalid mode. Supported modes: 'wake-word', 'auto' and 'manual'. Use --help to see available options.") diff --git a/agl_service_voiceagent/servicers/voice_agent_servicer.py b/agl_service_voiceagent/servicers/voice_agent_servicer.py index 4038b85..69af10b 100644 --- a/agl_service_voiceagent/servicers/voice_agent_servicer.py +++ b/agl_service_voiceagent/servicers/voice_agent_servicer.py @@ -17,42 +17,65 @@ import grpc import time import threading -from generated import voice_agent_pb2 -from generated import voice_agent_pb2_grpc +from agl_service_voiceagent.generated import voice_agent_pb2 +from agl_service_voiceagent.generated import voice_agent_pb2_grpc from agl_service_voiceagent.utils.audio_recorder import AudioRecorder from agl_service_voiceagent.utils.wake_word import WakeWordDetector from agl_service_voiceagent.utils.stt_model import STTModel +from agl_service_voiceagent.utils.kuksa_interface import KuksaInterface +from agl_service_voiceagent.utils.mapper import Intent2VSSMapper from agl_service_voiceagent.utils.config import get_config_value +from agl_service_voiceagent.utils.common import generate_unique_uuid, delete_file from agl_service_voiceagent.nlu.snips_interface import SnipsInterface from agl_service_voiceagent.nlu.rasa_interface import RASAInterface -from agl_service_voiceagent.utils.common import generate_unique_uuid, delete_file class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer): + """ + Voice Agent Servicer class that implements the gRPC service defined in voice_agent.proto. + """ + def __init__(self): + """ + Constructor for VoiceAgentServicer class. + """ # Get the config values - self.service_version = get_config_value('SERVICE_VERSION') + self.service_version = "v0.3.0" self.wake_word = get_config_value('WAKE_WORD') self.base_audio_dir = get_config_value('BASE_AUDIO_DIR') self.channels = int(get_config_value('CHANNELS')) self.sample_rate = int(get_config_value('SAMPLE_RATE')) self.bits_per_sample = int(get_config_value('BITS_PER_SAMPLE')) self.stt_model_path = get_config_value('STT_MODEL_PATH') + self.wake_word_model_path = get_config_value('WAKE_WORD_MODEL_PATH') self.snips_model_path = get_config_value('SNIPS_MODEL_PATH') self.rasa_model_path = get_config_value('RASA_MODEL_PATH') self.rasa_server_port = int(get_config_value('RASA_SERVER_PORT')) + self.rasa_detached_mode = bool(int(get_config_value('RASA_DETACHED_MODE'))) self.base_log_dir = get_config_value('BASE_LOG_DIR') self.store_voice_command = bool(int(get_config_value('STORE_VOICE_COMMANDS'))) # Initialize class methods self.stt_model = STTModel(self.stt_model_path, self.sample_rate) + self.stt_wake_word_model = STTModel(self.wake_word_model_path, self.sample_rate) self.snips_interface = SnipsInterface(self.snips_model_path) self.rasa_interface = RASAInterface(self.rasa_server_port, self.rasa_model_path, self.base_log_dir) - self.rasa_interface.start_server() + + # Only start RASA server if its not in detached mode, else we assume server is already running + if not self.rasa_detached_mode: + self.rasa_interface.start_server() + self.rvc_stream_uuids = {} + self.kuksa_client = KuksaInterface() + self.kuksa_client.connect_kuksa_client() + self.kuksa_client.authorize_kuksa_client() + self.mapper = Intent2VSSMapper() def CheckServiceStatus(self, request, context): + """ + Check the status of the Voice Agent service including the version. + """ response = voice_agent_pb2.ServiceStatus( version=self.service_version, status=True @@ -61,13 +84,16 @@ class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer): def DetectWakeWord(self, request, context): + """ + Detect the wake word using the wake word detection model. + """ wake_word_detector = WakeWordDetector(self.wake_word, self.stt_model, self.channels, self.sample_rate, self.bits_per_sample) wake_word_detector.create_pipeline() detection_thread = threading.Thread(target=wake_word_detector.start_listening) detection_thread.start() while True: status = wake_word_detector.get_wake_word_status() - time.sleep(1) + time.sleep(0.5) if not context.is_active(): wake_word_detector.send_eos() break @@ -79,6 +105,9 @@ class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer): def RecognizeVoiceCommand(self, requests, context): + """ + Recognize the voice command using the STT model and extract the intent using the NLU model. + """ stt = "" intent = "" intent_slots = [] @@ -154,3 +183,71 @@ class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer): status=status ) return response + + + def ExecuteVoiceCommand(self, request, context): + """ + Execute the voice command by sending the intent to Kuksa. + """ + intent = request.intent + intent_slots = request.intent_slots + processed_slots = [] + for slot in intent_slots: + slot_name = slot.name + slot_value = slot.value + processed_slots.append({"name": slot_name, "value": slot_value}) + + print(intent) + print(processed_slots) + execution_list = self.mapper.parse_intent(intent, processed_slots) + exec_response = f"Sorry, I failed to execute command against intent '{intent}'. Maybe try again with more specific instructions." + exec_status = voice_agent_pb2.EXEC_ERROR + + for execution_item in execution_list: + print(execution_item) + action = execution_item["action"] + signal = execution_item["signal"] + + if self.kuksa_client.get_kuksa_status(): + if action == "set" and "value" in execution_item: + value = execution_item["value"] + if self.kuksa_client.send_values(signal, value): + exec_response = f"Yay, I successfully updated the intent '{intent}' to value '{value}'." + exec_status = voice_agent_pb2.EXEC_SUCCESS + + elif action in ["increase", "decrease"]: + if "value" in execution_item: + value = execution_item["value"] + if self.kuksa_client.send_values(signal, value): + exec_response = f"Yay, I successfully updated the intent '{intent}' to value '{value}'." + exec_status = voice_agent_pb2.EXEC_SUCCESS + + elif "factor" in execution_item: + factor = execution_item["factor"] + current_value = self.kuksa_client.get_value(signal) + if current_value: + current_value = int(current_value) + if action == "increase": + value = current_value + factor + value = str(value) + elif action == "decrease": + value = current_value - factor + value = str(value) + if self.kuksa_client.send_values(signal, value): + exec_response = f"Yay, I successfully updated the intent '{intent}' to value '{value}'." + exec_status = voice_agent_pb2.EXEC_SUCCESS + + else: + exec_response = f"Uh oh, there is no value set for intent '{intent}'. Why not try setting a value first?" + exec_status = voice_agent_pb2.EXEC_KUKSA_CONN_ERROR + + else: + exec_response = "Uh oh, I failed to connect to Kuksa." + exec_status = voice_agent_pb2.EXEC_KUKSA_CONN_ERROR + + response = voice_agent_pb2.ExecuteResult( + response=exec_response, + status=exec_status + ) + + return response diff --git a/agl_service_voiceagent/utils/audio_recorder.py b/agl_service_voiceagent/utils/audio_recorder.py index 61ce994..2e8f11d 100644 --- a/agl_service_voiceagent/utils/audio_recorder.py +++ b/agl_service_voiceagent/utils/audio_recorder.py @@ -15,7 +15,6 @@ # limitations under the License. import gi -import vosk import time gi.require_version('Gst', '1.0') from gi.repository import Gst, GLib @@ -24,7 +23,21 @@ Gst.init(None) GLib.threads_init() class AudioRecorder: + """ + AudioRecorder is a class for recording audio using GStreamer in various modes. + """ + def __init__(self, stt_model, audio_files_basedir, channels=1, sample_rate=16000, bits_per_sample=16): + """ + Initialize the AudioRecorder instance with the provided parameters. + + Args: + stt_model (str): The speech-to-text model to use for voice input recognition. + audio_files_basedir (str): The base directory for saving audio files. + channels (int, optional): The number of audio channels (default is 1). + sample_rate (int, optional): The audio sample rate in Hz (default is 16000). + bits_per_sample (int, optional): The number of bits per sample (default is 16). + """ self.loop = GLib.MainLoop() self.mode = None self.pipeline = None @@ -43,6 +56,12 @@ class AudioRecorder: def create_pipeline(self): + """ + Create and configure the GStreamer audio recording pipeline. + + Returns: + str: The name of the audio file being recorded. + """ print("Creating pipeline for audio recording in {} mode...".format(self.mode)) self.pipeline = Gst.Pipeline() autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None) @@ -83,32 +102,50 @@ class AudioRecorder: def start_recording(self): + """ + Start recording audio using the GStreamer pipeline. + """ self.pipeline.set_state(Gst.State.PLAYING) print("Recording Voice Input...") def stop_recording(self): + """ + Stop audio recording and clean up the GStreamer pipeline. + """ print("Stopping recording...") - self.frames_above_threshold = 0 - self.cleanup_pipeline() + # self.cleanup_pipeline() + self.pipeline.send_event(Gst.Event.new_eos()) print("Recording finished!") def set_pipeline_mode(self, mode): + """ + Set the recording mode to 'auto' or 'manual'. + + Args: + mode (str): The recording mode ('auto' or 'manual'). + """ self.mode = mode - # this method helps with error handling def on_bus_message(self, bus, message): + """ + Handle GStreamer bus messages and perform actions based on the message type. + + Args: + bus (Gst.Bus): The GStreamer bus. + message (Gst.Message): The GStreamer message to process. + """ if message.type == Gst.MessageType.EOS: print("End-of-stream message received") - self.stop_recording() + self.cleanup_pipeline() elif message.type == Gst.MessageType.ERROR: err, debug_info = message.parse_error() print(f"Error received from element {message.src.get_name()}: {err.message}") print(f"Debugging information: {debug_info}") - self.stop_recording() + self.cleanup_pipeline() elif message.type == Gst.MessageType.WARNING: err, debug_info = message.parse_warning() @@ -136,6 +173,9 @@ class AudioRecorder: def cleanup_pipeline(self): + """ + Clean up the GStreamer pipeline, set it to NULL state, and remove the signal watch. + """ if self.pipeline is not None: print("Cleaning up pipeline...") self.pipeline.set_state(Gst.State.NULL) diff --git a/agl_service_voiceagent/utils/common.py b/agl_service_voiceagent/utils/common.py index 682473e..b9d6577 100644 --- a/agl_service_voiceagent/utils/common.py +++ b/agl_service_voiceagent/utils/common.py @@ -20,12 +20,18 @@ import json def add_trailing_slash(path): + """ + Adds a trailing slash to a path if it does not already have one. + """ if path and not path.endswith('/'): path += '/' return path def generate_unique_uuid(length): + """ + Generates a unique ID of specified length. + """ unique_id = str(uuid.uuid4().int) # Ensure the generated ID is exactly 'length' digits by taking the last 'length' characters unique_id = unique_id[-length:] @@ -33,18 +39,91 @@ def generate_unique_uuid(length): def load_json_file(file_path): - try: - with open(file_path, 'r') as file: - return json.load(file) - except FileNotFoundError: - raise ValueError(f"File '{file_path}' not found.") + """ + Loads a JSON file and returns the data. + """ + try: + with open(file_path, 'r') as file: + return json.load(file) + except FileNotFoundError: + raise ValueError(f"File '{file_path}' not found.") def delete_file(file_path): + """ + Deletes a file if it exists. + """ if os.path.exists(file_path): try: os.remove(file_path) except Exception as e: print(f"Error deleting '{file_path}': {e}") else: - print(f"File '{file_path}' does not exist.")
\ No newline at end of file + print(f"File '{file_path}' does not exist.") + + +def words_to_number(words): + """ + Converts a string of words to a number. + """ + word_to_number = { + 'one': 1, + 'two': 2, + 'three': 3, + 'four': 4, + 'five': 5, + 'six': 6, + 'seven': 7, + 'eight': 8, + 'nine': 9, + 'ten': 10, + 'eleven': 11, + 'twelve': 12, + 'thirteen': 13, + 'fourteen': 14, + 'fifteen': 15, + 'sixteen': 16, + 'seventeen': 17, + 'eighteen': 18, + 'nineteen': 19, + 'twenty': 20, + 'thirty': 30, + 'forty': 40, + 'fourty': 40, + 'fifty': 50, + 'sixty': 60, + 'seventy': 70, + 'eighty': 80, + 'ninety': 90 + } + + # Split the input words and initialize total and current number + words = words.split() + + if len(words) == 1: + if words[0] in ["zero", "min", "minimum"]: + return 0 + + elif words[0] in ["half", "halfway"]: + return 50 + + elif words[0] in ["max", "maximum", "full", "fully", "completely", "hundred"]: + return 100 + + + total = 0 + current_number = 0 + + for word in words: + if word in word_to_number: + current_number += word_to_number[word] + elif word == 'hundred': + current_number *= 100 + else: + total += current_number + current_number = 0 + + total += current_number + + # we return number in str format because kuksa expects str input + return str(total) or None
\ No newline at end of file diff --git a/agl_service_voiceagent/utils/config.py b/agl_service_voiceagent/utils/config.py index 8d7f346..7295c7f 100644 --- a/agl_service_voiceagent/utils/config.py +++ b/agl_service_voiceagent/utils/config.py @@ -14,21 +14,41 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import configparser -# Get the absolute path to the directory of the current script -current_dir = os.path.dirname(os.path.abspath(__file__)) -# Construct the path to the config.ini file located in the base directory -config_path = os.path.join(current_dir, '..', 'config.ini') - config = configparser.ConfigParser() -config.read(config_path) +config_path = None + +def set_config_path(path): + """ + Sets the path to the config file. + """ + global config_path + config_path = path + config.read(config_path) + +def load_config(): + """ + Loads the config file. + """ + if config_path is not None: + config.read(config_path) + else: + raise Exception("Config file path not provided.") def update_config_value(value, key, group="General"): + """ + Updates a value in the config file. + """ + if config_path is None: + raise Exception("Config file path not set.") + config.set(group, key, value) with open(config_path, 'w') as configfile: config.write(configfile) def get_config_value(key, group="General"): + """ + Gets a value from the config file. + """ return config.get(group, key) diff --git a/agl_service_voiceagent/utils/kuksa_interface.py b/agl_service_voiceagent/utils/kuksa_interface.py index 3e1c045..9270379 100644 --- a/agl_service_voiceagent/utils/kuksa_interface.py +++ b/agl_service_voiceagent/utils/kuksa_interface.py @@ -14,53 +14,197 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time +import json +import threading from kuksa_client import KuksaClientThread from agl_service_voiceagent.utils.config import get_config_value class KuksaInterface: - def __init__(self): + """ + Kuksa Interface + + This class provides methods to initialize, authorize, connect, send values, + check the status, and close the Kuksa client. + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls): + """ + Get the unique instance of the class. + + Returns: + KuksaInterface: The instance of the class. + """ + with cls._lock: + if cls._instance is None: + cls._instance = super(KuksaInterface, cls).__new__(cls) + cls._instance.init_client() + return cls._instance + + + def init_client(self): + """ + Initialize the Kuksa client configuration. + """ # get config values - self.ip = get_config_value("ip", "Kuksa") - self.port = get_config_value("port", "Kuksa") + self.ip = str(get_config_value("ip", "Kuksa")) + self.port = str(get_config_value("port", "Kuksa")) self.insecure = get_config_value("insecure", "Kuksa") self.protocol = get_config_value("protocol", "Kuksa") self.token = get_config_value("token", "Kuksa") + print(self.ip, self.port, self.insecure, self.protocol, self.token) + # define class methods self.kuksa_client = None def get_kuksa_client(self): + """ + Get the Kuksa client instance. + + Returns: + KuksaClientThread: The Kuksa client instance. + """ return self.kuksa_client - + def get_kuksa_status(self): + """ + Check the status of the Kuksa client connection. + + Returns: + bool: True if the client is connected, False otherwise. + """ if self.kuksa_client: return self.kuksa_client.checkConnection() - return False def connect_kuksa_client(self): + """ + Connect and start the Kuksa client. + """ try: - self.kuksa_client = KuksaClientThread({ - "ip": self.ip, - "port": self.port, - "insecure": self.insecure, - "protocol": self.protocol, - }) - self.kuksa_client.authorize(self.token) + with self._lock: + if self.kuksa_client is None: + self.kuksa_client = KuksaClientThread({ + "ip": self.ip, + "port": self.port, + "insecure": self.insecure, + "protocol": self.protocol, + }) + self.kuksa_client.start() + time.sleep(2) # Give the thread time to start + + if not self.get_kuksa_status(): + print("[-] Error: Connection to Kuksa server failed.") + else: + print("[+] Connection to Kuksa established.") except Exception as e: - print("Error: ", e) + print("[-] Error: Connection to Kuksa server failed. ", str(e)) + + + def authorize_kuksa_client(self): + """ + Authorize the Kuksa client with the provided token. + """ + if self.kuksa_client: + response = self.kuksa_client.authorize(self.token) + response = json.loads(response) + if "error" in response: + error_message = response.get("error", "Unknown error") + print(f"[-] Error: Authorization failed. {error_message}") + else: + print("[+] Kuksa client authorized successfully.") + else: + print("[-] Error: Kuksa client is not initialized. Call `connect_kuksa_client` first.") + + + def send_values(self, path=None, value=None): + """ + Send values to the Kuksa server. + + Args: + path (str): The path to the value. + value (str): The value to be sent. + Returns: + bool: True if the value was set, False otherwise. + """ + result = False + if self.kuksa_client is None: + print("[-] Error: Kuksa client is not initialized.") + return + + if self.get_kuksa_status(): + try: + response = self.kuksa_client.setValue(path, value) + response = json.loads(response) + if not "error" in response: + print(f"[+] Value '{value}' sent to Kuksa successfully.") + result = True + else: + error_message = response.get("error", "Unknown error") + print(f"[-] Error: Failed to send value '{value}' to Kuksa. {error_message}") + + except Exception as e: + print("[-] Error: Failed to send values to Kuksa. ", str(e)) + + else: + print("[-] Error: Connection to Kuksa failed.") + + return result - def send_values(self, Path=None, Value=None): + def get_value(self, path=None): + """ + Get values from the Kuksa server. + + Args: + path (str): The path to the value. + Returns: + str: The value if the path is valid, None otherwise. + """ + result = None if self.kuksa_client is None: - print("Error: Kuksa client is not initialized.") + print("[-] Error: Kuksa client is not initialized.") + return if self.get_kuksa_status(): - self.kuksa_client.setValue(Path, Value) + try: + response = self.kuksa_client.getValue(path) + response = json.loads(response) + if not "error" in response: + result = response.get("data", None) + result = result.get("dp", None) + result = result.get("value", None) + + else: + error_message = response.get("error", "Unknown error") + print(f"[-] Error: Failed to get value from Kuksa. {error_message}") + + except Exception as e: + print("[-] Error: Failed to get values from Kuksa. ", str(e)) else: - print("Error: Connection to Kuksa failed.") + print("[-] Error: Connection to Kuksa failed.") + + return result + + + def close_kuksa_client(self): + """ + Close and stop the Kuksa client. + """ + try: + with self._lock: + if self.kuksa_client: + self.kuksa_client.stop() + self.kuksa_client = None + print("[+] Kuksa client stopped.") + except Exception as e: + print("[-] Error: Failed to close Kuksa client. ", str(e))
\ No newline at end of file diff --git a/agl_service_voiceagent/utils/mapper.py b/agl_service_voiceagent/utils/mapper.py new file mode 100644 index 0000000..7529645 --- /dev/null +++ b/agl_service_voiceagent/utils/mapper.py @@ -0,0 +1,261 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from agl_service_voiceagent.utils.config import get_config_value +from agl_service_voiceagent.utils.common import load_json_file, words_to_number + + +class Intent2VSSMapper: + """ + Intent2VSSMapper is a class that facilitates the mapping of natural language intent to + corresponding vehicle signal specifications (VSS) for automated vehicle control systems. + """ + + def __init__(self): + """ + Initializes the Intent2VSSMapper class by loading Intent-to-VSS signal mappings + and VSS signal specifications from external configuration files. + """ + intents_vss_map_file = get_config_value("intents_vss_map", "Mapper") + vss_signals_spec_file = get_config_value("vss_signals_spec", "Mapper") + self.intents_vss_map = load_json_file(intents_vss_map_file).get("intents", {}) + self.vss_signals_spec = load_json_file(vss_signals_spec_file).get("signals", {}) + + if not self.validate_signal_spec_structure(): + raise ValueError("[-] Invalid VSS signal specification structure.") + + def validate_signal_spec_structure(self): + """ + Validates the structure of the VSS signal specification data. + """ + + signals = self.vss_signals_spec + + # Iterate over each signal in the 'signals' dictionary + for signal_name, signal_data in signals.items(): + # Check if the required keys are present in the signal data + if not all(key in signal_data for key in ['default_value', 'default_change_factor', 'actions', 'values', 'default_fallback', 'value_set_intents']): + print(f"[-] {signal_name}: Missing required keys in signal data.") + return False + + actions = signal_data['actions'] + + # Check if 'actions' is a dictionary with at least one action + if not isinstance(actions, dict) or not actions: + print(f"[-] {signal_name}: Invalid 'actions' key in signal data. Must be an object with at least one action.") + return False + + # Check if the actions match the allowed actions ["set", "increase", "decrease"] + for action in actions.keys(): + if action not in ["set", "increase", "decrease"]: + print(f"[-] {signal_name}: Invalid action in signal data. Allowed actions: ['set', 'increase', 'decrease']") + return False + + # Check if the 'synonyms' list is present for each action and is either a list or None + for action_data in actions.values(): + synonyms = action_data.get('synonyms') + if synonyms is not None and (not isinstance(synonyms, list) or not all(isinstance(synonym, str) for synonym in synonyms)): + print(f"[-] {signal_name}: Invalid 'synonyms' value in signal data. Must be a list of strings.") + return False + + values = signal_data['values'] + + # Check if 'values' is a dictionary with the required keys + if not isinstance(values, dict) or not all(key in values for key in ['ranged', 'start', 'end', 'ignore', 'additional']): + print(f"[-] {signal_name}: Invalid 'values' key in signal data. Required keys: ['ranged', 'start', 'end', 'ignore', 'additional']") + return False + + # Check if 'ranged' is a boolean + if not isinstance(values['ranged'], bool): + print(f"[-] {signal_name}: Invalid 'ranged' value in signal data. Allowed values: [true, false]") + return False + + default_fallback = signal_data['default_fallback'] + + # Check if 'default_fallback' is a boolean + if not isinstance(default_fallback, bool): + print(f"[-] {signal_name}: Invalid 'default_fallback' value in signal data. Allowed values: [true, false]") + return False + + # If all checks pass, the self.vss_signals_spec structure is valid + return True + + + def map_intent_to_signal(self, intent_name): + """ + Maps an intent name to the corresponding VSS signals and their specifications. + + Args: + intent_name (str): The name of the intent to be mapped. + + Returns: + dict: A dictionary containing VSS signals as keys and their specifications as values. + """ + + intent_data = self.intents_vss_map.get(intent_name, None) + result = {} + if intent_data: + signals = intent_data.get("signals", []) + + for signal in signals: + signal_info = self.vss_signals_spec.get(signal, {}) + if signal_info: + result.update({signal: signal_info}) + + return result + + + def parse_intent(self, intent_name, intent_slots = []): + """ + Parses an intent, extracting relevant VSS signals, actions, modifiers, and values + based on the intent and its associated slots. + + Args: + intent_name (str): The name of the intent to be parsed. + intent_slots (list): A list of dictionaries representing intent slots. + + Returns: + list: A list of dictionaries describing actions and signal-related details for execution. + + Note: + - If no relevant VSS signals are found for the intent, an empty list is returned. + - If no specific action or modifier is determined, default values are used. + """ + vss_signal_data = self.map_intent_to_signal(intent_name) + execution_list = [] + for signal_name, signal_data in vss_signal_data.items(): + action = self.determine_action(signal_data, intent_slots) + modifier = self.determine_modifier(signal_data, intent_slots) + value = self.determine_value(signal_data, intent_slots) + + if value != None and not self.verify_value(signal_data, value): + value = None + + change_factor = signal_data["default_change_factor"] + + if action in ["increase", "decrease"]: + if value and modifier == "to": + execution_list.append({"action": action, "signal": signal_name, "value": str(value)}) + + elif value and modifier == "by": + execution_list.append({"action": action, "signal": signal_name, "factor": str(value)}) + + elif value: + execution_list.append({"action": action, "signal": signal_name, "value": str(value)}) + + elif signal_data["default_fallback"]: + execution_list.append({"action": action, "signal": signal_name, "factor": str(change_factor)}) + + # if no value found set the default value + if value == None and signal_data["default_fallback"]: + value = signal_data["default_value"] + + if action == "set" and value != None: + execution_list.append({"action": action, "signal": signal_name, "value": str(value)}) + + + return execution_list + + + def determine_action(self, signal_data, intent_slots): + """ + Determines the action (e.g., set, increase, decrease) based on the intent slots + and VSS signal data. + + Args: + signal_data (dict): The specification data for a VSS signal. + intent_slots (list): A list of dictionaries representing intent slots. + + Returns: + str: The determined action or None if no action can be determined. + """ + action_res = None + for intent_slot in intent_slots: + for action, action_data in signal_data["actions"].items(): + if intent_slot["name"] in action_data["intents"] and intent_slot["value"] in action_data["synonyms"]: + action_res = action + break + + return action_res + + + def determine_modifier(self, signal_data, intent_slots): + """ + Determines the modifier (e.g., 'to' or 'by') based on the intent slots + and VSS signal data. + + Args: + signal_data (dict): The specification data for a VSS signal. + intent_slots (list): A list of dictionaries representing intent slots. + + Returns: + str: The determined modifier or None if no modifier can be determined. + """ + modifier_res = None + for intent_slot in intent_slots: + for _, action_data in signal_data["actions"].items(): + intent_val = intent_slot["value"] + if "modifier_intents" in action_data and intent_slot["name"] in action_data["modifier_intents"] and ("to" in intent_val or "by" in intent_val): + modifier_res = "to" if "to" in intent_val else "by" if "by" in intent_val else None + break + + return modifier_res + + + def determine_value(self, signal_data, intent_slots): + """ + Determines the value associated with the intent slot, considering the data type + and converting it to a numeric string representation if necessary. + + Args: + signal_data (dict): The specification data for a VSS signal. + intent_slots (list): A list of dictionaries representing intent slots. + + Returns: + str: The determined value or None if no value can be determined. + """ + result = None + for intent_slot in intent_slots: + for value, value_data in signal_data["value_set_intents"].items(): + if intent_slot["name"] == value: + result = intent_slot["value"] + + if value_data["datatype"] == "number": + result = words_to_number(result) # we assume our model will always return a number in words + + # the value should always returned as str because Kuksa expects str values + return str(result) if result != None else None + + + def verify_value(self, signal_data, value): + """ + Verifies that the value is valid based on the VSS signal data. + + Args: + signal_data (dict): The specification data for a VSS signal. + value (str): The value to be verified. + + Returns: + bool: True if the value is valid, False otherwise. + """ + if value in signal_data["values"]["ignore"]: + return False + + elif signal_data["values"]["ranged"] and isinstance(value, (int, float)): + return value >= signal_data["values"]["start"] and value <= signal_data["values"]["end"] + + else: + return value in signal_data["values"]["additional"] diff --git a/agl_service_voiceagent/utils/stt_model.py b/agl_service_voiceagent/utils/stt_model.py index 5337162..d51ae31 100644 --- a/agl_service_voiceagent/utils/stt_model.py +++ b/agl_service_voiceagent/utils/stt_model.py @@ -21,21 +21,61 @@ import wave from agl_service_voiceagent.utils.common import generate_unique_uuid class STTModel: + """ + STTModel is a class for speech-to-text (STT) recognition using the Vosk speech recognition library. + """ + def __init__(self, model_path, sample_rate=16000): + """ + Initialize the STTModel instance with the provided model and sample rate. + + Args: + model_path (str): The path to the Vosk speech recognition model. + sample_rate (int, optional): The audio sample rate in Hz (default is 16000). + """ self.sample_rate = sample_rate self.model = vosk.Model(model_path) self.recognizer = {} self.chunk_size = 1024 + def setup_recognizer(self): + """ + Set up a Vosk recognizer for a new session and return a unique identifier (UUID) for the session. + + Returns: + str: A unique identifier (UUID) for the session. + """ uuid = generate_unique_uuid(6) self.recognizer[uuid] = vosk.KaldiRecognizer(self.model, self.sample_rate) return uuid + def init_recognition(self, uuid, audio_data): + """ + Initialize the Vosk recognizer for a session with audio data. + + Args: + uuid (str): The unique identifier (UUID) for the session. + audio_data (bytes): Audio data to process. + + Returns: + bool: True if initialization was successful, False otherwise. + """ return self.recognizer[uuid].AcceptWaveform(audio_data) + def recognize(self, uuid, partial=False): + """ + Recognize speech and return the result as a JSON object. + + Args: + uuid (str): The unique identifier (UUID) for the session. + partial (bool, optional): If True, return partial recognition results (default is False). + + Returns: + dict: A JSON object containing recognition results. + """ self.recognizer[uuid].SetWords(True) if partial: result = json.loads(self.recognizer[uuid].PartialResult()) @@ -44,7 +84,18 @@ class STTModel: self.recognizer[uuid].Reset() return result + def recognize_from_file(self, uuid, filename): + """ + Recognize speech from an audio file and return the recognized text. + + Args: + uuid (str): The unique identifier (UUID) for the session. + filename (str): The path to the audio file. + + Returns: + str: The recognized text or error messages. + """ if not os.path.exists(filename): print(f"Audio file '{filename}' not found.") return "FILE_NOT_FOUND" @@ -75,31 +126,13 @@ class STTModel: print("Voice not recognized. Please speak again...") return "VOICE_NOT_RECOGNIZED" - def cleanup_recognizer(self, uuid): - del self.recognizer[uuid] -import wave - -def read_wav_file(filename, chunk_size=1024): - try: - wf = wave.open(filename, "rb") - if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE": - print("Audio file must be WAV format mono PCM.") - return "FILE_FORMAT_INVALID" - - audio_data = b"" # Initialize an empty bytes object to store audio data - while True: - chunk = wf.readframes(chunk_size) - if not chunk: - break # End of file reached - audio_data += chunk - - return audio_data - except Exception as e: - print(f"Error reading audio file: {e}") - return None + def cleanup_recognizer(self, uuid): + """ + Clean up and remove the Vosk recognizer for a session. -# Example usage: -filename = "your_audio.wav" -audio_data = read_wav_file(filename) + Args: + uuid (str): The unique identifier (UUID) for the session. + """ + del self.recognizer[uuid]
\ No newline at end of file diff --git a/agl_service_voiceagent/utils/wake_word.py b/agl_service_voiceagent/utils/wake_word.py index 066ae6d..47e547e 100644 --- a/agl_service_voiceagent/utils/wake_word.py +++ b/agl_service_voiceagent/utils/wake_word.py @@ -15,7 +15,6 @@ # limitations under the License. import gi -import vosk gi.require_version('Gst', '1.0') from gi.repository import Gst, GLib @@ -23,7 +22,21 @@ Gst.init(None) GLib.threads_init() class WakeWordDetector: + """ + WakeWordDetector is a class for detecting a wake word in an audio stream using GStreamer and Vosk. + """ + def __init__(self, wake_word, stt_model, channels=1, sample_rate=16000, bits_per_sample=16): + """ + Initialize the WakeWordDetector instance with the provided parameters. + + Args: + wake_word (str): The wake word to detect in the audio stream. + stt_model (STTModel): An instance of the STTModel for speech-to-text recognition. + channels (int, optional): The number of audio channels (default is 1). + sample_rate (int, optional): The audio sample rate in Hz (default is 16000). + bits_per_sample (int, optional): The number of bits per sample (default is 16). + """ self.loop = GLib.MainLoop() self.pipeline = None self.bus = None @@ -32,16 +45,25 @@ class WakeWordDetector: self.sample_rate = sample_rate self.channels = channels self.bits_per_sample = bits_per_sample - self.frame_size = int(self.sample_rate * 0.02) - self.stt_model = stt_model # Speech to text model recognizer + self.wake_word_model = stt_model # Speech to text model recognizer self.recognizer_uuid = stt_model.setup_recognizer() - self.buffer_duration = 1 # Buffer audio for atleast 1 second self.audio_buffer = bytearray() + self.segment_size = int(self.sample_rate * 1.0) # Adjust the segment size (e.g., 1 second) + def get_wake_word_status(self): + """ + Get the status of wake word detection. + + Returns: + bool: True if the wake word has been detected, False otherwise. + """ return self.wake_word_detected def create_pipeline(self): + """ + Create and configure the GStreamer audio processing pipeline for wake word detection. + """ print("Creating pipeline for Wake Word Detection...") self.pipeline = Gst.Pipeline() autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None) @@ -77,49 +99,87 @@ class WakeWordDetector: self.bus.add_signal_watch() self.bus.connect("message", self.on_bus_message) + def on_new_buffer(self, appsink, data) -> Gst.FlowReturn: + """ + Callback function to handle new audio buffers from GStreamer appsink. + + Args: + appsink (Gst.AppSink): The GStreamer appsink. + data (object): User data (not used). + + Returns: + Gst.FlowReturn: Indicates the status of buffer processing. + """ sample = appsink.emit("pull-sample") buffer = sample.get_buffer() data = buffer.extract_dup(0, buffer.get_size()) + + # Add the new data to the buffer self.audio_buffer.extend(data) - if len(self.audio_buffer) >= self.sample_rate * self.buffer_duration * self.channels * self.bits_per_sample // 8: - self.process_audio_buffer() + # Process audio in segments + while len(self.audio_buffer) >= self.segment_size: + segment = self.audio_buffer[:self.segment_size] + self.process_audio_segment(segment) + + # Advance the buffer by the segment size + self.audio_buffer = self.audio_buffer[self.segment_size:] return Gst.FlowReturn.OK - - def process_audio_buffer(self): - # Process the accumulated audio data using the audio model - audio_data = bytes(self.audio_buffer) - if self.stt_model.init_recognition(self.recognizer_uuid, audio_data): - stt_result = self.stt_model.recognize(self.recognizer_uuid) + def process_audio_segment(self, segment): + """ + Process an audio segment for wake word detection. + + Args: + segment (bytes): The audio segment to process. + """ + # Process the audio data segment + audio_data = bytes(segment) + + # Perform wake word detection on the audio_data + if self.wake_word_model.init_recognition(self.recognizer_uuid, audio_data): + stt_result = self.wake_word_model.recognize(self.recognizer_uuid) print("STT Result: ", stt_result) if self.wake_word in stt_result["text"]: self.wake_word_detected = True print("Wake word detected!") self.pipeline.send_event(Gst.Event.new_eos()) - self.audio_buffer.clear() # Clear the buffer - - def send_eos(self): + """ + Send an End-of-Stream (EOS) event to the pipeline. + """ self.pipeline.send_event(Gst.Event.new_eos()) self.audio_buffer.clear() def start_listening(self): + """ + Start listening for the wake word and enter the event loop. + """ self.pipeline.set_state(Gst.State.PLAYING) print("Listening for Wake Word...") self.loop.run() def stop_listening(self): + """ + Stop listening for the wake word and clean up the pipeline. + """ self.cleanup_pipeline() self.loop.quit() def on_bus_message(self, bus, message): + """ + Handle GStreamer bus messages and perform actions based on the message type. + + Args: + bus (Gst.Bus): The GStreamer bus. + message (Gst.Message): The GStreamer message to process. + """ if message.type == Gst.MessageType.EOS: print("End-of-stream message received") self.stop_listening() @@ -140,6 +200,9 @@ class WakeWordDetector: def cleanup_pipeline(self): + """ + Clean up the GStreamer pipeline and release associated resources. + """ if self.pipeline is not None: print("Cleaning up pipeline...") self.pipeline.set_state(Gst.State.NULL) @@ -147,4 +210,4 @@ class WakeWordDetector: print("Pipeline cleanup complete!") self.bus = None self.pipeline = None - self.stt_model.cleanup_recognizer(self.recognizer_uuid) + self.wake_word_model.cleanup_recognizer(self.recognizer_uuid) |