aboutsummaryrefslogtreecommitdiffstats
path: root/agl_service_voiceagent
diff options
context:
space:
mode:
authorMalik Talha <talhamalik727x@gmail.com>2023-10-29 20:52:29 +0500
committerMalik Talha <talhamalik727x@gmail.com>2023-10-29 20:52:29 +0500
commit42a03d2550f60a8064078f19a743afb944f9ff69 (patch)
treec9a7b3d028737d5fecd2e05f69e1c744810ed5fb /agl_service_voiceagent
parenta10c988b5480ca5b937a2793b450cfa01f569d76 (diff)
Update voice agent service
Add new features such as an option to load service using an external config file, enhanced kuksa client, and a more robust mapper. Signed-off-by: Malik Talha <talhamalik727x@gmail.com> Change-Id: Iba3cfd234c0aabad67b293669d456bb73d8e3135
Diffstat (limited to 'agl_service_voiceagent')
-rw-r--r--agl_service_voiceagent/client.py16
-rw-r--r--agl_service_voiceagent/config.ini17
-rw-r--r--agl_service_voiceagent/nlu/rasa_interface.py41
-rw-r--r--agl_service_voiceagent/nlu/snips_interface.py38
-rw-r--r--agl_service_voiceagent/protos/voice_agent.proto3
-rw-r--r--agl_service_voiceagent/server.py3
-rw-r--r--agl_service_voiceagent/service.py63
-rw-r--r--agl_service_voiceagent/servicers/voice_agent_servicer.py109
-rw-r--r--agl_service_voiceagent/utils/audio_recorder.py52
-rw-r--r--agl_service_voiceagent/utils/common.py91
-rw-r--r--agl_service_voiceagent/utils/config.py34
-rw-r--r--agl_service_voiceagent/utils/kuksa_interface.py178
-rw-r--r--agl_service_voiceagent/utils/mapper.py261
-rw-r--r--agl_service_voiceagent/utils/stt_model.py83
-rw-r--r--agl_service_voiceagent/utils/wake_word.py95
15 files changed, 979 insertions, 105 deletions
diff --git a/agl_service_voiceagent/client.py b/agl_service_voiceagent/client.py
index 9b2e0a0..922e08c 100644
--- a/agl_service_voiceagent/client.py
+++ b/agl_service_voiceagent/client.py
@@ -20,14 +20,8 @@ from agl_service_voiceagent.generated import voice_agent_pb2
from agl_service_voiceagent.generated import voice_agent_pb2_grpc
from agl_service_voiceagent.utils.config import get_config_value
-# following code is only reqired for logging
-import logging
-logging.basicConfig()
-logging.getLogger("grpc").setLevel(logging.DEBUG)
-
-SERVER_URL = get_config_value('SERVER_ADDRESS') + ":" + str(get_config_value('SERVER_PORT'))
-
def run_client(mode, nlu_model):
+ SERVER_URL = get_config_value('SERVER_ADDRESS') + ":" + str(get_config_value('SERVER_PORT'))
nlu_model = voice_agent_pb2.SNIPS if nlu_model == "snips" else voice_agent_pb2.RASA
print("Starting Voice Agent Client...")
print(f"Client connecting to URL: {SERVER_URL}")
@@ -73,6 +67,12 @@ def run_client(mode, nlu_model):
print("Command:", record_result.command)
print("Status:", status)
print("Intent:", record_result.intent)
+ intent_slots = []
for slot in record_result.intent_slots:
print("Slot Name:", slot.name)
- print("Slot Value:", slot.value) \ No newline at end of file
+ print("Slot Value:", slot.value)
+ i_slot = voice_agent_pb2.IntentSlot(name=slot.name, value=slot.value)
+ intent_slots.append(i_slot)
+
+ exec_voice_command_request = voice_agent_pb2.ExecuteInput(intent=record_result.intent, intent_slots=intent_slots)
+ response = stub.ExecuteVoiceCommand(exec_voice_command_request) \ No newline at end of file
diff --git a/agl_service_voiceagent/config.ini b/agl_service_voiceagent/config.ini
index 9455d6a..074f6a8 100644
--- a/agl_service_voiceagent/config.ini
+++ b/agl_service_voiceagent/config.ini
@@ -1,22 +1,27 @@
[General]
-service_version = 0.2.0
base_audio_dir = /usr/share/nlu/commands/
-stt_model_path = /usr/share/vosk/vosk-model-small-en-us-0.15
+stt_model_path = /usr/share/vosk/vosk-model-small-en-us-0.15/
+wake_word_model_path = /usr/share/vosk/vosk-model-small-en-us-0.15/
snips_model_path = /usr/share/nlu/snips/model/
channels = 1
sample_rate = 16000
bits_per_sample = 16
wake_word = hello auto
server_port = 51053
-server_address = localhost
+server_address = 127.0.0.1
rasa_model_path = /usr/share/nlu/rasa/models/
rasa_server_port = 51054
+rasa_detached_mode = 0
base_log_dir = /usr/share/nlu/logs/
store_voice_commands = 0
[Kuksa]
-ip = localhost
+ip = 127.0.0.1
port = 8090
protocol = ws
-insecure = False
-token = /
+insecure = True
+token = /usr/lib/python3.10/site-packages/kuksa_certificates/jwt/super-admin.json.token
+
+[Mapper]
+intents_vss_map = /usr/share/nlu/mappings/intents_vss_map.json
+vss_signals_spec = /usr/share/nlu/mappings/vss_signals_spec.json \ No newline at end of file
diff --git a/agl_service_voiceagent/nlu/rasa_interface.py b/agl_service_voiceagent/nlu/rasa_interface.py
index 0232126..537a318 100644
--- a/agl_service_voiceagent/nlu/rasa_interface.py
+++ b/agl_service_voiceagent/nlu/rasa_interface.py
@@ -21,7 +21,20 @@ import subprocess
from concurrent.futures import ThreadPoolExecutor
class RASAInterface:
+ """
+ RASAInterface is a class for interfacing with a Rasa NLU server to extract intents and entities from text input.
+ """
+
def __init__(self, port, model_path, log_dir, max_threads=5):
+ """
+ Initialize the RASAInterface instance with the provided parameters.
+
+ Args:
+ port (int): The port number on which the Rasa NLU server will run.
+ model_path (str): The path to the Rasa NLU model.
+ log_dir (str): The directory where server logs will be saved.
+ max_threads (int, optional): The maximum number of concurrent threads (default is 5).
+ """
self.port = port
self.model_path = model_path
self.max_threads = max_threads
@@ -31,6 +44,9 @@ class RASAInterface:
def _start_server(self):
+ """
+ Start the Rasa NLU server in a subprocess and redirect its output to the log file.
+ """
command = (
f"rasa run --enable-api -m \"{self.model_path}\" -p {self.port}"
)
@@ -41,6 +57,9 @@ class RASAInterface:
def start_server(self):
+ """
+ Start the Rasa NLU server in a separate thread and wait for it to initialize.
+ """
self.thread_pool.submit(self._start_server)
# Wait for a brief moment to allow the server to start
@@ -48,6 +67,9 @@ class RASAInterface:
def stop_server(self):
+ """
+ Stop the Rasa NLU server and shut down the thread pool.
+ """
if self.server_process:
self.server_process.terminate()
self.server_process.wait()
@@ -56,6 +78,16 @@ class RASAInterface:
def preprocess_text(self, text):
+ """
+ Preprocess the input text by converting it to lowercase, removing leading/trailing spaces,
+ and removing special characters and punctuation.
+
+ Args:
+ text (str): The input text to preprocess.
+
+ Returns:
+ str: The preprocessed text.
+ """
# text to lower case and remove trailing and leading spaces
preprocessed_text = text.lower().strip()
# remove special characters, punctuation, and extra whitespaces
@@ -77,6 +109,15 @@ class RASAInterface:
def process_intent(self, intent_output):
+ """
+ Extract intents and entities from preprocessed text using the Rasa NLU server.
+
+ Args:
+ text (str): The preprocessed input text.
+
+ Returns:
+ dict: Intent and entity extraction result as a dictionary.
+ """
intent = intent_output["intent"]["name"]
entities = {}
for entity in intent_output["entities"]:
diff --git a/agl_service_voiceagent/nlu/snips_interface.py b/agl_service_voiceagent/nlu/snips_interface.py
index f0b05d2..1febe92 100644
--- a/agl_service_voiceagent/nlu/snips_interface.py
+++ b/agl_service_voiceagent/nlu/snips_interface.py
@@ -19,10 +19,30 @@ from typing import Text
from snips_inference_agl import SnipsNLUEngine
class SnipsInterface:
+ """
+ SnipsInterface is a class for interacting with the Snips Natural Language Understanding Engine (Snips NLU).
+ """
+
def __init__(self, model_path: Text):
+ """
+ Initialize the SnipsInterface instance with the provided Snips NLU model.
+
+ Args:
+ model_path (Text): The path to the Snips NLU model.
+ """
self.engine = SnipsNLUEngine.from_path(model_path)
def preprocess_text(self, text):
+ """
+ Preprocess the input text by converting it to lowercase, removing leading/trailing spaces,
+ and removing special characters and punctuation.
+
+ Args:
+ text (str): The input text to preprocess.
+
+ Returns:
+ str: The preprocessed text.
+ """
# text to lower case and remove trailing and leading spaces
preprocessed_text = text.lower().strip()
# remove special characters, punctuation, and extra whitespaces
@@ -30,11 +50,29 @@ class SnipsInterface:
return preprocessed_text
def extract_intent(self, text: Text):
+ """
+ Extract the intent from preprocessed text using the Snips NLU engine.
+
+ Args:
+ text (Text): The preprocessed input text.
+
+ Returns:
+ dict: The intent extraction result as a dictionary.
+ """
preprocessed_text = self.preprocess_text(text)
result = self.engine.parse(preprocessed_text)
return result
def process_intent(self, intent_output):
+ """
+ Extract intent and slot values from Snips NLU output.
+
+ Args:
+ intent_output (dict): The intent extraction result from Snips NLU.
+
+ Returns:
+ tuple: A tuple containing the intent name (str) and a dictionary of intent actions (entity-value pairs).
+ """
intent_actions = {}
intent = intent_output['intent']['intentName']
slots = intent_output.get('slots', [])
diff --git a/agl_service_voiceagent/protos/voice_agent.proto b/agl_service_voiceagent/protos/voice_agent.proto
index 8ee8324..8c3ab65 100644
--- a/agl_service_voiceagent/protos/voice_agent.proto
+++ b/agl_service_voiceagent/protos/voice_agent.proto
@@ -78,5 +78,6 @@ message ExecuteInput {
}
message ExecuteResult {
- ExecuteStatusType status = 1;
+ string response = 1;
+ ExecuteStatusType status = 2;
}
diff --git a/agl_service_voiceagent/server.py b/agl_service_voiceagent/server.py
index 5fe0746..d8ce785 100644
--- a/agl_service_voiceagent/server.py
+++ b/agl_service_voiceagent/server.py
@@ -20,9 +20,8 @@ from agl_service_voiceagent.generated import voice_agent_pb2_grpc
from agl_service_voiceagent.servicers.voice_agent_servicer import VoiceAgentServicer
from agl_service_voiceagent.utils.config import get_config_value
-SERVER_URL = get_config_value('SERVER_ADDRESS') + ":" + str(get_config_value('SERVER_PORT'))
-
def run_server():
+ SERVER_URL = get_config_value('SERVER_ADDRESS') + ":" + str(get_config_value('SERVER_PORT'))
print("Starting Voice Agent Service...")
print(f"Server running at URL: {SERVER_URL}")
print(f"STT Model Path: {get_config_value('STT_MODEL_PATH')}")
diff --git a/agl_service_voiceagent/service.py b/agl_service_voiceagent/service.py
index 1b34c27..784d8d9 100644
--- a/agl_service_voiceagent/service.py
+++ b/agl_service_voiceagent/service.py
@@ -25,7 +25,7 @@ generated_dir = os.path.join(current_dir, "generated")
sys.path.append(generated_dir)
import argparse
-from agl_service_voiceagent.utils.config import update_config_value, get_config_value
+from agl_service_voiceagent.utils.config import set_config_path, load_config, update_config_value, get_config_value
from agl_service_voiceagent.utils.common import add_trailing_slash
from agl_service_voiceagent.server import run_server
from agl_service_voiceagent.client import run_client
@@ -33,7 +33,7 @@ from agl_service_voiceagent.client import run_client
def print_version():
print("Automotive Grade Linux (AGL)")
- print(f"Voice Agent Service v{get_config_value('SERVICE_VERSION')}")
+ print(f"Voice Agent Service v0.3.0")
def main():
@@ -47,10 +47,15 @@ def main():
server_parser = subparsers.add_parser('run-server', help='Run the Voice Agent gRPC Server')
client_parser = subparsers.add_parser('run-client', help='Run the Voice Agent gRPC Client')
- server_parser.add_argument('--config', action='store_true', help='Starts the server solely based on values provided in config file.')
- server_parser.add_argument('--stt-model-path', required=False, help='Path to the Speech To Text model. Currently only supports VOSK Kaldi.')
+ server_parser.add_argument('--default', action='store_true', help='Starts the server based on default config file.')
+ server_parser.add_argument('--config', required=False, help='Path to a config file. Server is started based on this config file.')
+ server_parser.add_argument('--stt-model-path', required=False, help='Path to the Speech To Text model for Voice Commad detection. Currently only supports VOSK Kaldi.')
+ server_parser.add_argument('--ww-model-path', required=False, help='Path to the Speech To Text model for Wake Word detection. Currently only supports VOSK Kaldi. Defaults to the same model as --stt-model-path if not provided.')
server_parser.add_argument('--snips-model-path', required=False, help='Path to the Snips NLU model.')
server_parser.add_argument('--rasa-model-path', required=False, help='Path to the RASA NLU model.')
+ server_parser.add_argument('--rasa-detached-mode', required=False, help='Assume that the RASA server is already running and does not start it as a sub process.')
+ server_parser.add_argument('--intents-vss-map-path', required=False, help='Path to the JSON file containing Intent to VSS map.')
+ server_parser.add_argument('--vss-signals-spec-path', required=False, help='Path to the VSS signals specification JSON file.')
server_parser.add_argument('--audio-store-dir', required=False, help='Directory to store the generated audio files.')
server_parser.add_argument('--log-store-dir', required=False, help='Directory to store the generated log files.')
@@ -63,7 +68,7 @@ def main():
print_version()
elif args.subcommand == 'run-server':
- if not args.config:
+ if not args.default and not args.config:
if not args.stt_model_path:
raise ValueError("The --stt-model-path is missing. Please provide a value. Use --help to see available options.")
@@ -72,20 +77,42 @@ def main():
if not args.rasa_model_path:
raise ValueError("The --rasa-model-path is missing. Please provide a value. Use --help to see available options.")
+
+ if not args.intents_vss_map_path:
+ raise ValueError("The --intents-vss-map-path is missing. Please provide a value. Use --help to see available options.")
+
+ if not args.vss_signals_spec_path:
+ raise ValueError("The --vss-signals-spec is missing. Please provide a value. Use --help to see available options.")
+
+ # Contruct the default config file path
+ config_path = os.path.join(current_dir, "config.ini")
+ # Load the config values from the config file
+ set_config_path(config_path)
+ load_config()
+
+ # Get the values provided by the user
stt_path = args.stt_model_path
snips_model_path = args.snips_model_path
rasa_model_path = args.rasa_model_path
+ intents_vss_map_path = args.intents_vss_map_path
+ vss_signals_spec_path = args.vss_signals_spec_path
# Convert to an absolute path if it's a relative path
stt_path = add_trailing_slash(os.path.abspath(stt_path)) if not os.path.isabs(stt_path) else stt_path
snips_model_path = add_trailing_slash(os.path.abspath(snips_model_path)) if not os.path.isabs(snips_model_path) else snips_model_path
rasa_model_path = add_trailing_slash(os.path.abspath(rasa_model_path)) if not os.path.isabs(rasa_model_path) else rasa_model_path
+ intents_vss_map_path = os.path.abspath(intents_vss_map_path) if not os.path.isabs(intents_vss_map_path) else intents_vss_map_path
+ vss_signals_spec_path = os.path.abspath(vss_signals_spec_path) if not os.path.isabs(vss_signals_spec_path) else vss_signals_spec_path
# Also update the config.ini file
update_config_value(stt_path, 'STT_MODEL_PATH')
update_config_value(snips_model_path, 'SNIPS_MODEL_PATH')
update_config_value(rasa_model_path, 'RASA_MODEL_PATH')
+ update_config_value(intents_vss_map_path, 'INTENTS_VSS_MAP')
+ update_config_value(vss_signals_spec_path, 'VSS_SIGNALS_SPEC')
+ if args.rasa_detached_mode:
+ update_config_value('1', 'RASA_DETACHED_MODE')
# Update the audio store dir in config.ini if provided
audio_dir = args.audio_store_dir or get_config_value('BASE_AUDIO_DIR')
@@ -98,6 +125,25 @@ def main():
update_config_value(log_dir, 'BASE_LOG_DIR')
+ elif args.config:
+ # Get config file path value
+ cli_config_path = args.config
+
+ # if config file path provided then load the config values from it
+ if cli_config_path :
+ cli_config_path = os.path.abspath(cli_config_path) if not os.path.isabs(cli_config_path) else cli_config_path
+ print(f"New config file path provided: {cli_config_path}. Overriding the default config file path.")
+ set_config_path(cli_config_path)
+ load_config()
+
+ elif args.default:
+ # Contruct the default config file path
+ config_path = os.path.join(current_dir, "config.ini")
+
+ # Load the config values from the config file
+ set_config_path(config_path)
+ load_config()
+
# create the base audio dir if not exists
if not os.path.exists(get_config_value('BASE_AUDIO_DIR')):
os.makedirs(get_config_value('BASE_AUDIO_DIR'))
@@ -109,6 +155,13 @@ def main():
run_server()
elif args.subcommand == 'run-client':
+ # Contruct the default config file path
+ config_path = os.path.join(current_dir, "config.ini")
+
+ # Load the config values from the config file
+ set_config_path(config_path)
+ load_config()
+
mode = args.mode
if mode not in ['wake-word', 'auto', 'manual']:
raise ValueError("Invalid mode. Supported modes: 'wake-word', 'auto' and 'manual'. Use --help to see available options.")
diff --git a/agl_service_voiceagent/servicers/voice_agent_servicer.py b/agl_service_voiceagent/servicers/voice_agent_servicer.py
index 4038b85..69af10b 100644
--- a/agl_service_voiceagent/servicers/voice_agent_servicer.py
+++ b/agl_service_voiceagent/servicers/voice_agent_servicer.py
@@ -17,42 +17,65 @@
import grpc
import time
import threading
-from generated import voice_agent_pb2
-from generated import voice_agent_pb2_grpc
+from agl_service_voiceagent.generated import voice_agent_pb2
+from agl_service_voiceagent.generated import voice_agent_pb2_grpc
from agl_service_voiceagent.utils.audio_recorder import AudioRecorder
from agl_service_voiceagent.utils.wake_word import WakeWordDetector
from agl_service_voiceagent.utils.stt_model import STTModel
+from agl_service_voiceagent.utils.kuksa_interface import KuksaInterface
+from agl_service_voiceagent.utils.mapper import Intent2VSSMapper
from agl_service_voiceagent.utils.config import get_config_value
+from agl_service_voiceagent.utils.common import generate_unique_uuid, delete_file
from agl_service_voiceagent.nlu.snips_interface import SnipsInterface
from agl_service_voiceagent.nlu.rasa_interface import RASAInterface
-from agl_service_voiceagent.utils.common import generate_unique_uuid, delete_file
class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer):
+ """
+ Voice Agent Servicer class that implements the gRPC service defined in voice_agent.proto.
+ """
+
def __init__(self):
+ """
+ Constructor for VoiceAgentServicer class.
+ """
# Get the config values
- self.service_version = get_config_value('SERVICE_VERSION')
+ self.service_version = "v0.3.0"
self.wake_word = get_config_value('WAKE_WORD')
self.base_audio_dir = get_config_value('BASE_AUDIO_DIR')
self.channels = int(get_config_value('CHANNELS'))
self.sample_rate = int(get_config_value('SAMPLE_RATE'))
self.bits_per_sample = int(get_config_value('BITS_PER_SAMPLE'))
self.stt_model_path = get_config_value('STT_MODEL_PATH')
+ self.wake_word_model_path = get_config_value('WAKE_WORD_MODEL_PATH')
self.snips_model_path = get_config_value('SNIPS_MODEL_PATH')
self.rasa_model_path = get_config_value('RASA_MODEL_PATH')
self.rasa_server_port = int(get_config_value('RASA_SERVER_PORT'))
+ self.rasa_detached_mode = bool(int(get_config_value('RASA_DETACHED_MODE')))
self.base_log_dir = get_config_value('BASE_LOG_DIR')
self.store_voice_command = bool(int(get_config_value('STORE_VOICE_COMMANDS')))
# Initialize class methods
self.stt_model = STTModel(self.stt_model_path, self.sample_rate)
+ self.stt_wake_word_model = STTModel(self.wake_word_model_path, self.sample_rate)
self.snips_interface = SnipsInterface(self.snips_model_path)
self.rasa_interface = RASAInterface(self.rasa_server_port, self.rasa_model_path, self.base_log_dir)
- self.rasa_interface.start_server()
+
+ # Only start RASA server if its not in detached mode, else we assume server is already running
+ if not self.rasa_detached_mode:
+ self.rasa_interface.start_server()
+
self.rvc_stream_uuids = {}
+ self.kuksa_client = KuksaInterface()
+ self.kuksa_client.connect_kuksa_client()
+ self.kuksa_client.authorize_kuksa_client()
+ self.mapper = Intent2VSSMapper()
def CheckServiceStatus(self, request, context):
+ """
+ Check the status of the Voice Agent service including the version.
+ """
response = voice_agent_pb2.ServiceStatus(
version=self.service_version,
status=True
@@ -61,13 +84,16 @@ class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer):
def DetectWakeWord(self, request, context):
+ """
+ Detect the wake word using the wake word detection model.
+ """
wake_word_detector = WakeWordDetector(self.wake_word, self.stt_model, self.channels, self.sample_rate, self.bits_per_sample)
wake_word_detector.create_pipeline()
detection_thread = threading.Thread(target=wake_word_detector.start_listening)
detection_thread.start()
while True:
status = wake_word_detector.get_wake_word_status()
- time.sleep(1)
+ time.sleep(0.5)
if not context.is_active():
wake_word_detector.send_eos()
break
@@ -79,6 +105,9 @@ class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer):
def RecognizeVoiceCommand(self, requests, context):
+ """
+ Recognize the voice command using the STT model and extract the intent using the NLU model.
+ """
stt = ""
intent = ""
intent_slots = []
@@ -154,3 +183,71 @@ class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer):
status=status
)
return response
+
+
+ def ExecuteVoiceCommand(self, request, context):
+ """
+ Execute the voice command by sending the intent to Kuksa.
+ """
+ intent = request.intent
+ intent_slots = request.intent_slots
+ processed_slots = []
+ for slot in intent_slots:
+ slot_name = slot.name
+ slot_value = slot.value
+ processed_slots.append({"name": slot_name, "value": slot_value})
+
+ print(intent)
+ print(processed_slots)
+ execution_list = self.mapper.parse_intent(intent, processed_slots)
+ exec_response = f"Sorry, I failed to execute command against intent '{intent}'. Maybe try again with more specific instructions."
+ exec_status = voice_agent_pb2.EXEC_ERROR
+
+ for execution_item in execution_list:
+ print(execution_item)
+ action = execution_item["action"]
+ signal = execution_item["signal"]
+
+ if self.kuksa_client.get_kuksa_status():
+ if action == "set" and "value" in execution_item:
+ value = execution_item["value"]
+ if self.kuksa_client.send_values(signal, value):
+ exec_response = f"Yay, I successfully updated the intent '{intent}' to value '{value}'."
+ exec_status = voice_agent_pb2.EXEC_SUCCESS
+
+ elif action in ["increase", "decrease"]:
+ if "value" in execution_item:
+ value = execution_item["value"]
+ if self.kuksa_client.send_values(signal, value):
+ exec_response = f"Yay, I successfully updated the intent '{intent}' to value '{value}'."
+ exec_status = voice_agent_pb2.EXEC_SUCCESS
+
+ elif "factor" in execution_item:
+ factor = execution_item["factor"]
+ current_value = self.kuksa_client.get_value(signal)
+ if current_value:
+ current_value = int(current_value)
+ if action == "increase":
+ value = current_value + factor
+ value = str(value)
+ elif action == "decrease":
+ value = current_value - factor
+ value = str(value)
+ if self.kuksa_client.send_values(signal, value):
+ exec_response = f"Yay, I successfully updated the intent '{intent}' to value '{value}'."
+ exec_status = voice_agent_pb2.EXEC_SUCCESS
+
+ else:
+ exec_response = f"Uh oh, there is no value set for intent '{intent}'. Why not try setting a value first?"
+ exec_status = voice_agent_pb2.EXEC_KUKSA_CONN_ERROR
+
+ else:
+ exec_response = "Uh oh, I failed to connect to Kuksa."
+ exec_status = voice_agent_pb2.EXEC_KUKSA_CONN_ERROR
+
+ response = voice_agent_pb2.ExecuteResult(
+ response=exec_response,
+ status=exec_status
+ )
+
+ return response
diff --git a/agl_service_voiceagent/utils/audio_recorder.py b/agl_service_voiceagent/utils/audio_recorder.py
index 61ce994..2e8f11d 100644
--- a/agl_service_voiceagent/utils/audio_recorder.py
+++ b/agl_service_voiceagent/utils/audio_recorder.py
@@ -15,7 +15,6 @@
# limitations under the License.
import gi
-import vosk
import time
gi.require_version('Gst', '1.0')
from gi.repository import Gst, GLib
@@ -24,7 +23,21 @@ Gst.init(None)
GLib.threads_init()
class AudioRecorder:
+ """
+ AudioRecorder is a class for recording audio using GStreamer in various modes.
+ """
+
def __init__(self, stt_model, audio_files_basedir, channels=1, sample_rate=16000, bits_per_sample=16):
+ """
+ Initialize the AudioRecorder instance with the provided parameters.
+
+ Args:
+ stt_model (str): The speech-to-text model to use for voice input recognition.
+ audio_files_basedir (str): The base directory for saving audio files.
+ channels (int, optional): The number of audio channels (default is 1).
+ sample_rate (int, optional): The audio sample rate in Hz (default is 16000).
+ bits_per_sample (int, optional): The number of bits per sample (default is 16).
+ """
self.loop = GLib.MainLoop()
self.mode = None
self.pipeline = None
@@ -43,6 +56,12 @@ class AudioRecorder:
def create_pipeline(self):
+ """
+ Create and configure the GStreamer audio recording pipeline.
+
+ Returns:
+ str: The name of the audio file being recorded.
+ """
print("Creating pipeline for audio recording in {} mode...".format(self.mode))
self.pipeline = Gst.Pipeline()
autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None)
@@ -83,32 +102,50 @@ class AudioRecorder:
def start_recording(self):
+ """
+ Start recording audio using the GStreamer pipeline.
+ """
self.pipeline.set_state(Gst.State.PLAYING)
print("Recording Voice Input...")
def stop_recording(self):
+ """
+ Stop audio recording and clean up the GStreamer pipeline.
+ """
print("Stopping recording...")
- self.frames_above_threshold = 0
- self.cleanup_pipeline()
+ # self.cleanup_pipeline()
+ self.pipeline.send_event(Gst.Event.new_eos())
print("Recording finished!")
def set_pipeline_mode(self, mode):
+ """
+ Set the recording mode to 'auto' or 'manual'.
+
+ Args:
+ mode (str): The recording mode ('auto' or 'manual').
+ """
self.mode = mode
- # this method helps with error handling
def on_bus_message(self, bus, message):
+ """
+ Handle GStreamer bus messages and perform actions based on the message type.
+
+ Args:
+ bus (Gst.Bus): The GStreamer bus.
+ message (Gst.Message): The GStreamer message to process.
+ """
if message.type == Gst.MessageType.EOS:
print("End-of-stream message received")
- self.stop_recording()
+ self.cleanup_pipeline()
elif message.type == Gst.MessageType.ERROR:
err, debug_info = message.parse_error()
print(f"Error received from element {message.src.get_name()}: {err.message}")
print(f"Debugging information: {debug_info}")
- self.stop_recording()
+ self.cleanup_pipeline()
elif message.type == Gst.MessageType.WARNING:
err, debug_info = message.parse_warning()
@@ -136,6 +173,9 @@ class AudioRecorder:
def cleanup_pipeline(self):
+ """
+ Clean up the GStreamer pipeline, set it to NULL state, and remove the signal watch.
+ """
if self.pipeline is not None:
print("Cleaning up pipeline...")
self.pipeline.set_state(Gst.State.NULL)
diff --git a/agl_service_voiceagent/utils/common.py b/agl_service_voiceagent/utils/common.py
index 682473e..b9d6577 100644
--- a/agl_service_voiceagent/utils/common.py
+++ b/agl_service_voiceagent/utils/common.py
@@ -20,12 +20,18 @@ import json
def add_trailing_slash(path):
+ """
+ Adds a trailing slash to a path if it does not already have one.
+ """
if path and not path.endswith('/'):
path += '/'
return path
def generate_unique_uuid(length):
+ """
+ Generates a unique ID of specified length.
+ """
unique_id = str(uuid.uuid4().int)
# Ensure the generated ID is exactly 'length' digits by taking the last 'length' characters
unique_id = unique_id[-length:]
@@ -33,18 +39,91 @@ def generate_unique_uuid(length):
def load_json_file(file_path):
- try:
- with open(file_path, 'r') as file:
- return json.load(file)
- except FileNotFoundError:
- raise ValueError(f"File '{file_path}' not found.")
+ """
+ Loads a JSON file and returns the data.
+ """
+ try:
+ with open(file_path, 'r') as file:
+ return json.load(file)
+ except FileNotFoundError:
+ raise ValueError(f"File '{file_path}' not found.")
def delete_file(file_path):
+ """
+ Deletes a file if it exists.
+ """
if os.path.exists(file_path):
try:
os.remove(file_path)
except Exception as e:
print(f"Error deleting '{file_path}': {e}")
else:
- print(f"File '{file_path}' does not exist.") \ No newline at end of file
+ print(f"File '{file_path}' does not exist.")
+
+
+def words_to_number(words):
+ """
+ Converts a string of words to a number.
+ """
+ word_to_number = {
+ 'one': 1,
+ 'two': 2,
+ 'three': 3,
+ 'four': 4,
+ 'five': 5,
+ 'six': 6,
+ 'seven': 7,
+ 'eight': 8,
+ 'nine': 9,
+ 'ten': 10,
+ 'eleven': 11,
+ 'twelve': 12,
+ 'thirteen': 13,
+ 'fourteen': 14,
+ 'fifteen': 15,
+ 'sixteen': 16,
+ 'seventeen': 17,
+ 'eighteen': 18,
+ 'nineteen': 19,
+ 'twenty': 20,
+ 'thirty': 30,
+ 'forty': 40,
+ 'fourty': 40,
+ 'fifty': 50,
+ 'sixty': 60,
+ 'seventy': 70,
+ 'eighty': 80,
+ 'ninety': 90
+ }
+
+ # Split the input words and initialize total and current number
+ words = words.split()
+
+ if len(words) == 1:
+ if words[0] in ["zero", "min", "minimum"]:
+ return 0
+
+ elif words[0] in ["half", "halfway"]:
+ return 50
+
+ elif words[0] in ["max", "maximum", "full", "fully", "completely", "hundred"]:
+ return 100
+
+
+ total = 0
+ current_number = 0
+
+ for word in words:
+ if word in word_to_number:
+ current_number += word_to_number[word]
+ elif word == 'hundred':
+ current_number *= 100
+ else:
+ total += current_number
+ current_number = 0
+
+ total += current_number
+
+ # we return number in str format because kuksa expects str input
+ return str(total) or None \ No newline at end of file
diff --git a/agl_service_voiceagent/utils/config.py b/agl_service_voiceagent/utils/config.py
index 8d7f346..7295c7f 100644
--- a/agl_service_voiceagent/utils/config.py
+++ b/agl_service_voiceagent/utils/config.py
@@ -14,21 +14,41 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-import os
import configparser
-# Get the absolute path to the directory of the current script
-current_dir = os.path.dirname(os.path.abspath(__file__))
-# Construct the path to the config.ini file located in the base directory
-config_path = os.path.join(current_dir, '..', 'config.ini')
-
config = configparser.ConfigParser()
-config.read(config_path)
+config_path = None
+
+def set_config_path(path):
+ """
+ Sets the path to the config file.
+ """
+ global config_path
+ config_path = path
+ config.read(config_path)
+
+def load_config():
+ """
+ Loads the config file.
+ """
+ if config_path is not None:
+ config.read(config_path)
+ else:
+ raise Exception("Config file path not provided.")
def update_config_value(value, key, group="General"):
+ """
+ Updates a value in the config file.
+ """
+ if config_path is None:
+ raise Exception("Config file path not set.")
+
config.set(group, key, value)
with open(config_path, 'w') as configfile:
config.write(configfile)
def get_config_value(key, group="General"):
+ """
+ Gets a value from the config file.
+ """
return config.get(group, key)
diff --git a/agl_service_voiceagent/utils/kuksa_interface.py b/agl_service_voiceagent/utils/kuksa_interface.py
index 3e1c045..9270379 100644
--- a/agl_service_voiceagent/utils/kuksa_interface.py
+++ b/agl_service_voiceagent/utils/kuksa_interface.py
@@ -14,53 +14,197 @@
# See the License for the specific language governing permissions and
# limitations under the License.
+import time
+import json
+import threading
from kuksa_client import KuksaClientThread
from agl_service_voiceagent.utils.config import get_config_value
class KuksaInterface:
- def __init__(self):
+ """
+ Kuksa Interface
+
+ This class provides methods to initialize, authorize, connect, send values,
+ check the status, and close the Kuksa client.
+ """
+
+ _instance = None
+ _lock = threading.Lock()
+
+ def __new__(cls):
+ """
+ Get the unique instance of the class.
+
+ Returns:
+ KuksaInterface: The instance of the class.
+ """
+ with cls._lock:
+ if cls._instance is None:
+ cls._instance = super(KuksaInterface, cls).__new__(cls)
+ cls._instance.init_client()
+ return cls._instance
+
+
+ def init_client(self):
+ """
+ Initialize the Kuksa client configuration.
+ """
# get config values
- self.ip = get_config_value("ip", "Kuksa")
- self.port = get_config_value("port", "Kuksa")
+ self.ip = str(get_config_value("ip", "Kuksa"))
+ self.port = str(get_config_value("port", "Kuksa"))
self.insecure = get_config_value("insecure", "Kuksa")
self.protocol = get_config_value("protocol", "Kuksa")
self.token = get_config_value("token", "Kuksa")
+ print(self.ip, self.port, self.insecure, self.protocol, self.token)
+
# define class methods
self.kuksa_client = None
def get_kuksa_client(self):
+ """
+ Get the Kuksa client instance.
+
+ Returns:
+ KuksaClientThread: The Kuksa client instance.
+ """
return self.kuksa_client
-
+
def get_kuksa_status(self):
+ """
+ Check the status of the Kuksa client connection.
+
+ Returns:
+ bool: True if the client is connected, False otherwise.
+ """
if self.kuksa_client:
return self.kuksa_client.checkConnection()
-
return False
def connect_kuksa_client(self):
+ """
+ Connect and start the Kuksa client.
+ """
try:
- self.kuksa_client = KuksaClientThread({
- "ip": self.ip,
- "port": self.port,
- "insecure": self.insecure,
- "protocol": self.protocol,
- })
- self.kuksa_client.authorize(self.token)
+ with self._lock:
+ if self.kuksa_client is None:
+ self.kuksa_client = KuksaClientThread({
+ "ip": self.ip,
+ "port": self.port,
+ "insecure": self.insecure,
+ "protocol": self.protocol,
+ })
+ self.kuksa_client.start()
+ time.sleep(2) # Give the thread time to start
+
+ if not self.get_kuksa_status():
+ print("[-] Error: Connection to Kuksa server failed.")
+ else:
+ print("[+] Connection to Kuksa established.")
except Exception as e:
- print("Error: ", e)
+ print("[-] Error: Connection to Kuksa server failed. ", str(e))
+
+
+ def authorize_kuksa_client(self):
+ """
+ Authorize the Kuksa client with the provided token.
+ """
+ if self.kuksa_client:
+ response = self.kuksa_client.authorize(self.token)
+ response = json.loads(response)
+ if "error" in response:
+ error_message = response.get("error", "Unknown error")
+ print(f"[-] Error: Authorization failed. {error_message}")
+ else:
+ print("[+] Kuksa client authorized successfully.")
+ else:
+ print("[-] Error: Kuksa client is not initialized. Call `connect_kuksa_client` first.")
+
+
+ def send_values(self, path=None, value=None):
+ """
+ Send values to the Kuksa server.
+
+ Args:
+ path (str): The path to the value.
+ value (str): The value to be sent.
+ Returns:
+ bool: True if the value was set, False otherwise.
+ """
+ result = False
+ if self.kuksa_client is None:
+ print("[-] Error: Kuksa client is not initialized.")
+ return
+
+ if self.get_kuksa_status():
+ try:
+ response = self.kuksa_client.setValue(path, value)
+ response = json.loads(response)
+ if not "error" in response:
+ print(f"[+] Value '{value}' sent to Kuksa successfully.")
+ result = True
+ else:
+ error_message = response.get("error", "Unknown error")
+ print(f"[-] Error: Failed to send value '{value}' to Kuksa. {error_message}")
+
+ except Exception as e:
+ print("[-] Error: Failed to send values to Kuksa. ", str(e))
+
+ else:
+ print("[-] Error: Connection to Kuksa failed.")
+
+ return result
- def send_values(self, Path=None, Value=None):
+ def get_value(self, path=None):
+ """
+ Get values from the Kuksa server.
+
+ Args:
+ path (str): The path to the value.
+ Returns:
+ str: The value if the path is valid, None otherwise.
+ """
+ result = None
if self.kuksa_client is None:
- print("Error: Kuksa client is not initialized.")
+ print("[-] Error: Kuksa client is not initialized.")
+ return
if self.get_kuksa_status():
- self.kuksa_client.setValue(Path, Value)
+ try:
+ response = self.kuksa_client.getValue(path)
+ response = json.loads(response)
+ if not "error" in response:
+ result = response.get("data", None)
+ result = result.get("dp", None)
+ result = result.get("value", None)
+
+ else:
+ error_message = response.get("error", "Unknown error")
+ print(f"[-] Error: Failed to get value from Kuksa. {error_message}")
+
+ except Exception as e:
+ print("[-] Error: Failed to get values from Kuksa. ", str(e))
else:
- print("Error: Connection to Kuksa failed.")
+ print("[-] Error: Connection to Kuksa failed.")
+
+ return result
+
+
+ def close_kuksa_client(self):
+ """
+ Close and stop the Kuksa client.
+ """
+ try:
+ with self._lock:
+ if self.kuksa_client:
+ self.kuksa_client.stop()
+ self.kuksa_client = None
+ print("[+] Kuksa client stopped.")
+ except Exception as e:
+ print("[-] Error: Failed to close Kuksa client. ", str(e)) \ No newline at end of file
diff --git a/agl_service_voiceagent/utils/mapper.py b/agl_service_voiceagent/utils/mapper.py
new file mode 100644
index 0000000..7529645
--- /dev/null
+++ b/agl_service_voiceagent/utils/mapper.py
@@ -0,0 +1,261 @@
+# SPDX-License-Identifier: Apache-2.0
+#
+# Copyright (c) 2023 Malik Talha
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from agl_service_voiceagent.utils.config import get_config_value
+from agl_service_voiceagent.utils.common import load_json_file, words_to_number
+
+
+class Intent2VSSMapper:
+ """
+ Intent2VSSMapper is a class that facilitates the mapping of natural language intent to
+ corresponding vehicle signal specifications (VSS) for automated vehicle control systems.
+ """
+
+ def __init__(self):
+ """
+ Initializes the Intent2VSSMapper class by loading Intent-to-VSS signal mappings
+ and VSS signal specifications from external configuration files.
+ """
+ intents_vss_map_file = get_config_value("intents_vss_map", "Mapper")
+ vss_signals_spec_file = get_config_value("vss_signals_spec", "Mapper")
+ self.intents_vss_map = load_json_file(intents_vss_map_file).get("intents", {})
+ self.vss_signals_spec = load_json_file(vss_signals_spec_file).get("signals", {})
+
+ if not self.validate_signal_spec_structure():
+ raise ValueError("[-] Invalid VSS signal specification structure.")
+
+ def validate_signal_spec_structure(self):
+ """
+ Validates the structure of the VSS signal specification data.
+ """
+
+ signals = self.vss_signals_spec
+
+ # Iterate over each signal in the 'signals' dictionary
+ for signal_name, signal_data in signals.items():
+ # Check if the required keys are present in the signal data
+ if not all(key in signal_data for key in ['default_value', 'default_change_factor', 'actions', 'values', 'default_fallback', 'value_set_intents']):
+ print(f"[-] {signal_name}: Missing required keys in signal data.")
+ return False
+
+ actions = signal_data['actions']
+
+ # Check if 'actions' is a dictionary with at least one action
+ if not isinstance(actions, dict) or not actions:
+ print(f"[-] {signal_name}: Invalid 'actions' key in signal data. Must be an object with at least one action.")
+ return False
+
+ # Check if the actions match the allowed actions ["set", "increase", "decrease"]
+ for action in actions.keys():
+ if action not in ["set", "increase", "decrease"]:
+ print(f"[-] {signal_name}: Invalid action in signal data. Allowed actions: ['set', 'increase', 'decrease']")
+ return False
+
+ # Check if the 'synonyms' list is present for each action and is either a list or None
+ for action_data in actions.values():
+ synonyms = action_data.get('synonyms')
+ if synonyms is not None and (not isinstance(synonyms, list) or not all(isinstance(synonym, str) for synonym in synonyms)):
+ print(f"[-] {signal_name}: Invalid 'synonyms' value in signal data. Must be a list of strings.")
+ return False
+
+ values = signal_data['values']
+
+ # Check if 'values' is a dictionary with the required keys
+ if not isinstance(values, dict) or not all(key in values for key in ['ranged', 'start', 'end', 'ignore', 'additional']):
+ print(f"[-] {signal_name}: Invalid 'values' key in signal data. Required keys: ['ranged', 'start', 'end', 'ignore', 'additional']")
+ return False
+
+ # Check if 'ranged' is a boolean
+ if not isinstance(values['ranged'], bool):
+ print(f"[-] {signal_name}: Invalid 'ranged' value in signal data. Allowed values: [true, false]")
+ return False
+
+ default_fallback = signal_data['default_fallback']
+
+ # Check if 'default_fallback' is a boolean
+ if not isinstance(default_fallback, bool):
+ print(f"[-] {signal_name}: Invalid 'default_fallback' value in signal data. Allowed values: [true, false]")
+ return False
+
+ # If all checks pass, the self.vss_signals_spec structure is valid
+ return True
+
+
+ def map_intent_to_signal(self, intent_name):
+ """
+ Maps an intent name to the corresponding VSS signals and their specifications.
+
+ Args:
+ intent_name (str): The name of the intent to be mapped.
+
+ Returns:
+ dict: A dictionary containing VSS signals as keys and their specifications as values.
+ """
+
+ intent_data = self.intents_vss_map.get(intent_name, None)
+ result = {}
+ if intent_data:
+ signals = intent_data.get("signals", [])
+
+ for signal in signals:
+ signal_info = self.vss_signals_spec.get(signal, {})
+ if signal_info:
+ result.update({signal: signal_info})
+
+ return result
+
+
+ def parse_intent(self, intent_name, intent_slots = []):
+ """
+ Parses an intent, extracting relevant VSS signals, actions, modifiers, and values
+ based on the intent and its associated slots.
+
+ Args:
+ intent_name (str): The name of the intent to be parsed.
+ intent_slots (list): A list of dictionaries representing intent slots.
+
+ Returns:
+ list: A list of dictionaries describing actions and signal-related details for execution.
+
+ Note:
+ - If no relevant VSS signals are found for the intent, an empty list is returned.
+ - If no specific action or modifier is determined, default values are used.
+ """
+ vss_signal_data = self.map_intent_to_signal(intent_name)
+ execution_list = []
+ for signal_name, signal_data in vss_signal_data.items():
+ action = self.determine_action(signal_data, intent_slots)
+ modifier = self.determine_modifier(signal_data, intent_slots)
+ value = self.determine_value(signal_data, intent_slots)
+
+ if value != None and not self.verify_value(signal_data, value):
+ value = None
+
+ change_factor = signal_data["default_change_factor"]
+
+ if action in ["increase", "decrease"]:
+ if value and modifier == "to":
+ execution_list.append({"action": action, "signal": signal_name, "value": str(value)})
+
+ elif value and modifier == "by":
+ execution_list.append({"action": action, "signal": signal_name, "factor": str(value)})
+
+ elif value:
+ execution_list.append({"action": action, "signal": signal_name, "value": str(value)})
+
+ elif signal_data["default_fallback"]:
+ execution_list.append({"action": action, "signal": signal_name, "factor": str(change_factor)})
+
+ # if no value found set the default value
+ if value == None and signal_data["default_fallback"]:
+ value = signal_data["default_value"]
+
+ if action == "set" and value != None:
+ execution_list.append({"action": action, "signal": signal_name, "value": str(value)})
+
+
+ return execution_list
+
+
+ def determine_action(self, signal_data, intent_slots):
+ """
+ Determines the action (e.g., set, increase, decrease) based on the intent slots
+ and VSS signal data.
+
+ Args:
+ signal_data (dict): The specification data for a VSS signal.
+ intent_slots (list): A list of dictionaries representing intent slots.
+
+ Returns:
+ str: The determined action or None if no action can be determined.
+ """
+ action_res = None
+ for intent_slot in intent_slots:
+ for action, action_data in signal_data["actions"].items():
+ if intent_slot["name"] in action_data["intents"] and intent_slot["value"] in action_data["synonyms"]:
+ action_res = action
+ break
+
+ return action_res
+
+
+ def determine_modifier(self, signal_data, intent_slots):
+ """
+ Determines the modifier (e.g., 'to' or 'by') based on the intent slots
+ and VSS signal data.
+
+ Args:
+ signal_data (dict): The specification data for a VSS signal.
+ intent_slots (list): A list of dictionaries representing intent slots.
+
+ Returns:
+ str: The determined modifier or None if no modifier can be determined.
+ """
+ modifier_res = None
+ for intent_slot in intent_slots:
+ for _, action_data in signal_data["actions"].items():
+ intent_val = intent_slot["value"]
+ if "modifier_intents" in action_data and intent_slot["name"] in action_data["modifier_intents"] and ("to" in intent_val or "by" in intent_val):
+ modifier_res = "to" if "to" in intent_val else "by" if "by" in intent_val else None
+ break
+
+ return modifier_res
+
+
+ def determine_value(self, signal_data, intent_slots):
+ """
+ Determines the value associated with the intent slot, considering the data type
+ and converting it to a numeric string representation if necessary.
+
+ Args:
+ signal_data (dict): The specification data for a VSS signal.
+ intent_slots (list): A list of dictionaries representing intent slots.
+
+ Returns:
+ str: The determined value or None if no value can be determined.
+ """
+ result = None
+ for intent_slot in intent_slots:
+ for value, value_data in signal_data["value_set_intents"].items():
+ if intent_slot["name"] == value:
+ result = intent_slot["value"]
+
+ if value_data["datatype"] == "number":
+ result = words_to_number(result) # we assume our model will always return a number in words
+
+ # the value should always returned as str because Kuksa expects str values
+ return str(result) if result != None else None
+
+
+ def verify_value(self, signal_data, value):
+ """
+ Verifies that the value is valid based on the VSS signal data.
+
+ Args:
+ signal_data (dict): The specification data for a VSS signal.
+ value (str): The value to be verified.
+
+ Returns:
+ bool: True if the value is valid, False otherwise.
+ """
+ if value in signal_data["values"]["ignore"]:
+ return False
+
+ elif signal_data["values"]["ranged"] and isinstance(value, (int, float)):
+ return value >= signal_data["values"]["start"] and value <= signal_data["values"]["end"]
+
+ else:
+ return value in signal_data["values"]["additional"]
diff --git a/agl_service_voiceagent/utils/stt_model.py b/agl_service_voiceagent/utils/stt_model.py
index 5337162..d51ae31 100644
--- a/agl_service_voiceagent/utils/stt_model.py
+++ b/agl_service_voiceagent/utils/stt_model.py
@@ -21,21 +21,61 @@ import wave
from agl_service_voiceagent.utils.common import generate_unique_uuid
class STTModel:
+ """
+ STTModel is a class for speech-to-text (STT) recognition using the Vosk speech recognition library.
+ """
+
def __init__(self, model_path, sample_rate=16000):
+ """
+ Initialize the STTModel instance with the provided model and sample rate.
+
+ Args:
+ model_path (str): The path to the Vosk speech recognition model.
+ sample_rate (int, optional): The audio sample rate in Hz (default is 16000).
+ """
self.sample_rate = sample_rate
self.model = vosk.Model(model_path)
self.recognizer = {}
self.chunk_size = 1024
+
def setup_recognizer(self):
+ """
+ Set up a Vosk recognizer for a new session and return a unique identifier (UUID) for the session.
+
+ Returns:
+ str: A unique identifier (UUID) for the session.
+ """
uuid = generate_unique_uuid(6)
self.recognizer[uuid] = vosk.KaldiRecognizer(self.model, self.sample_rate)
return uuid
+
def init_recognition(self, uuid, audio_data):
+ """
+ Initialize the Vosk recognizer for a session with audio data.
+
+ Args:
+ uuid (str): The unique identifier (UUID) for the session.
+ audio_data (bytes): Audio data to process.
+
+ Returns:
+ bool: True if initialization was successful, False otherwise.
+ """
return self.recognizer[uuid].AcceptWaveform(audio_data)
+
def recognize(self, uuid, partial=False):
+ """
+ Recognize speech and return the result as a JSON object.
+
+ Args:
+ uuid (str): The unique identifier (UUID) for the session.
+ partial (bool, optional): If True, return partial recognition results (default is False).
+
+ Returns:
+ dict: A JSON object containing recognition results.
+ """
self.recognizer[uuid].SetWords(True)
if partial:
result = json.loads(self.recognizer[uuid].PartialResult())
@@ -44,7 +84,18 @@ class STTModel:
self.recognizer[uuid].Reset()
return result
+
def recognize_from_file(self, uuid, filename):
+ """
+ Recognize speech from an audio file and return the recognized text.
+
+ Args:
+ uuid (str): The unique identifier (UUID) for the session.
+ filename (str): The path to the audio file.
+
+ Returns:
+ str: The recognized text or error messages.
+ """
if not os.path.exists(filename):
print(f"Audio file '{filename}' not found.")
return "FILE_NOT_FOUND"
@@ -75,31 +126,13 @@ class STTModel:
print("Voice not recognized. Please speak again...")
return "VOICE_NOT_RECOGNIZED"
- def cleanup_recognizer(self, uuid):
- del self.recognizer[uuid]
-import wave
-
-def read_wav_file(filename, chunk_size=1024):
- try:
- wf = wave.open(filename, "rb")
- if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE":
- print("Audio file must be WAV format mono PCM.")
- return "FILE_FORMAT_INVALID"
-
- audio_data = b"" # Initialize an empty bytes object to store audio data
- while True:
- chunk = wf.readframes(chunk_size)
- if not chunk:
- break # End of file reached
- audio_data += chunk
-
- return audio_data
- except Exception as e:
- print(f"Error reading audio file: {e}")
- return None
+ def cleanup_recognizer(self, uuid):
+ """
+ Clean up and remove the Vosk recognizer for a session.
-# Example usage:
-filename = "your_audio.wav"
-audio_data = read_wav_file(filename)
+ Args:
+ uuid (str): The unique identifier (UUID) for the session.
+ """
+ del self.recognizer[uuid]
\ No newline at end of file
diff --git a/agl_service_voiceagent/utils/wake_word.py b/agl_service_voiceagent/utils/wake_word.py
index 066ae6d..47e547e 100644
--- a/agl_service_voiceagent/utils/wake_word.py
+++ b/agl_service_voiceagent/utils/wake_word.py
@@ -15,7 +15,6 @@
# limitations under the License.
import gi
-import vosk
gi.require_version('Gst', '1.0')
from gi.repository import Gst, GLib
@@ -23,7 +22,21 @@ Gst.init(None)
GLib.threads_init()
class WakeWordDetector:
+ """
+ WakeWordDetector is a class for detecting a wake word in an audio stream using GStreamer and Vosk.
+ """
+
def __init__(self, wake_word, stt_model, channels=1, sample_rate=16000, bits_per_sample=16):
+ """
+ Initialize the WakeWordDetector instance with the provided parameters.
+
+ Args:
+ wake_word (str): The wake word to detect in the audio stream.
+ stt_model (STTModel): An instance of the STTModel for speech-to-text recognition.
+ channels (int, optional): The number of audio channels (default is 1).
+ sample_rate (int, optional): The audio sample rate in Hz (default is 16000).
+ bits_per_sample (int, optional): The number of bits per sample (default is 16).
+ """
self.loop = GLib.MainLoop()
self.pipeline = None
self.bus = None
@@ -32,16 +45,25 @@ class WakeWordDetector:
self.sample_rate = sample_rate
self.channels = channels
self.bits_per_sample = bits_per_sample
- self.frame_size = int(self.sample_rate * 0.02)
- self.stt_model = stt_model # Speech to text model recognizer
+ self.wake_word_model = stt_model # Speech to text model recognizer
self.recognizer_uuid = stt_model.setup_recognizer()
- self.buffer_duration = 1 # Buffer audio for atleast 1 second
self.audio_buffer = bytearray()
+ self.segment_size = int(self.sample_rate * 1.0) # Adjust the segment size (e.g., 1 second)
+
def get_wake_word_status(self):
+ """
+ Get the status of wake word detection.
+
+ Returns:
+ bool: True if the wake word has been detected, False otherwise.
+ """
return self.wake_word_detected
def create_pipeline(self):
+ """
+ Create and configure the GStreamer audio processing pipeline for wake word detection.
+ """
print("Creating pipeline for Wake Word Detection...")
self.pipeline = Gst.Pipeline()
autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None)
@@ -77,49 +99,87 @@ class WakeWordDetector:
self.bus.add_signal_watch()
self.bus.connect("message", self.on_bus_message)
+
def on_new_buffer(self, appsink, data) -> Gst.FlowReturn:
+ """
+ Callback function to handle new audio buffers from GStreamer appsink.
+
+ Args:
+ appsink (Gst.AppSink): The GStreamer appsink.
+ data (object): User data (not used).
+
+ Returns:
+ Gst.FlowReturn: Indicates the status of buffer processing.
+ """
sample = appsink.emit("pull-sample")
buffer = sample.get_buffer()
data = buffer.extract_dup(0, buffer.get_size())
+
+ # Add the new data to the buffer
self.audio_buffer.extend(data)
- if len(self.audio_buffer) >= self.sample_rate * self.buffer_duration * self.channels * self.bits_per_sample // 8:
- self.process_audio_buffer()
+ # Process audio in segments
+ while len(self.audio_buffer) >= self.segment_size:
+ segment = self.audio_buffer[:self.segment_size]
+ self.process_audio_segment(segment)
+
+ # Advance the buffer by the segment size
+ self.audio_buffer = self.audio_buffer[self.segment_size:]
return Gst.FlowReturn.OK
-
- def process_audio_buffer(self):
- # Process the accumulated audio data using the audio model
- audio_data = bytes(self.audio_buffer)
- if self.stt_model.init_recognition(self.recognizer_uuid, audio_data):
- stt_result = self.stt_model.recognize(self.recognizer_uuid)
+ def process_audio_segment(self, segment):
+ """
+ Process an audio segment for wake word detection.
+
+ Args:
+ segment (bytes): The audio segment to process.
+ """
+ # Process the audio data segment
+ audio_data = bytes(segment)
+
+ # Perform wake word detection on the audio_data
+ if self.wake_word_model.init_recognition(self.recognizer_uuid, audio_data):
+ stt_result = self.wake_word_model.recognize(self.recognizer_uuid)
print("STT Result: ", stt_result)
if self.wake_word in stt_result["text"]:
self.wake_word_detected = True
print("Wake word detected!")
self.pipeline.send_event(Gst.Event.new_eos())
- self.audio_buffer.clear() # Clear the buffer
-
-
def send_eos(self):
+ """
+ Send an End-of-Stream (EOS) event to the pipeline.
+ """
self.pipeline.send_event(Gst.Event.new_eos())
self.audio_buffer.clear()
def start_listening(self):
+ """
+ Start listening for the wake word and enter the event loop.
+ """
self.pipeline.set_state(Gst.State.PLAYING)
print("Listening for Wake Word...")
self.loop.run()
def stop_listening(self):
+ """
+ Stop listening for the wake word and clean up the pipeline.
+ """
self.cleanup_pipeline()
self.loop.quit()
def on_bus_message(self, bus, message):
+ """
+ Handle GStreamer bus messages and perform actions based on the message type.
+
+ Args:
+ bus (Gst.Bus): The GStreamer bus.
+ message (Gst.Message): The GStreamer message to process.
+ """
if message.type == Gst.MessageType.EOS:
print("End-of-stream message received")
self.stop_listening()
@@ -140,6 +200,9 @@ class WakeWordDetector:
def cleanup_pipeline(self):
+ """
+ Clean up the GStreamer pipeline and release associated resources.
+ """
if self.pipeline is not None:
print("Cleaning up pipeline...")
self.pipeline.set_state(Gst.State.NULL)
@@ -147,4 +210,4 @@ class WakeWordDetector:
print("Pipeline cleanup complete!")
self.bus = None
self.pipeline = None
- self.stt_model.cleanup_recognizer(self.recognizer_uuid)
+ self.wake_word_model.cleanup_recognizer(self.recognizer_uuid)