aboutsummaryrefslogtreecommitdiffstats
path: root/agl_service_voiceagent/utils
diff options
context:
space:
mode:
Diffstat (limited to 'agl_service_voiceagent/utils')
-rw-r--r--agl_service_voiceagent/utils/audio_recorder.py52
-rw-r--r--agl_service_voiceagent/utils/common.py91
-rw-r--r--agl_service_voiceagent/utils/config.py34
-rw-r--r--agl_service_voiceagent/utils/kuksa_interface.py178
-rw-r--r--agl_service_voiceagent/utils/mapper.py261
-rw-r--r--agl_service_voiceagent/utils/stt_model.py83
-rw-r--r--agl_service_voiceagent/utils/wake_word.py95
7 files changed, 717 insertions, 77 deletions
diff --git a/agl_service_voiceagent/utils/audio_recorder.py b/agl_service_voiceagent/utils/audio_recorder.py
index 61ce994..2e8f11d 100644
--- a/agl_service_voiceagent/utils/audio_recorder.py
+++ b/agl_service_voiceagent/utils/audio_recorder.py
@@ -15,7 +15,6 @@
# limitations under the License.
import gi
-import vosk
import time
gi.require_version('Gst', '1.0')
from gi.repository import Gst, GLib
@@ -24,7 +23,21 @@ Gst.init(None)
GLib.threads_init()
class AudioRecorder:
+ """
+ AudioRecorder is a class for recording audio using GStreamer in various modes.
+ """
+
def __init__(self, stt_model, audio_files_basedir, channels=1, sample_rate=16000, bits_per_sample=16):
+ """
+ Initialize the AudioRecorder instance with the provided parameters.
+
+ Args:
+ stt_model (str): The speech-to-text model to use for voice input recognition.
+ audio_files_basedir (str): The base directory for saving audio files.
+ channels (int, optional): The number of audio channels (default is 1).
+ sample_rate (int, optional): The audio sample rate in Hz (default is 16000).
+ bits_per_sample (int, optional): The number of bits per sample (default is 16).
+ """
self.loop = GLib.MainLoop()
self.mode = None
self.pipeline = None
@@ -43,6 +56,12 @@ class AudioRecorder:
def create_pipeline(self):
+ """
+ Create and configure the GStreamer audio recording pipeline.
+
+ Returns:
+ str: The name of the audio file being recorded.
+ """
print("Creating pipeline for audio recording in {} mode...".format(self.mode))
self.pipeline = Gst.Pipeline()
autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None)
@@ -83,32 +102,50 @@ class AudioRecorder:
def start_recording(self):
+ """
+ Start recording audio using the GStreamer pipeline.
+ """
self.pipeline.set_state(Gst.State.PLAYING)
print("Recording Voice Input...")
def stop_recording(self):
+ """
+ Stop audio recording and clean up the GStreamer pipeline.
+ """
print("Stopping recording...")
- self.frames_above_threshold = 0
- self.cleanup_pipeline()
+ # self.cleanup_pipeline()
+ self.pipeline.send_event(Gst.Event.new_eos())
print("Recording finished!")
def set_pipeline_mode(self, mode):
+ """
+ Set the recording mode to 'auto' or 'manual'.
+
+ Args:
+ mode (str): The recording mode ('auto' or 'manual').
+ """
self.mode = mode
- # this method helps with error handling
def on_bus_message(self, bus, message):
+ """
+ Handle GStreamer bus messages and perform actions based on the message type.
+
+ Args:
+ bus (Gst.Bus): The GStreamer bus.
+ message (Gst.Message): The GStreamer message to process.
+ """
if message.type == Gst.MessageType.EOS:
print("End-of-stream message received")
- self.stop_recording()
+ self.cleanup_pipeline()
elif message.type == Gst.MessageType.ERROR:
err, debug_info = message.parse_error()
print(f"Error received from element {message.src.get_name()}: {err.message}")
print(f"Debugging information: {debug_info}")
- self.stop_recording()
+ self.cleanup_pipeline()
elif message.type == Gst.MessageType.WARNING:
err, debug_info = message.parse_warning()
@@ -136,6 +173,9 @@ class AudioRecorder:
def cleanup_pipeline(self):
+ """
+ Clean up the GStreamer pipeline, set it to NULL state, and remove the signal watch.
+ """
if self.pipeline is not None:
print("Cleaning up pipeline...")
self.pipeline.set_state(Gst.State.NULL)
diff --git a/agl_service_voiceagent/utils/common.py b/agl_service_voiceagent/utils/common.py
index 682473e..b9d6577 100644
--- a/agl_service_voiceagent/utils/common.py
+++ b/agl_service_voiceagent/utils/common.py
@@ -20,12 +20,18 @@ import json
def add_trailing_slash(path):
+ """
+ Adds a trailing slash to a path if it does not already have one.
+ """
if path and not path.endswith('/'):
path += '/'
return path
def generate_unique_uuid(length):
+ """
+ Generates a unique ID of specified length.
+ """
unique_id = str(uuid.uuid4().int)
# Ensure the generated ID is exactly 'length' digits by taking the last 'length' characters
unique_id = unique_id[-length:]
@@ -33,18 +39,91 @@ def generate_unique_uuid(length):
def load_json_file(file_path):
- try:
- with open(file_path, 'r') as file:
- return json.load(file)
- except FileNotFoundError:
- raise ValueError(f"File '{file_path}' not found.")
+ """
+ Loads a JSON file and returns the data.
+ """
+ try:
+ with open(file_path, 'r') as file:
+ return json.load(file)
+ except FileNotFoundError:
+ raise ValueError(f"File '{file_path}' not found.")
def delete_file(file_path):
+ """
+ Deletes a file if it exists.
+ """
if os.path.exists(file_path):
try:
os.remove(file_path)
except Exception as e:
print(f"Error deleting '{file_path}': {e}")
else:
- print(f"File '{file_path}' does not exist.") \ No newline at end of file
+ print(f"File '{file_path}' does not exist.")
+
+
+def words_to_number(words):
+ """
+ Converts a string of words to a number.
+ """
+ word_to_number = {
+ 'one': 1,
+ 'two': 2,
+ 'three': 3,
+ 'four': 4,
+ 'five': 5,
+ 'six': 6,
+ 'seven': 7,
+ 'eight': 8,
+ 'nine': 9,
+ 'ten': 10,
+ 'eleven': 11,
+ 'twelve': 12,
+ 'thirteen': 13,
+ 'fourteen': 14,
+ 'fifteen': 15,
+ 'sixteen': 16,
+ 'seventeen': 17,
+ 'eighteen': 18,
+ 'nineteen': 19,
+ 'twenty': 20,
+ 'thirty': 30,
+ 'forty': 40,
+ 'fourty': 40,
+ 'fifty': 50,
+ 'sixty': 60,
+ 'seventy': 70,
+ 'eighty': 80,
+ 'ninety': 90
+ }
+
+ # Split the input words and initialize total and current number
+ words = words.split()
+
+ if len(words) == 1:
+ if words[0] in ["zero", "min", "minimum"]:
+ return 0
+
+ elif words[0] in ["half", "halfway"]:
+ return 50
+
+ elif words[0] in ["max", "maximum", "full", "fully", "completely", "hundred"]:
+ return 100
+
+
+ total = 0
+ current_number = 0
+
+ for word in words:
+ if word in word_to_number:
+ current_number += word_to_number[word]
+ elif word == 'hundred':
+ current_number *= 100
+ else:
+ total += current_number
+ current_number = 0
+
+ total += current_number
+
+ # we return number in str format because kuksa expects str input
+ return str(total) or None \ No newline at end of file
diff --git a/agl_service_voiceagent/utils/config.py b/agl_service_voiceagent/utils/config.py
index 8d7f346..7295c7f 100644
--- a/agl_service_voiceagent/utils/config.py
+++ b/agl_service_voiceagent/utils/config.py
@@ -14,21 +14,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import configparser
-# Get the absolute path to the directory of the current script
-current_dir = os.path.dirname(os.path.abspath(__file__))
-# Construct the path to the config.ini file located in the base directory
-config_path = os.path.join(current_dir, '..', 'config.ini')
-
config = configparser.ConfigParser()
-config.read(config_path)
+config_path = None
+
+def set_config_path(path):
+ """
+ Sets the path to the config file.
+ """
+ global config_path
+ config_path = path
+ config.read(config_path)
+
+def load_config():
+ """
+ Loads the config file.
+ """
+ if config_path is not None:
+ config.read(config_path)
+ else:
+ raise Exception("Config file path not provided.")
def update_config_value(value, key, group="General"):
+ """
+ Updates a value in the config file.
+ """
+ if config_path is None:
+ raise Exception("Config file path not set.")
+
config.set(group, key, value)
with open(config_path, 'w') as configfile:
config.write(configfile)
def get_config_value(key, group="General"):
+ """
+ Gets a value from the config file.
+ """
return config.get(group, key)
diff --git a/agl_service_voiceagent/utils/kuksa_interface.py b/agl_service_voiceagent/utils/kuksa_interface.py
index 3e1c045..9270379 100644
--- a/agl_service_voiceagent/utils/kuksa_interface.py
+++ b/agl_service_voiceagent/utils/kuksa_interface.py
@@ -14,53 +14,197 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import time
+import json
+import threading
from kuksa_client import KuksaClientThread
from agl_service_voiceagent.utils.config import get_config_value
class KuksaInterface:
- def __init__(self):
+ """
+ Kuksa Interface
+
+ This class provides methods to initialize, authorize, connect, send values,
+ check the status, and close the Kuksa client.
+ """
+
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls):
+ """
+ Get the unique instance of the class.
+
+ Returns:
+ KuksaInterface: The instance of the class.
+ """
+ with cls._lock:
+ if cls._instance is None:
+ cls._instance = super(KuksaInterface, cls).__new__(cls)
+ cls._instance.init_client()
+ return cls._instance
+
+
+ def init_client(self):
+ """
+ Initialize the Kuksa client configuration.
+ """
# get config values
- self.ip = get_config_value("ip", "Kuksa")
- self.port = get_config_value("port", "Kuksa")
+ self.ip = str(get_config_value("ip", "Kuksa"))
+ self.port = str(get_config_value("port", "Kuksa"))
self.insecure = get_config_value("insecure", "Kuksa")
self.protocol = get_config_value("protocol", "Kuksa")
self.token = get_config_value("token", "Kuksa")
+ print(self.ip, self.port, self.insecure, self.protocol, self.token)
+
# define class methods
self.kuksa_client = None
def get_kuksa_client(self):
+ """
+ Get the Kuksa client instance.
+
+ Returns:
+ KuksaClientThread: The Kuksa client instance.
+ """
return self.kuksa_client
-
+
def get_kuksa_status(self):
+ """
+ Check the status of the Kuksa client connection.
+
+ Returns:
+ bool: True if the client is connected, False otherwise.
+ """
if self.kuksa_client:
return self.kuksa_client.checkConnection()
-
return False
def connect_kuksa_client(self):
+ """
+ Connect and start the Kuksa client.
+ """
try:
- self.kuksa_client = KuksaClientThread({
- "ip": self.ip,
- "port": self.port,
- "insecure": self.insecure,
- "protocol": self.protocol,
- })
- self.kuksa_client.authorize(self.token)
+ with self._lock:
+ if self.kuksa_client is None:
+ self.kuksa_client = KuksaClientThread({
+ "ip": self.ip,
+ "port": self.port,
+ "insecure": self.insecure,
+ "protocol": self.protocol,
+ })
+ self.kuksa_client.start()
+ time.sleep(2) # Give the thread time to start
+
+ if not self.get_kuksa_status():
+ print("[-] Error: Connection to Kuksa server failed.")
+ else:
+ print("[+] Connection to Kuksa established.")
except Exception as e:
- print("Error: ", e)
+ print("[-] Error: Connection to Kuksa server failed. ", str(e))
+
+
+ def authorize_kuksa_client(self):
+ """
+ Authorize the Kuksa client with the provided token.
+ """
+ if self.kuksa_client:
+ response = self.kuksa_client.authorize(self.token)
+ response = json.loads(response)
+ if "error" in response:
+ error_message = response.get("error", "Unknown error")
+ print(f"[-] Error: Authorization failed. {error_message}")
+ else:
+ print("[+] Kuksa client authorized successfully.")
+ else:
+ print("[-] Error: Kuksa client is not initialized. Call `connect_kuksa_client` first.")
+
+
+ def send_values(self, path=None, value=None):
+ """
+ Send values to the Kuksa server.
+
+ Args:
+ path (str): The path to the value.
+ value (str): The value to be sent.
+ Returns:
+ bool: True if the value was set, False otherwise.
+ """
+ result = False
+ if self.kuksa_client is None:
+ print("[-] Error: Kuksa client is not initialized.")
+ return
+
+ if self.get_kuksa_status():
+ try:
+ response = self.kuksa_client.setValue(path, value)
+ response = json.loads(response)
+ if not "error" in response:
+ print(f"[+] Value '{value}' sent to Kuksa successfully.")
+ result = True
+ else:
+ error_message = response.get("error", "Unknown error")
+ print(f"[-] Error: Failed to send value '{value}' to Kuksa. {error_message}")
+
+ except Exception as e:
+ print("[-] Error: Failed to send values to Kuksa. ", str(e))
+
+ else:
+ print("[-] Error: Connection to Kuksa failed.")
+
+ return result
- def send_values(self, Path=None, Value=None):
+ def get_value(self, path=None):
+ """
+ Get values from the Kuksa server.
+
+ Args:
+ path (str): The path to the value.
+ Returns:
+ str: The value if the path is valid, None otherwise.
+ """
+ result = None
if self.kuksa_client is None:
- print("Error: Kuksa client is not initialized.")
+ print("[-] Error: Kuksa client is not initialized.")
+ return
if self.get_kuksa_status():
- self.kuksa_client.setValue(Path, Value)
+ try:
+ response = self.kuksa_client.getValue(path)
+ response = json.loads(response)
+ if not "error" in response:
+ result = response.get("data", None)
+ result = result.get("dp", None)
+ result = result.get("value", None)
+
+ else:
+ error_message = response.get("error", "Unknown error")
+ print(f"[-] Error: Failed to get value from Kuksa. {error_message}")
+
+ except Exception as e:
+ print("[-] Error: Failed to get values from Kuksa. ", str(e))
else:
- print("Error: Connection to Kuksa failed.")
+ print("[-] Error: Connection to Kuksa failed.")
+
+ return result
+
+
+ def close_kuksa_client(self):
+ """
+ Close and stop the Kuksa client.
+ """
+ try:
+ with self._lock:
+ if self.kuksa_client:
+ self.kuksa_client.stop()
+ self.kuksa_client = None
+ print("[+] Kuksa client stopped.")
+ except Exception as e:
+ print("[-] Error: Failed to close Kuksa client. ", str(e)) \ No newline at end of file
diff --git a/agl_service_voiceagent/utils/mapper.py b/agl_service_voiceagent/utils/mapper.py
new file mode 100644
index 0000000..7529645
--- /dev/null
+++ b/agl_service_voiceagent/utils/mapper.py
@@ -0,0 +1,261 @@
+# SPDX-License-Identifier: Apache-2.0
+#
+# Copyright (c) 2023 Malik Talha
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from agl_service_voiceagent.utils.config import get_config_value
+from agl_service_voiceagent.utils.common import load_json_file, words_to_number
+
+
+class Intent2VSSMapper:
+ """
+ Intent2VSSMapper is a class that facilitates the mapping of natural language intent to
+ corresponding vehicle signal specifications (VSS) for automated vehicle control systems.
+ """
+
+ def __init__(self):
+ """
+ Initializes the Intent2VSSMapper class by loading Intent-to-VSS signal mappings
+ and VSS signal specifications from external configuration files.
+ """
+ intents_vss_map_file = get_config_value("intents_vss_map", "Mapper")
+ vss_signals_spec_file = get_config_value("vss_signals_spec", "Mapper")
+ self.intents_vss_map = load_json_file(intents_vss_map_file).get("intents", {})
+ self.vss_signals_spec = load_json_file(vss_signals_spec_file).get("signals", {})
+
+ if not self.validate_signal_spec_structure():
+ raise ValueError("[-] Invalid VSS signal specification structure.")
+
+ def validate_signal_spec_structure(self):
+ """
+ Validates the structure of the VSS signal specification data.
+ """
+
+ signals = self.vss_signals_spec
+
+ # Iterate over each signal in the 'signals' dictionary
+ for signal_name, signal_data in signals.items():
+ # Check if the required keys are present in the signal data
+ if not all(key in signal_data for key in ['default_value', 'default_change_factor', 'actions', 'values', 'default_fallback', 'value_set_intents']):
+ print(f"[-] {signal_name}: Missing required keys in signal data.")
+ return False
+
+ actions = signal_data['actions']
+
+ # Check if 'actions' is a dictionary with at least one action
+ if not isinstance(actions, dict) or not actions:
+ print(f"[-] {signal_name}: Invalid 'actions' key in signal data. Must be an object with at least one action.")
+ return False
+
+ # Check if the actions match the allowed actions ["set", "increase", "decrease"]
+ for action in actions.keys():
+ if action not in ["set", "increase", "decrease"]:
+ print(f"[-] {signal_name}: Invalid action in signal data. Allowed actions: ['set', 'increase', 'decrease']")
+ return False
+
+ # Check if the 'synonyms' list is present for each action and is either a list or None
+ for action_data in actions.values():
+ synonyms = action_data.get('synonyms')
+ if synonyms is not None and (not isinstance(synonyms, list) or not all(isinstance(synonym, str) for synonym in synonyms)):
+ print(f"[-] {signal_name}: Invalid 'synonyms' value in signal data. Must be a list of strings.")
+ return False
+
+ values = signal_data['values']
+
+ # Check if 'values' is a dictionary with the required keys
+ if not isinstance(values, dict) or not all(key in values for key in ['ranged', 'start', 'end', 'ignore', 'additional']):
+ print(f"[-] {signal_name}: Invalid 'values' key in signal data. Required keys: ['ranged', 'start', 'end', 'ignore', 'additional']")
+ return False
+
+ # Check if 'ranged' is a boolean
+ if not isinstance(values['ranged'], bool):
+ print(f"[-] {signal_name}: Invalid 'ranged' value in signal data. Allowed values: [true, false]")
+ return False
+
+ default_fallback = signal_data['default_fallback']
+
+ # Check if 'default_fallback' is a boolean
+ if not isinstance(default_fallback, bool):
+ print(f"[-] {signal_name}: Invalid 'default_fallback' value in signal data. Allowed values: [true, false]")
+ return False
+
+ # If all checks pass, the self.vss_signals_spec structure is valid
+ return True
+
+
+ def map_intent_to_signal(self, intent_name):
+ """
+ Maps an intent name to the corresponding VSS signals and their specifications.
+
+ Args:
+ intent_name (str): The name of the intent to be mapped.
+
+ Returns:
+ dict: A dictionary containing VSS signals as keys and their specifications as values.
+ """
+
+ intent_data = self.intents_vss_map.get(intent_name, None)
+ result = {}
+ if intent_data:
+ signals = intent_data.get("signals", [])
+
+ for signal in signals:
+ signal_info = self.vss_signals_spec.get(signal, {})
+ if signal_info:
+ result.update({signal: signal_info})
+
+ return result
+
+
+ def parse_intent(self, intent_name, intent_slots = []):
+ """
+ Parses an intent, extracting relevant VSS signals, actions, modifiers, and values
+ based on the intent and its associated slots.
+
+ Args:
+ intent_name (str): The name of the intent to be parsed.
+ intent_slots (list): A list of dictionaries representing intent slots.
+
+ Returns:
+ list: A list of dictionaries describing actions and signal-related details for execution.
+
+ Note:
+ - If no relevant VSS signals are found for the intent, an empty list is returned.
+ - If no specific action or modifier is determined, default values are used.
+ """
+ vss_signal_data = self.map_intent_to_signal(intent_name)
+ execution_list = []
+ for signal_name, signal_data in vss_signal_data.items():
+ action = self.determine_action(signal_data, intent_slots)
+ modifier = self.determine_modifier(signal_data, intent_slots)
+ value = self.determine_value(signal_data, intent_slots)
+
+ if value != None and not self.verify_value(signal_data, value):
+ value = None
+
+ change_factor = signal_data["default_change_factor"]
+
+ if action in ["increase", "decrease"]:
+ if value and modifier == "to":
+ execution_list.append({"action": action, "signal": signal_name, "value": str(value)})
+
+ elif value and modifier == "by":
+ execution_list.append({"action": action, "signal": signal_name, "factor": str(value)})
+
+ elif value:
+ execution_list.append({"action": action, "signal": signal_name, "value": str(value)})
+
+ elif signal_data["default_fallback"]:
+ execution_list.append({"action": action, "signal": signal_name, "factor": str(change_factor)})
+
+ # if no value found set the default value
+ if value == None and signal_data["default_fallback"]:
+ value = signal_data["default_value"]
+
+ if action == "set" and value != None:
+ execution_list.append({"action": action, "signal": signal_name, "value": str(value)})
+
+
+ return execution_list
+
+
+ def determine_action(self, signal_data, intent_slots):
+ """
+ Determines the action (e.g., set, increase, decrease) based on the intent slots
+ and VSS signal data.
+
+ Args:
+ signal_data (dict): The specification data for a VSS signal.
+ intent_slots (list): A list of dictionaries representing intent slots.
+
+ Returns:
+ str: The determined action or None if no action can be determined.
+ """
+ action_res = None
+ for intent_slot in intent_slots:
+ for action, action_data in signal_data["actions"].items():
+ if intent_slot["name"] in action_data["intents"] and intent_slot["value"] in action_data["synonyms"]:
+ action_res = action
+ break
+
+ return action_res
+
+
+ def determine_modifier(self, signal_data, intent_slots):
+ """
+ Determines the modifier (e.g., 'to' or 'by') based on the intent slots
+ and VSS signal data.
+
+ Args:
+ signal_data (dict): The specification data for a VSS signal.
+ intent_slots (list): A list of dictionaries representing intent slots.
+
+ Returns:
+ str: The determined modifier or None if no modifier can be determined.
+ """
+ modifier_res = None
+ for intent_slot in intent_slots:
+ for _, action_data in signal_data["actions"].items():
+ intent_val = intent_slot["value"]
+ if "modifier_intents" in action_data and intent_slot["name"] in action_data["modifier_intents"] and ("to" in intent_val or "by" in intent_val):
+ modifier_res = "to" if "to" in intent_val else "by" if "by" in intent_val else None
+ break
+
+ return modifier_res
+
+
+ def determine_value(self, signal_data, intent_slots):
+ """
+ Determines the value associated with the intent slot, considering the data type
+ and converting it to a numeric string representation if necessary.
+
+ Args:
+ signal_data (dict): The specification data for a VSS signal.
+ intent_slots (list): A list of dictionaries representing intent slots.
+
+ Returns:
+ str: The determined value or None if no value can be determined.
+ """
+ result = None
+ for intent_slot in intent_slots:
+ for value, value_data in signal_data["value_set_intents"].items():
+ if intent_slot["name"] == value:
+ result = intent_slot["value"]
+
+ if value_data["datatype"] == "number":
+ result = words_to_number(result) # we assume our model will always return a number in words
+
+ # the value should always returned as str because Kuksa expects str values
+ return str(result) if result != None else None
+
+
+ def verify_value(self, signal_data, value):
+ """
+ Verifies that the value is valid based on the VSS signal data.
+
+ Args:
+ signal_data (dict): The specification data for a VSS signal.
+ value (str): The value to be verified.
+
+ Returns:
+ bool: True if the value is valid, False otherwise.
+ """
+ if value in signal_data["values"]["ignore"]:
+ return False
+
+ elif signal_data["values"]["ranged"] and isinstance(value, (int, float)):
+ return value >= signal_data["values"]["start"] and value <= signal_data["values"]["end"]
+
+ else:
+ return value in signal_data["values"]["additional"]
diff --git a/agl_service_voiceagent/utils/stt_model.py b/agl_service_voiceagent/utils/stt_model.py
index 5337162..d51ae31 100644
--- a/agl_service_voiceagent/utils/stt_model.py
+++ b/agl_service_voiceagent/utils/stt_model.py
@@ -21,21 +21,61 @@ import wave
from agl_service_voiceagent.utils.common import generate_unique_uuid
class STTModel:
+ """
+ STTModel is a class for speech-to-text (STT) recognition using the Vosk speech recognition library.
+ """
+
def __init__(self, model_path, sample_rate=16000):
+ """
+ Initialize the STTModel instance with the provided model and sample rate.
+
+ Args:
+ model_path (str): The path to the Vosk speech recognition model.
+ sample_rate (int, optional): The audio sample rate in Hz (default is 16000).
+ """
self.sample_rate = sample_rate
self.model = vosk.Model(model_path)
self.recognizer = {}
self.chunk_size = 1024
+
def setup_recognizer(self):
+ """
+ Set up a Vosk recognizer for a new session and return a unique identifier (UUID) for the session.
+
+ Returns:
+ str: A unique identifier (UUID) for the session.
+ """
uuid = generate_unique_uuid(6)
self.recognizer[uuid] = vosk.KaldiRecognizer(self.model, self.sample_rate)
return uuid
+
def init_recognition(self, uuid, audio_data):
+ """
+ Initialize the Vosk recognizer for a session with audio data.
+
+ Args:
+ uuid (str): The unique identifier (UUID) for the session.
+ audio_data (bytes): Audio data to process.
+
+ Returns:
+ bool: True if initialization was successful, False otherwise.
+ """
return self.recognizer[uuid].AcceptWaveform(audio_data)
+
def recognize(self, uuid, partial=False):
+ """
+ Recognize speech and return the result as a JSON object.
+
+ Args:
+ uuid (str): The unique identifier (UUID) for the session.
+ partial (bool, optional): If True, return partial recognition results (default is False).
+
+ Returns:
+ dict: A JSON object containing recognition results.
+ """
self.recognizer[uuid].SetWords(True)
if partial:
result = json.loads(self.recognizer[uuid].PartialResult())
@@ -44,7 +84,18 @@ class STTModel:
self.recognizer[uuid].Reset()
return result
+
def recognize_from_file(self, uuid, filename):
+ """
+ Recognize speech from an audio file and return the recognized text.
+
+ Args:
+ uuid (str): The unique identifier (UUID) for the session.
+ filename (str): The path to the audio file.
+
+ Returns:
+ str: The recognized text or error messages.
+ """
if not os.path.exists(filename):
print(f"Audio file '{filename}' not found.")
return "FILE_NOT_FOUND"
@@ -75,31 +126,13 @@ class STTModel:
print("Voice not recognized. Please speak again...")
return "VOICE_NOT_RECOGNIZED"
- def cleanup_recognizer(self, uuid):
- del self.recognizer[uuid]
-import wave
-
-def read_wav_file(filename, chunk_size=1024):
- try:
- wf = wave.open(filename, "rb")
- if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE":
- print("Audio file must be WAV format mono PCM.")
- return "FILE_FORMAT_INVALID"
-
- audio_data = b"" # Initialize an empty bytes object to store audio data
- while True:
- chunk = wf.readframes(chunk_size)
- if not chunk:
- break # End of file reached
- audio_data += chunk
-
- return audio_data
- except Exception as e:
- print(f"Error reading audio file: {e}")
- return None
+ def cleanup_recognizer(self, uuid):
+ """
+ Clean up and remove the Vosk recognizer for a session.
-# Example usage:
-filename = "your_audio.wav"
-audio_data = read_wav_file(filename)
+ Args:
+ uuid (str): The unique identifier (UUID) for the session.
+ """
+ del self.recognizer[uuid]
\ No newline at end of file
diff --git a/agl_service_voiceagent/utils/wake_word.py b/agl_service_voiceagent/utils/wake_word.py
index 066ae6d..47e547e 100644
--- a/agl_service_voiceagent/utils/wake_word.py
+++ b/agl_service_voiceagent/utils/wake_word.py
@@ -15,7 +15,6 @@
# limitations under the License.
import gi
-import vosk
gi.require_version('Gst', '1.0')
from gi.repository import Gst, GLib
@@ -23,7 +22,21 @@ Gst.init(None)
GLib.threads_init()
class WakeWordDetector:
+ """
+ WakeWordDetector is a class for detecting a wake word in an audio stream using GStreamer and Vosk.
+ """
+
def __init__(self, wake_word, stt_model, channels=1, sample_rate=16000, bits_per_sample=16):
+ """
+ Initialize the WakeWordDetector instance with the provided parameters.
+
+ Args:
+ wake_word (str): The wake word to detect in the audio stream.
+ stt_model (STTModel): An instance of the STTModel for speech-to-text recognition.
+ channels (int, optional): The number of audio channels (default is 1).
+ sample_rate (int, optional): The audio sample rate in Hz (default is 16000).
+ bits_per_sample (int, optional): The number of bits per sample (default is 16).
+ """
self.loop = GLib.MainLoop()
self.pipeline = None
self.bus = None
@@ -32,16 +45,25 @@ class WakeWordDetector:
self.sample_rate = sample_rate
self.channels = channels
self.bits_per_sample = bits_per_sample
- self.frame_size = int(self.sample_rate * 0.02)
- self.stt_model = stt_model # Speech to text model recognizer
+ self.wake_word_model = stt_model # Speech to text model recognizer
self.recognizer_uuid = stt_model.setup_recognizer()
- self.buffer_duration = 1 # Buffer audio for atleast 1 second
self.audio_buffer = bytearray()
+ self.segment_size = int(self.sample_rate * 1.0) # Adjust the segment size (e.g., 1 second)
+
def get_wake_word_status(self):
+ """
+ Get the status of wake word detection.
+
+ Returns:
+ bool: True if the wake word has been detected, False otherwise.
+ """
return self.wake_word_detected
def create_pipeline(self):
+ """
+ Create and configure the GStreamer audio processing pipeline for wake word detection.
+ """
print("Creating pipeline for Wake Word Detection...")
self.pipeline = Gst.Pipeline()
autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None)
@@ -77,49 +99,87 @@ class WakeWordDetector:
self.bus.add_signal_watch()
self.bus.connect("message", self.on_bus_message)
+
def on_new_buffer(self, appsink, data) -> Gst.FlowReturn:
+ """
+ Callback function to handle new audio buffers from GStreamer appsink.
+
+ Args:
+ appsink (Gst.AppSink): The GStreamer appsink.
+ data (object): User data (not used).
+
+ Returns:
+ Gst.FlowReturn: Indicates the status of buffer processing.
+ """
sample = appsink.emit("pull-sample")
buffer = sample.get_buffer()
data = buffer.extract_dup(0, buffer.get_size())
+
+ # Add the new data to the buffer
self.audio_buffer.extend(data)
- if len(self.audio_buffer) >= self.sample_rate * self.buffer_duration * self.channels * self.bits_per_sample // 8:
- self.process_audio_buffer()
+ # Process audio in segments
+ while len(self.audio_buffer) >= self.segment_size:
+ segment = self.audio_buffer[:self.segment_size]
+ self.process_audio_segment(segment)
+
+ # Advance the buffer by the segment size
+ self.audio_buffer = self.audio_buffer[self.segment_size:]
return Gst.FlowReturn.OK
-
- def process_audio_buffer(self):
- # Process the accumulated audio data using the audio model
- audio_data = bytes(self.audio_buffer)
- if self.stt_model.init_recognition(self.recognizer_uuid, audio_data):
- stt_result = self.stt_model.recognize(self.recognizer_uuid)
+ def process_audio_segment(self, segment):
+ """
+ Process an audio segment for wake word detection.
+
+ Args:
+ segment (bytes): The audio segment to process.
+ """
+ # Process the audio data segment
+ audio_data = bytes(segment)
+
+ # Perform wake word detection on the audio_data
+ if self.wake_word_model.init_recognition(self.recognizer_uuid, audio_data):
+ stt_result = self.wake_word_model.recognize(self.recognizer_uuid)
print("STT Result: ", stt_result)
if self.wake_word in stt_result["text"]:
self.wake_word_detected = True
print("Wake word detected!")
self.pipeline.send_event(Gst.Event.new_eos())
- self.audio_buffer.clear() # Clear the buffer
-
-
def send_eos(self):
+ """
+ Send an End-of-Stream (EOS) event to the pipeline.
+ """
self.pipeline.send_event(Gst.Event.new_eos())
self.audio_buffer.clear()
def start_listening(self):
+ """
+ Start listening for the wake word and enter the event loop.
+ """
self.pipeline.set_state(Gst.State.PLAYING)
print("Listening for Wake Word...")
self.loop.run()
def stop_listening(self):
+ """
+ Stop listening for the wake word and clean up the pipeline.
+ """
self.cleanup_pipeline()
self.loop.quit()
def on_bus_message(self, bus, message):
+ """
+ Handle GStreamer bus messages and perform actions based on the message type.
+
+ Args:
+ bus (Gst.Bus): The GStreamer bus.
+ message (Gst.Message): The GStreamer message to process.
+ """
if message.type == Gst.MessageType.EOS:
print("End-of-stream message received")
self.stop_listening()
@@ -140,6 +200,9 @@ class WakeWordDetector:
def cleanup_pipeline(self):
+ """
+ Clean up the GStreamer pipeline and release associated resources.
+ """
if self.pipeline is not None:
print("Cleaning up pipeline...")
self.pipeline.set_state(Gst.State.NULL)
@@ -147,4 +210,4 @@ class WakeWordDetector:
print("Pipeline cleanup complete!")
self.bus = None
self.pipeline = None
- self.stt_model.cleanup_recognizer(self.recognizer_uuid)
+ self.wake_word_model.cleanup_recognizer(self.recognizer_uuid)