diff options
Diffstat (limited to 'agl_service_voiceagent/utils')
-rw-r--r-- | agl_service_voiceagent/utils/audio_recorder.py | 52 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/common.py | 91 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/config.py | 34 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/kuksa_interface.py | 178 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/mapper.py | 261 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/stt_model.py | 83 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/wake_word.py | 95 |
7 files changed, 717 insertions, 77 deletions
diff --git a/agl_service_voiceagent/utils/audio_recorder.py b/agl_service_voiceagent/utils/audio_recorder.py index 61ce994..2e8f11d 100644 --- a/agl_service_voiceagent/utils/audio_recorder.py +++ b/agl_service_voiceagent/utils/audio_recorder.py @@ -15,7 +15,6 @@ # limitations under the License. import gi -import vosk import time gi.require_version('Gst', '1.0') from gi.repository import Gst, GLib @@ -24,7 +23,21 @@ Gst.init(None) GLib.threads_init() class AudioRecorder: + """ + AudioRecorder is a class for recording audio using GStreamer in various modes. + """ + def __init__(self, stt_model, audio_files_basedir, channels=1, sample_rate=16000, bits_per_sample=16): + """ + Initialize the AudioRecorder instance with the provided parameters. + + Args: + stt_model (str): The speech-to-text model to use for voice input recognition. + audio_files_basedir (str): The base directory for saving audio files. + channels (int, optional): The number of audio channels (default is 1). + sample_rate (int, optional): The audio sample rate in Hz (default is 16000). + bits_per_sample (int, optional): The number of bits per sample (default is 16). + """ self.loop = GLib.MainLoop() self.mode = None self.pipeline = None @@ -43,6 +56,12 @@ class AudioRecorder: def create_pipeline(self): + """ + Create and configure the GStreamer audio recording pipeline. + + Returns: + str: The name of the audio file being recorded. + """ print("Creating pipeline for audio recording in {} mode...".format(self.mode)) self.pipeline = Gst.Pipeline() autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None) @@ -83,32 +102,50 @@ class AudioRecorder: def start_recording(self): + """ + Start recording audio using the GStreamer pipeline. + """ self.pipeline.set_state(Gst.State.PLAYING) print("Recording Voice Input...") def stop_recording(self): + """ + Stop audio recording and clean up the GStreamer pipeline. + """ print("Stopping recording...") - self.frames_above_threshold = 0 - self.cleanup_pipeline() + # self.cleanup_pipeline() + self.pipeline.send_event(Gst.Event.new_eos()) print("Recording finished!") def set_pipeline_mode(self, mode): + """ + Set the recording mode to 'auto' or 'manual'. + + Args: + mode (str): The recording mode ('auto' or 'manual'). + """ self.mode = mode - # this method helps with error handling def on_bus_message(self, bus, message): + """ + Handle GStreamer bus messages and perform actions based on the message type. + + Args: + bus (Gst.Bus): The GStreamer bus. + message (Gst.Message): The GStreamer message to process. + """ if message.type == Gst.MessageType.EOS: print("End-of-stream message received") - self.stop_recording() + self.cleanup_pipeline() elif message.type == Gst.MessageType.ERROR: err, debug_info = message.parse_error() print(f"Error received from element {message.src.get_name()}: {err.message}") print(f"Debugging information: {debug_info}") - self.stop_recording() + self.cleanup_pipeline() elif message.type == Gst.MessageType.WARNING: err, debug_info = message.parse_warning() @@ -136,6 +173,9 @@ class AudioRecorder: def cleanup_pipeline(self): + """ + Clean up the GStreamer pipeline, set it to NULL state, and remove the signal watch. + """ if self.pipeline is not None: print("Cleaning up pipeline...") self.pipeline.set_state(Gst.State.NULL) diff --git a/agl_service_voiceagent/utils/common.py b/agl_service_voiceagent/utils/common.py index 682473e..b9d6577 100644 --- a/agl_service_voiceagent/utils/common.py +++ b/agl_service_voiceagent/utils/common.py @@ -20,12 +20,18 @@ import json def add_trailing_slash(path): + """ + Adds a trailing slash to a path if it does not already have one. + """ if path and not path.endswith('/'): path += '/' return path def generate_unique_uuid(length): + """ + Generates a unique ID of specified length. + """ unique_id = str(uuid.uuid4().int) # Ensure the generated ID is exactly 'length' digits by taking the last 'length' characters unique_id = unique_id[-length:] @@ -33,18 +39,91 @@ def generate_unique_uuid(length): def load_json_file(file_path): - try: - with open(file_path, 'r') as file: - return json.load(file) - except FileNotFoundError: - raise ValueError(f"File '{file_path}' not found.") + """ + Loads a JSON file and returns the data. + """ + try: + with open(file_path, 'r') as file: + return json.load(file) + except FileNotFoundError: + raise ValueError(f"File '{file_path}' not found.") def delete_file(file_path): + """ + Deletes a file if it exists. + """ if os.path.exists(file_path): try: os.remove(file_path) except Exception as e: print(f"Error deleting '{file_path}': {e}") else: - print(f"File '{file_path}' does not exist.")
\ No newline at end of file + print(f"File '{file_path}' does not exist.") + + +def words_to_number(words): + """ + Converts a string of words to a number. + """ + word_to_number = { + 'one': 1, + 'two': 2, + 'three': 3, + 'four': 4, + 'five': 5, + 'six': 6, + 'seven': 7, + 'eight': 8, + 'nine': 9, + 'ten': 10, + 'eleven': 11, + 'twelve': 12, + 'thirteen': 13, + 'fourteen': 14, + 'fifteen': 15, + 'sixteen': 16, + 'seventeen': 17, + 'eighteen': 18, + 'nineteen': 19, + 'twenty': 20, + 'thirty': 30, + 'forty': 40, + 'fourty': 40, + 'fifty': 50, + 'sixty': 60, + 'seventy': 70, + 'eighty': 80, + 'ninety': 90 + } + + # Split the input words and initialize total and current number + words = words.split() + + if len(words) == 1: + if words[0] in ["zero", "min", "minimum"]: + return 0 + + elif words[0] in ["half", "halfway"]: + return 50 + + elif words[0] in ["max", "maximum", "full", "fully", "completely", "hundred"]: + return 100 + + + total = 0 + current_number = 0 + + for word in words: + if word in word_to_number: + current_number += word_to_number[word] + elif word == 'hundred': + current_number *= 100 + else: + total += current_number + current_number = 0 + + total += current_number + + # we return number in str format because kuksa expects str input + return str(total) or None
\ No newline at end of file diff --git a/agl_service_voiceagent/utils/config.py b/agl_service_voiceagent/utils/config.py index 8d7f346..7295c7f 100644 --- a/agl_service_voiceagent/utils/config.py +++ b/agl_service_voiceagent/utils/config.py @@ -14,21 +14,41 @@ # See the License for the specific language governing permissions and # limitations under the License. -import os import configparser -# Get the absolute path to the directory of the current script -current_dir = os.path.dirname(os.path.abspath(__file__)) -# Construct the path to the config.ini file located in the base directory -config_path = os.path.join(current_dir, '..', 'config.ini') - config = configparser.ConfigParser() -config.read(config_path) +config_path = None + +def set_config_path(path): + """ + Sets the path to the config file. + """ + global config_path + config_path = path + config.read(config_path) + +def load_config(): + """ + Loads the config file. + """ + if config_path is not None: + config.read(config_path) + else: + raise Exception("Config file path not provided.") def update_config_value(value, key, group="General"): + """ + Updates a value in the config file. + """ + if config_path is None: + raise Exception("Config file path not set.") + config.set(group, key, value) with open(config_path, 'w') as configfile: config.write(configfile) def get_config_value(key, group="General"): + """ + Gets a value from the config file. + """ return config.get(group, key) diff --git a/agl_service_voiceagent/utils/kuksa_interface.py b/agl_service_voiceagent/utils/kuksa_interface.py index 3e1c045..9270379 100644 --- a/agl_service_voiceagent/utils/kuksa_interface.py +++ b/agl_service_voiceagent/utils/kuksa_interface.py @@ -14,53 +14,197 @@ # See the License for the specific language governing permissions and # limitations under the License. +import time +import json +import threading from kuksa_client import KuksaClientThread from agl_service_voiceagent.utils.config import get_config_value class KuksaInterface: - def __init__(self): + """ + Kuksa Interface + + This class provides methods to initialize, authorize, connect, send values, + check the status, and close the Kuksa client. + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls): + """ + Get the unique instance of the class. + + Returns: + KuksaInterface: The instance of the class. + """ + with cls._lock: + if cls._instance is None: + cls._instance = super(KuksaInterface, cls).__new__(cls) + cls._instance.init_client() + return cls._instance + + + def init_client(self): + """ + Initialize the Kuksa client configuration. + """ # get config values - self.ip = get_config_value("ip", "Kuksa") - self.port = get_config_value("port", "Kuksa") + self.ip = str(get_config_value("ip", "Kuksa")) + self.port = str(get_config_value("port", "Kuksa")) self.insecure = get_config_value("insecure", "Kuksa") self.protocol = get_config_value("protocol", "Kuksa") self.token = get_config_value("token", "Kuksa") + print(self.ip, self.port, self.insecure, self.protocol, self.token) + # define class methods self.kuksa_client = None def get_kuksa_client(self): + """ + Get the Kuksa client instance. + + Returns: + KuksaClientThread: The Kuksa client instance. + """ return self.kuksa_client - + def get_kuksa_status(self): + """ + Check the status of the Kuksa client connection. + + Returns: + bool: True if the client is connected, False otherwise. + """ if self.kuksa_client: return self.kuksa_client.checkConnection() - return False def connect_kuksa_client(self): + """ + Connect and start the Kuksa client. + """ try: - self.kuksa_client = KuksaClientThread({ - "ip": self.ip, - "port": self.port, - "insecure": self.insecure, - "protocol": self.protocol, - }) - self.kuksa_client.authorize(self.token) + with self._lock: + if self.kuksa_client is None: + self.kuksa_client = KuksaClientThread({ + "ip": self.ip, + "port": self.port, + "insecure": self.insecure, + "protocol": self.protocol, + }) + self.kuksa_client.start() + time.sleep(2) # Give the thread time to start + + if not self.get_kuksa_status(): + print("[-] Error: Connection to Kuksa server failed.") + else: + print("[+] Connection to Kuksa established.") except Exception as e: - print("Error: ", e) + print("[-] Error: Connection to Kuksa server failed. ", str(e)) + + + def authorize_kuksa_client(self): + """ + Authorize the Kuksa client with the provided token. + """ + if self.kuksa_client: + response = self.kuksa_client.authorize(self.token) + response = json.loads(response) + if "error" in response: + error_message = response.get("error", "Unknown error") + print(f"[-] Error: Authorization failed. {error_message}") + else: + print("[+] Kuksa client authorized successfully.") + else: + print("[-] Error: Kuksa client is not initialized. Call `connect_kuksa_client` first.") + + + def send_values(self, path=None, value=None): + """ + Send values to the Kuksa server. + + Args: + path (str): The path to the value. + value (str): The value to be sent. + Returns: + bool: True if the value was set, False otherwise. + """ + result = False + if self.kuksa_client is None: + print("[-] Error: Kuksa client is not initialized.") + return + + if self.get_kuksa_status(): + try: + response = self.kuksa_client.setValue(path, value) + response = json.loads(response) + if not "error" in response: + print(f"[+] Value '{value}' sent to Kuksa successfully.") + result = True + else: + error_message = response.get("error", "Unknown error") + print(f"[-] Error: Failed to send value '{value}' to Kuksa. {error_message}") + + except Exception as e: + print("[-] Error: Failed to send values to Kuksa. ", str(e)) + + else: + print("[-] Error: Connection to Kuksa failed.") + + return result - def send_values(self, Path=None, Value=None): + def get_value(self, path=None): + """ + Get values from the Kuksa server. + + Args: + path (str): The path to the value. + Returns: + str: The value if the path is valid, None otherwise. + """ + result = None if self.kuksa_client is None: - print("Error: Kuksa client is not initialized.") + print("[-] Error: Kuksa client is not initialized.") + return if self.get_kuksa_status(): - self.kuksa_client.setValue(Path, Value) + try: + response = self.kuksa_client.getValue(path) + response = json.loads(response) + if not "error" in response: + result = response.get("data", None) + result = result.get("dp", None) + result = result.get("value", None) + + else: + error_message = response.get("error", "Unknown error") + print(f"[-] Error: Failed to get value from Kuksa. {error_message}") + + except Exception as e: + print("[-] Error: Failed to get values from Kuksa. ", str(e)) else: - print("Error: Connection to Kuksa failed.") + print("[-] Error: Connection to Kuksa failed.") + + return result + + + def close_kuksa_client(self): + """ + Close and stop the Kuksa client. + """ + try: + with self._lock: + if self.kuksa_client: + self.kuksa_client.stop() + self.kuksa_client = None + print("[+] Kuksa client stopped.") + except Exception as e: + print("[-] Error: Failed to close Kuksa client. ", str(e))
\ No newline at end of file diff --git a/agl_service_voiceagent/utils/mapper.py b/agl_service_voiceagent/utils/mapper.py new file mode 100644 index 0000000..7529645 --- /dev/null +++ b/agl_service_voiceagent/utils/mapper.py @@ -0,0 +1,261 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from agl_service_voiceagent.utils.config import get_config_value +from agl_service_voiceagent.utils.common import load_json_file, words_to_number + + +class Intent2VSSMapper: + """ + Intent2VSSMapper is a class that facilitates the mapping of natural language intent to + corresponding vehicle signal specifications (VSS) for automated vehicle control systems. + """ + + def __init__(self): + """ + Initializes the Intent2VSSMapper class by loading Intent-to-VSS signal mappings + and VSS signal specifications from external configuration files. + """ + intents_vss_map_file = get_config_value("intents_vss_map", "Mapper") + vss_signals_spec_file = get_config_value("vss_signals_spec", "Mapper") + self.intents_vss_map = load_json_file(intents_vss_map_file).get("intents", {}) + self.vss_signals_spec = load_json_file(vss_signals_spec_file).get("signals", {}) + + if not self.validate_signal_spec_structure(): + raise ValueError("[-] Invalid VSS signal specification structure.") + + def validate_signal_spec_structure(self): + """ + Validates the structure of the VSS signal specification data. + """ + + signals = self.vss_signals_spec + + # Iterate over each signal in the 'signals' dictionary + for signal_name, signal_data in signals.items(): + # Check if the required keys are present in the signal data + if not all(key in signal_data for key in ['default_value', 'default_change_factor', 'actions', 'values', 'default_fallback', 'value_set_intents']): + print(f"[-] {signal_name}: Missing required keys in signal data.") + return False + + actions = signal_data['actions'] + + # Check if 'actions' is a dictionary with at least one action + if not isinstance(actions, dict) or not actions: + print(f"[-] {signal_name}: Invalid 'actions' key in signal data. Must be an object with at least one action.") + return False + + # Check if the actions match the allowed actions ["set", "increase", "decrease"] + for action in actions.keys(): + if action not in ["set", "increase", "decrease"]: + print(f"[-] {signal_name}: Invalid action in signal data. Allowed actions: ['set', 'increase', 'decrease']") + return False + + # Check if the 'synonyms' list is present for each action and is either a list or None + for action_data in actions.values(): + synonyms = action_data.get('synonyms') + if synonyms is not None and (not isinstance(synonyms, list) or not all(isinstance(synonym, str) for synonym in synonyms)): + print(f"[-] {signal_name}: Invalid 'synonyms' value in signal data. Must be a list of strings.") + return False + + values = signal_data['values'] + + # Check if 'values' is a dictionary with the required keys + if not isinstance(values, dict) or not all(key in values for key in ['ranged', 'start', 'end', 'ignore', 'additional']): + print(f"[-] {signal_name}: Invalid 'values' key in signal data. Required keys: ['ranged', 'start', 'end', 'ignore', 'additional']") + return False + + # Check if 'ranged' is a boolean + if not isinstance(values['ranged'], bool): + print(f"[-] {signal_name}: Invalid 'ranged' value in signal data. Allowed values: [true, false]") + return False + + default_fallback = signal_data['default_fallback'] + + # Check if 'default_fallback' is a boolean + if not isinstance(default_fallback, bool): + print(f"[-] {signal_name}: Invalid 'default_fallback' value in signal data. Allowed values: [true, false]") + return False + + # If all checks pass, the self.vss_signals_spec structure is valid + return True + + + def map_intent_to_signal(self, intent_name): + """ + Maps an intent name to the corresponding VSS signals and their specifications. + + Args: + intent_name (str): The name of the intent to be mapped. + + Returns: + dict: A dictionary containing VSS signals as keys and their specifications as values. + """ + + intent_data = self.intents_vss_map.get(intent_name, None) + result = {} + if intent_data: + signals = intent_data.get("signals", []) + + for signal in signals: + signal_info = self.vss_signals_spec.get(signal, {}) + if signal_info: + result.update({signal: signal_info}) + + return result + + + def parse_intent(self, intent_name, intent_slots = []): + """ + Parses an intent, extracting relevant VSS signals, actions, modifiers, and values + based on the intent and its associated slots. + + Args: + intent_name (str): The name of the intent to be parsed. + intent_slots (list): A list of dictionaries representing intent slots. + + Returns: + list: A list of dictionaries describing actions and signal-related details for execution. + + Note: + - If no relevant VSS signals are found for the intent, an empty list is returned. + - If no specific action or modifier is determined, default values are used. + """ + vss_signal_data = self.map_intent_to_signal(intent_name) + execution_list = [] + for signal_name, signal_data in vss_signal_data.items(): + action = self.determine_action(signal_data, intent_slots) + modifier = self.determine_modifier(signal_data, intent_slots) + value = self.determine_value(signal_data, intent_slots) + + if value != None and not self.verify_value(signal_data, value): + value = None + + change_factor = signal_data["default_change_factor"] + + if action in ["increase", "decrease"]: + if value and modifier == "to": + execution_list.append({"action": action, "signal": signal_name, "value": str(value)}) + + elif value and modifier == "by": + execution_list.append({"action": action, "signal": signal_name, "factor": str(value)}) + + elif value: + execution_list.append({"action": action, "signal": signal_name, "value": str(value)}) + + elif signal_data["default_fallback"]: + execution_list.append({"action": action, "signal": signal_name, "factor": str(change_factor)}) + + # if no value found set the default value + if value == None and signal_data["default_fallback"]: + value = signal_data["default_value"] + + if action == "set" and value != None: + execution_list.append({"action": action, "signal": signal_name, "value": str(value)}) + + + return execution_list + + + def determine_action(self, signal_data, intent_slots): + """ + Determines the action (e.g., set, increase, decrease) based on the intent slots + and VSS signal data. + + Args: + signal_data (dict): The specification data for a VSS signal. + intent_slots (list): A list of dictionaries representing intent slots. + + Returns: + str: The determined action or None if no action can be determined. + """ + action_res = None + for intent_slot in intent_slots: + for action, action_data in signal_data["actions"].items(): + if intent_slot["name"] in action_data["intents"] and intent_slot["value"] in action_data["synonyms"]: + action_res = action + break + + return action_res + + + def determine_modifier(self, signal_data, intent_slots): + """ + Determines the modifier (e.g., 'to' or 'by') based on the intent slots + and VSS signal data. + + Args: + signal_data (dict): The specification data for a VSS signal. + intent_slots (list): A list of dictionaries representing intent slots. + + Returns: + str: The determined modifier or None if no modifier can be determined. + """ + modifier_res = None + for intent_slot in intent_slots: + for _, action_data in signal_data["actions"].items(): + intent_val = intent_slot["value"] + if "modifier_intents" in action_data and intent_slot["name"] in action_data["modifier_intents"] and ("to" in intent_val or "by" in intent_val): + modifier_res = "to" if "to" in intent_val else "by" if "by" in intent_val else None + break + + return modifier_res + + + def determine_value(self, signal_data, intent_slots): + """ + Determines the value associated with the intent slot, considering the data type + and converting it to a numeric string representation if necessary. + + Args: + signal_data (dict): The specification data for a VSS signal. + intent_slots (list): A list of dictionaries representing intent slots. + + Returns: + str: The determined value or None if no value can be determined. + """ + result = None + for intent_slot in intent_slots: + for value, value_data in signal_data["value_set_intents"].items(): + if intent_slot["name"] == value: + result = intent_slot["value"] + + if value_data["datatype"] == "number": + result = words_to_number(result) # we assume our model will always return a number in words + + # the value should always returned as str because Kuksa expects str values + return str(result) if result != None else None + + + def verify_value(self, signal_data, value): + """ + Verifies that the value is valid based on the VSS signal data. + + Args: + signal_data (dict): The specification data for a VSS signal. + value (str): The value to be verified. + + Returns: + bool: True if the value is valid, False otherwise. + """ + if value in signal_data["values"]["ignore"]: + return False + + elif signal_data["values"]["ranged"] and isinstance(value, (int, float)): + return value >= signal_data["values"]["start"] and value <= signal_data["values"]["end"] + + else: + return value in signal_data["values"]["additional"] diff --git a/agl_service_voiceagent/utils/stt_model.py b/agl_service_voiceagent/utils/stt_model.py index 5337162..d51ae31 100644 --- a/agl_service_voiceagent/utils/stt_model.py +++ b/agl_service_voiceagent/utils/stt_model.py @@ -21,21 +21,61 @@ import wave from agl_service_voiceagent.utils.common import generate_unique_uuid class STTModel: + """ + STTModel is a class for speech-to-text (STT) recognition using the Vosk speech recognition library. + """ + def __init__(self, model_path, sample_rate=16000): + """ + Initialize the STTModel instance with the provided model and sample rate. + + Args: + model_path (str): The path to the Vosk speech recognition model. + sample_rate (int, optional): The audio sample rate in Hz (default is 16000). + """ self.sample_rate = sample_rate self.model = vosk.Model(model_path) self.recognizer = {} self.chunk_size = 1024 + def setup_recognizer(self): + """ + Set up a Vosk recognizer for a new session and return a unique identifier (UUID) for the session. + + Returns: + str: A unique identifier (UUID) for the session. + """ uuid = generate_unique_uuid(6) self.recognizer[uuid] = vosk.KaldiRecognizer(self.model, self.sample_rate) return uuid + def init_recognition(self, uuid, audio_data): + """ + Initialize the Vosk recognizer for a session with audio data. + + Args: + uuid (str): The unique identifier (UUID) for the session. + audio_data (bytes): Audio data to process. + + Returns: + bool: True if initialization was successful, False otherwise. + """ return self.recognizer[uuid].AcceptWaveform(audio_data) + def recognize(self, uuid, partial=False): + """ + Recognize speech and return the result as a JSON object. + + Args: + uuid (str): The unique identifier (UUID) for the session. + partial (bool, optional): If True, return partial recognition results (default is False). + + Returns: + dict: A JSON object containing recognition results. + """ self.recognizer[uuid].SetWords(True) if partial: result = json.loads(self.recognizer[uuid].PartialResult()) @@ -44,7 +84,18 @@ class STTModel: self.recognizer[uuid].Reset() return result + def recognize_from_file(self, uuid, filename): + """ + Recognize speech from an audio file and return the recognized text. + + Args: + uuid (str): The unique identifier (UUID) for the session. + filename (str): The path to the audio file. + + Returns: + str: The recognized text or error messages. + """ if not os.path.exists(filename): print(f"Audio file '{filename}' not found.") return "FILE_NOT_FOUND" @@ -75,31 +126,13 @@ class STTModel: print("Voice not recognized. Please speak again...") return "VOICE_NOT_RECOGNIZED" - def cleanup_recognizer(self, uuid): - del self.recognizer[uuid] -import wave - -def read_wav_file(filename, chunk_size=1024): - try: - wf = wave.open(filename, "rb") - if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE": - print("Audio file must be WAV format mono PCM.") - return "FILE_FORMAT_INVALID" - - audio_data = b"" # Initialize an empty bytes object to store audio data - while True: - chunk = wf.readframes(chunk_size) - if not chunk: - break # End of file reached - audio_data += chunk - - return audio_data - except Exception as e: - print(f"Error reading audio file: {e}") - return None + def cleanup_recognizer(self, uuid): + """ + Clean up and remove the Vosk recognizer for a session. -# Example usage: -filename = "your_audio.wav" -audio_data = read_wav_file(filename) + Args: + uuid (str): The unique identifier (UUID) for the session. + """ + del self.recognizer[uuid]
\ No newline at end of file diff --git a/agl_service_voiceagent/utils/wake_word.py b/agl_service_voiceagent/utils/wake_word.py index 066ae6d..47e547e 100644 --- a/agl_service_voiceagent/utils/wake_word.py +++ b/agl_service_voiceagent/utils/wake_word.py @@ -15,7 +15,6 @@ # limitations under the License. import gi -import vosk gi.require_version('Gst', '1.0') from gi.repository import Gst, GLib @@ -23,7 +22,21 @@ Gst.init(None) GLib.threads_init() class WakeWordDetector: + """ + WakeWordDetector is a class for detecting a wake word in an audio stream using GStreamer and Vosk. + """ + def __init__(self, wake_word, stt_model, channels=1, sample_rate=16000, bits_per_sample=16): + """ + Initialize the WakeWordDetector instance with the provided parameters. + + Args: + wake_word (str): The wake word to detect in the audio stream. + stt_model (STTModel): An instance of the STTModel for speech-to-text recognition. + channels (int, optional): The number of audio channels (default is 1). + sample_rate (int, optional): The audio sample rate in Hz (default is 16000). + bits_per_sample (int, optional): The number of bits per sample (default is 16). + """ self.loop = GLib.MainLoop() self.pipeline = None self.bus = None @@ -32,16 +45,25 @@ class WakeWordDetector: self.sample_rate = sample_rate self.channels = channels self.bits_per_sample = bits_per_sample - self.frame_size = int(self.sample_rate * 0.02) - self.stt_model = stt_model # Speech to text model recognizer + self.wake_word_model = stt_model # Speech to text model recognizer self.recognizer_uuid = stt_model.setup_recognizer() - self.buffer_duration = 1 # Buffer audio for atleast 1 second self.audio_buffer = bytearray() + self.segment_size = int(self.sample_rate * 1.0) # Adjust the segment size (e.g., 1 second) + def get_wake_word_status(self): + """ + Get the status of wake word detection. + + Returns: + bool: True if the wake word has been detected, False otherwise. + """ return self.wake_word_detected def create_pipeline(self): + """ + Create and configure the GStreamer audio processing pipeline for wake word detection. + """ print("Creating pipeline for Wake Word Detection...") self.pipeline = Gst.Pipeline() autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None) @@ -77,49 +99,87 @@ class WakeWordDetector: self.bus.add_signal_watch() self.bus.connect("message", self.on_bus_message) + def on_new_buffer(self, appsink, data) -> Gst.FlowReturn: + """ + Callback function to handle new audio buffers from GStreamer appsink. + + Args: + appsink (Gst.AppSink): The GStreamer appsink. + data (object): User data (not used). + + Returns: + Gst.FlowReturn: Indicates the status of buffer processing. + """ sample = appsink.emit("pull-sample") buffer = sample.get_buffer() data = buffer.extract_dup(0, buffer.get_size()) + + # Add the new data to the buffer self.audio_buffer.extend(data) - if len(self.audio_buffer) >= self.sample_rate * self.buffer_duration * self.channels * self.bits_per_sample // 8: - self.process_audio_buffer() + # Process audio in segments + while len(self.audio_buffer) >= self.segment_size: + segment = self.audio_buffer[:self.segment_size] + self.process_audio_segment(segment) + + # Advance the buffer by the segment size + self.audio_buffer = self.audio_buffer[self.segment_size:] return Gst.FlowReturn.OK - - def process_audio_buffer(self): - # Process the accumulated audio data using the audio model - audio_data = bytes(self.audio_buffer) - if self.stt_model.init_recognition(self.recognizer_uuid, audio_data): - stt_result = self.stt_model.recognize(self.recognizer_uuid) + def process_audio_segment(self, segment): + """ + Process an audio segment for wake word detection. + + Args: + segment (bytes): The audio segment to process. + """ + # Process the audio data segment + audio_data = bytes(segment) + + # Perform wake word detection on the audio_data + if self.wake_word_model.init_recognition(self.recognizer_uuid, audio_data): + stt_result = self.wake_word_model.recognize(self.recognizer_uuid) print("STT Result: ", stt_result) if self.wake_word in stt_result["text"]: self.wake_word_detected = True print("Wake word detected!") self.pipeline.send_event(Gst.Event.new_eos()) - self.audio_buffer.clear() # Clear the buffer - - def send_eos(self): + """ + Send an End-of-Stream (EOS) event to the pipeline. + """ self.pipeline.send_event(Gst.Event.new_eos()) self.audio_buffer.clear() def start_listening(self): + """ + Start listening for the wake word and enter the event loop. + """ self.pipeline.set_state(Gst.State.PLAYING) print("Listening for Wake Word...") self.loop.run() def stop_listening(self): + """ + Stop listening for the wake word and clean up the pipeline. + """ self.cleanup_pipeline() self.loop.quit() def on_bus_message(self, bus, message): + """ + Handle GStreamer bus messages and perform actions based on the message type. + + Args: + bus (Gst.Bus): The GStreamer bus. + message (Gst.Message): The GStreamer message to process. + """ if message.type == Gst.MessageType.EOS: print("End-of-stream message received") self.stop_listening() @@ -140,6 +200,9 @@ class WakeWordDetector: def cleanup_pipeline(self): + """ + Clean up the GStreamer pipeline and release associated resources. + """ if self.pipeline is not None: print("Cleaning up pipeline...") self.pipeline.set_state(Gst.State.NULL) @@ -147,4 +210,4 @@ class WakeWordDetector: print("Pipeline cleanup complete!") self.bus = None self.pipeline = None - self.stt_model.cleanup_recognizer(self.recognizer_uuid) + self.wake_word_model.cleanup_recognizer(self.recognizer_uuid) |