diff options
author | Malik Talha <talhamalik727x@gmail.com> | 2023-09-14 22:41:26 +0500 |
---|---|---|
committer | Malik Talha <talhamalik727x@gmail.com> | 2023-09-25 00:40:38 +0500 |
commit | a10c988b5480ca5b937a2793b450cfa01f569d76 (patch) | |
tree | 23c032557a36afd671c7b7db9d6dd843253ae835 /agl_service_voiceagent | |
parent | 3e300cdc7fff19e5f338b282266444061f74506e (diff) |
Add gRPC-based voice agent service for AGL
Introducing a gRPC-based voice agent service for Automotive Grade Linux
(AGL) that leverages GStreamer, Vosk, Snips, and RASA. It seamlessly
processes user voice commands, converting spoken words to text,
extracting intents, and performing actions via the Kuksa interface (WIP).
Bug-AGL: SPEC-4906
Signed-off-by: Malik Talha <talhamalik727x@gmail.com>
Change-Id: I47e61c66149c67bb97fecc745e4c3afd79f447a5
Diffstat (limited to 'agl_service_voiceagent')
20 files changed, 1185 insertions, 0 deletions
diff --git a/agl_service_voiceagent/__init__.py b/agl_service_voiceagent/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/agl_service_voiceagent/__init__.py diff --git a/agl_service_voiceagent/client.py b/agl_service_voiceagent/client.py new file mode 100644 index 0000000..9b2e0a0 --- /dev/null +++ b/agl_service_voiceagent/client.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import grpc +from agl_service_voiceagent.generated import voice_agent_pb2 +from agl_service_voiceagent.generated import voice_agent_pb2_grpc +from agl_service_voiceagent.utils.config import get_config_value + +# following code is only reqired for logging +import logging +logging.basicConfig() +logging.getLogger("grpc").setLevel(logging.DEBUG) + +SERVER_URL = get_config_value('SERVER_ADDRESS') + ":" + str(get_config_value('SERVER_PORT')) + +def run_client(mode, nlu_model): + nlu_model = voice_agent_pb2.SNIPS if nlu_model == "snips" else voice_agent_pb2.RASA + print("Starting Voice Agent Client...") + print(f"Client connecting to URL: {SERVER_URL}") + with grpc.insecure_channel(SERVER_URL) as channel: + print("Press Ctrl+C to stop the client.") + print("Voice Agent Client started!") + if mode == 'wake-word': + stub = voice_agent_pb2_grpc.VoiceAgentServiceStub(channel) + print("Listening for wake word...") + wake_request = voice_agent_pb2.Empty() + wake_results = stub.DetectWakeWord(wake_request) + wake_word_detected = False + for wake_result in wake_results: + print("Wake word status: ", wake_word_detected) + if wake_result.status: + print("Wake word status: ", wake_result.status) + wake_word_detected = True + break + + elif mode == 'auto': + raise ValueError("Auto mode is not implemented yet.") + + elif mode == 'manual': + stub = voice_agent_pb2_grpc.VoiceAgentServiceStub(channel) + print("Recording voice command...") + record_start_request = voice_agent_pb2.RecognizeControl(action=voice_agent_pb2.START, nlu_model=nlu_model, record_mode=voice_agent_pb2.MANUAL) + response = stub.RecognizeVoiceCommand(iter([record_start_request])) + stream_id = response.stream_id + time.sleep(5) # any arbitrary pause here + record_stop_request = voice_agent_pb2.RecognizeControl(action=voice_agent_pb2.STOP, nlu_model=nlu_model, record_mode=voice_agent_pb2.MANUAL, stream_id=stream_id) + record_result = stub.RecognizeVoiceCommand(iter([record_stop_request])) + print("Voice command recorded!") + + status = "Uh oh! Status is unknown." + if record_result.status == voice_agent_pb2.REC_SUCCESS: + status = "Yay! Status is success." + elif record_result.status == voice_agent_pb2.VOICE_NOT_RECOGNIZED: + status = "Voice not recognized." + elif record_result.status == voice_agent_pb2.INTENT_NOT_RECOGNIZED: + status = "Intent not recognized." + + # Process the response + print("Command:", record_result.command) + print("Status:", status) + print("Intent:", record_result.intent) + for slot in record_result.intent_slots: + print("Slot Name:", slot.name) + print("Slot Value:", slot.value)
\ No newline at end of file diff --git a/agl_service_voiceagent/config.ini b/agl_service_voiceagent/config.ini new file mode 100644 index 0000000..9455d6a --- /dev/null +++ b/agl_service_voiceagent/config.ini @@ -0,0 +1,22 @@ +[General] +service_version = 0.2.0 +base_audio_dir = /usr/share/nlu/commands/ +stt_model_path = /usr/share/vosk/vosk-model-small-en-us-0.15 +snips_model_path = /usr/share/nlu/snips/model/ +channels = 1 +sample_rate = 16000 +bits_per_sample = 16 +wake_word = hello auto +server_port = 51053 +server_address = localhost +rasa_model_path = /usr/share/nlu/rasa/models/ +rasa_server_port = 51054 +base_log_dir = /usr/share/nlu/logs/ +store_voice_commands = 0 + +[Kuksa] +ip = localhost +port = 8090 +protocol = ws +insecure = False +token = / diff --git a/agl_service_voiceagent/generated/__init__.py b/agl_service_voiceagent/generated/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/agl_service_voiceagent/generated/__init__.py diff --git a/agl_service_voiceagent/nlu/__init__.py b/agl_service_voiceagent/nlu/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/agl_service_voiceagent/nlu/__init__.py diff --git a/agl_service_voiceagent/nlu/rasa_interface.py b/agl_service_voiceagent/nlu/rasa_interface.py new file mode 100644 index 0000000..0232126 --- /dev/null +++ b/agl_service_voiceagent/nlu/rasa_interface.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import time +import requests +import subprocess +from concurrent.futures import ThreadPoolExecutor + +class RASAInterface: + def __init__(self, port, model_path, log_dir, max_threads=5): + self.port = port + self.model_path = model_path + self.max_threads = max_threads + self.server_process = None + self.thread_pool = ThreadPoolExecutor(max_workers=max_threads) + self.log_file = log_dir+"rasa_server_logs.txt" + + + def _start_server(self): + command = ( + f"rasa run --enable-api -m \"{self.model_path}\" -p {self.port}" + ) + # Redirect stdout and stderr to capture the output + with open(self.log_file, "w") as output_file: + self.server_process = subprocess.Popen(command, shell=True, stdout=output_file, stderr=subprocess.STDOUT) + self.server_process.wait() # Wait for the server process to finish + + + def start_server(self): + self.thread_pool.submit(self._start_server) + + # Wait for a brief moment to allow the server to start + time.sleep(25) + + + def stop_server(self): + if self.server_process: + self.server_process.terminate() + self.server_process.wait() + self.server_process = None + self.thread_pool.shutdown(wait=True) + + + def preprocess_text(self, text): + # text to lower case and remove trailing and leading spaces + preprocessed_text = text.lower().strip() + # remove special characters, punctuation, and extra whitespaces + preprocessed_text = re.sub(r'[^\w\s]', '', preprocessed_text).strip() + return preprocessed_text + + + def extract_intent(self, text): + preprocessed_text = self.preprocess_text(text) + url = f"http://localhost:{self.port}/model/parse" + data = { + "text": preprocessed_text + } + response = requests.post(url, json=data) + if response.status_code == 200: + return response.json() + else: + return None + + + def process_intent(self, intent_output): + intent = intent_output["intent"]["name"] + entities = {} + for entity in intent_output["entities"]: + entity_name = entity["entity"] + entity_value = entity["value"] + entities[entity_name] = entity_value + + return intent, entities
\ No newline at end of file diff --git a/agl_service_voiceagent/nlu/snips_interface.py b/agl_service_voiceagent/nlu/snips_interface.py new file mode 100644 index 0000000..f0b05d2 --- /dev/null +++ b/agl_service_voiceagent/nlu/snips_interface.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Text +from snips_inference_agl import SnipsNLUEngine + +class SnipsInterface: + def __init__(self, model_path: Text): + self.engine = SnipsNLUEngine.from_path(model_path) + + def preprocess_text(self, text): + # text to lower case and remove trailing and leading spaces + preprocessed_text = text.lower().strip() + # remove special characters, punctuation, and extra whitespaces + preprocessed_text = re.sub(r'[^\w\s]', '', preprocessed_text).strip() + return preprocessed_text + + def extract_intent(self, text: Text): + preprocessed_text = self.preprocess_text(text) + result = self.engine.parse(preprocessed_text) + return result + + def process_intent(self, intent_output): + intent_actions = {} + intent = intent_output['intent']['intentName'] + slots = intent_output.get('slots', []) + for slot in slots: + action = slot['entity'] + value = slot['value']['value'] + intent_actions[action] = value + + return intent, intent_actions
\ No newline at end of file diff --git a/agl_service_voiceagent/protos/__init__.py b/agl_service_voiceagent/protos/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/agl_service_voiceagent/protos/__init__.py diff --git a/agl_service_voiceagent/protos/voice_agent.proto b/agl_service_voiceagent/protos/voice_agent.proto new file mode 100644 index 0000000..8ee8324 --- /dev/null +++ b/agl_service_voiceagent/protos/voice_agent.proto @@ -0,0 +1,82 @@ +syntax = "proto3"; + + +service VoiceAgentService { + rpc CheckServiceStatus(Empty) returns (ServiceStatus); + rpc DetectWakeWord(Empty) returns (stream WakeWordStatus); + rpc RecognizeVoiceCommand(stream RecognizeControl) returns (RecognizeResult); + rpc ExecuteVoiceCommand(ExecuteInput) returns (ExecuteResult); +} + + +enum RecordAction { + START = 0; + STOP = 1; +} + +enum NLUModel { + SNIPS = 0; + RASA = 1; +} + +enum RecordMode { + MANUAL = 0; + AUTO = 1; +} + +enum RecognizeStatusType { + REC_ERROR = 0; + REC_SUCCESS = 1; + REC_PROCESSING = 2; + VOICE_NOT_RECOGNIZED = 3; + INTENT_NOT_RECOGNIZED = 4; +} + +enum ExecuteStatusType { + EXEC_ERROR = 0; + EXEC_SUCCESS = 1; + KUKSA_CONN_ERROR = 2; + INTENT_NOT_SUPPORTED = 3; + INTENT_SLOTS_INCOMPLETE = 4; +} + + +message Empty {} + +message ServiceStatus { + string version = 1; + bool status = 2; +} + +message WakeWordStatus { + bool status = 1; +} + +message RecognizeControl { + RecordAction action = 1; + NLUModel nlu_model = 2; + RecordMode record_mode = 3; + string stream_id = 4; +} + +message IntentSlot { + string name = 1; + string value = 2; +} + +message RecognizeResult { + string command = 1; + string intent = 2; + repeated IntentSlot intent_slots = 3; + string stream_id = 4; + RecognizeStatusType status = 5; +} + +message ExecuteInput { + string intent = 1; + repeated IntentSlot intent_slots = 2; +} + +message ExecuteResult { + ExecuteStatusType status = 1; +} diff --git a/agl_service_voiceagent/server.py b/agl_service_voiceagent/server.py new file mode 100644 index 0000000..5fe0746 --- /dev/null +++ b/agl_service_voiceagent/server.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import grpc +from concurrent import futures +from agl_service_voiceagent.generated import voice_agent_pb2_grpc +from agl_service_voiceagent.servicers.voice_agent_servicer import VoiceAgentServicer +from agl_service_voiceagent.utils.config import get_config_value + +SERVER_URL = get_config_value('SERVER_ADDRESS') + ":" + str(get_config_value('SERVER_PORT')) + +def run_server(): + print("Starting Voice Agent Service...") + print(f"Server running at URL: {SERVER_URL}") + print(f"STT Model Path: {get_config_value('STT_MODEL_PATH')}") + print(f"Audio Store Directory: {get_config_value('BASE_AUDIO_DIR')}") + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + voice_agent_pb2_grpc.add_VoiceAgentServiceServicer_to_server(VoiceAgentServicer(), server) + server.add_insecure_port(SERVER_URL) + print("Press Ctrl+C to stop the server.") + print("Voice Agent Server started!") + server.start() + server.wait_for_termination()
\ No newline at end of file diff --git a/agl_service_voiceagent/service.py b/agl_service_voiceagent/service.py new file mode 100644 index 0000000..1b34c27 --- /dev/null +++ b/agl_service_voiceagent/service.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +# Get the path to the directory containing this script +current_dir = os.path.dirname(os.path.abspath(__file__)) +# Construct the path to the "generated" folder +generated_dir = os.path.join(current_dir, "generated") +# Add the "generated" folder to sys.path +sys.path.append(generated_dir) + +import argparse +from agl_service_voiceagent.utils.config import update_config_value, get_config_value +from agl_service_voiceagent.utils.common import add_trailing_slash +from agl_service_voiceagent.server import run_server +from agl_service_voiceagent.client import run_client + + +def print_version(): + print("Automotive Grade Linux (AGL)") + print(f"Voice Agent Service v{get_config_value('SERVICE_VERSION')}") + + +def main(): + parser = argparse.ArgumentParser(description="Automotive Grade Linux (AGL) - Voice Agent Service") + parser.add_argument('--version', action='store_true', help='Show version') + + subparsers = parser.add_subparsers(dest='subcommand', title='Available Commands') + subparsers.required = False + + # Create subparsers for "run server" and "run client" + server_parser = subparsers.add_parser('run-server', help='Run the Voice Agent gRPC Server') + client_parser = subparsers.add_parser('run-client', help='Run the Voice Agent gRPC Client') + + server_parser.add_argument('--config', action='store_true', help='Starts the server solely based on values provided in config file.') + server_parser.add_argument('--stt-model-path', required=False, help='Path to the Speech To Text model. Currently only supports VOSK Kaldi.') + server_parser.add_argument('--snips-model-path', required=False, help='Path to the Snips NLU model.') + server_parser.add_argument('--rasa-model-path', required=False, help='Path to the RASA NLU model.') + server_parser.add_argument('--audio-store-dir', required=False, help='Directory to store the generated audio files.') + server_parser.add_argument('--log-store-dir', required=False, help='Directory to store the generated log files.') + + client_parser.add_argument('--mode', required=True, help='Mode to run the client in. Supported modes: "wake-word", "auto" and "manual".') + client_parser.add_argument('--nlu', required=True, help='NLU engine to use. Supported NLU egnines: "snips" and "rasa".') + + args = parser.parse_args() + + if args.version: + print_version() + + elif args.subcommand == 'run-server': + if not args.config: + if not args.stt_model_path: + raise ValueError("The --stt-model-path is missing. Please provide a value. Use --help to see available options.") + + if not args.snips_model_path: + raise ValueError("The --snips-model-path is missing. Please provide a value. Use --help to see available options.") + + if not args.rasa_model_path: + raise ValueError("The --rasa-model-path is missing. Please provide a value. Use --help to see available options.") + + stt_path = args.stt_model_path + snips_model_path = args.snips_model_path + rasa_model_path = args.rasa_model_path + + # Convert to an absolute path if it's a relative path + stt_path = add_trailing_slash(os.path.abspath(stt_path)) if not os.path.isabs(stt_path) else stt_path + snips_model_path = add_trailing_slash(os.path.abspath(snips_model_path)) if not os.path.isabs(snips_model_path) else snips_model_path + rasa_model_path = add_trailing_slash(os.path.abspath(rasa_model_path)) if not os.path.isabs(rasa_model_path) else rasa_model_path + + # Also update the config.ini file + update_config_value(stt_path, 'STT_MODEL_PATH') + update_config_value(snips_model_path, 'SNIPS_MODEL_PATH') + update_config_value(rasa_model_path, 'RASA_MODEL_PATH') + + # Update the audio store dir in config.ini if provided + audio_dir = args.audio_store_dir or get_config_value('BASE_AUDIO_DIR') + audio_dir = add_trailing_slash(os.path.abspath(audio_dir)) if not os.path.isabs(audio_dir) else audio_dir + update_config_value(audio_dir, 'BASE_AUDIO_DIR') + + # Update the log store dir in config.ini if provided + log_dir = args.log_store_dir or get_config_value('BASE_LOG_DIR') + log_dir = add_trailing_slash(os.path.abspath(log_dir)) if not os.path.isabs(log_dir) else log_dir + update_config_value(log_dir, 'BASE_LOG_DIR') + + + # create the base audio dir if not exists + if not os.path.exists(get_config_value('BASE_AUDIO_DIR')): + os.makedirs(get_config_value('BASE_AUDIO_DIR')) + + # create the base log dir if not exists + if not os.path.exists(get_config_value('BASE_LOG_DIR')): + os.makedirs(get_config_value('BASE_LOG_DIR')) + + run_server() + + elif args.subcommand == 'run-client': + mode = args.mode + if mode not in ['wake-word', 'auto', 'manual']: + raise ValueError("Invalid mode. Supported modes: 'wake-word', 'auto' and 'manual'. Use --help to see available options.") + + model = args.nlu + if model not in ['snips', 'rasa']: + raise ValueError("Invalid NLU engine. Supported NLU engines: 'snips' and 'rasa'. Use --help to see available options.") + + run_client(mode, model) + + else: + print_version() + print("Use --help to see available options.") + + +if __name__ == '__main__': + main()
\ No newline at end of file diff --git a/agl_service_voiceagent/servicers/__init__.py b/agl_service_voiceagent/servicers/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/agl_service_voiceagent/servicers/__init__.py diff --git a/agl_service_voiceagent/servicers/voice_agent_servicer.py b/agl_service_voiceagent/servicers/voice_agent_servicer.py new file mode 100644 index 0000000..4038b85 --- /dev/null +++ b/agl_service_voiceagent/servicers/voice_agent_servicer.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import grpc +import time +import threading +from generated import voice_agent_pb2 +from generated import voice_agent_pb2_grpc +from agl_service_voiceagent.utils.audio_recorder import AudioRecorder +from agl_service_voiceagent.utils.wake_word import WakeWordDetector +from agl_service_voiceagent.utils.stt_model import STTModel +from agl_service_voiceagent.utils.config import get_config_value +from agl_service_voiceagent.nlu.snips_interface import SnipsInterface +from agl_service_voiceagent.nlu.rasa_interface import RASAInterface +from agl_service_voiceagent.utils.common import generate_unique_uuid, delete_file + + +class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer): + def __init__(self): + # Get the config values + self.service_version = get_config_value('SERVICE_VERSION') + self.wake_word = get_config_value('WAKE_WORD') + self.base_audio_dir = get_config_value('BASE_AUDIO_DIR') + self.channels = int(get_config_value('CHANNELS')) + self.sample_rate = int(get_config_value('SAMPLE_RATE')) + self.bits_per_sample = int(get_config_value('BITS_PER_SAMPLE')) + self.stt_model_path = get_config_value('STT_MODEL_PATH') + self.snips_model_path = get_config_value('SNIPS_MODEL_PATH') + self.rasa_model_path = get_config_value('RASA_MODEL_PATH') + self.rasa_server_port = int(get_config_value('RASA_SERVER_PORT')) + self.base_log_dir = get_config_value('BASE_LOG_DIR') + self.store_voice_command = bool(int(get_config_value('STORE_VOICE_COMMANDS'))) + + # Initialize class methods + self.stt_model = STTModel(self.stt_model_path, self.sample_rate) + self.snips_interface = SnipsInterface(self.snips_model_path) + self.rasa_interface = RASAInterface(self.rasa_server_port, self.rasa_model_path, self.base_log_dir) + self.rasa_interface.start_server() + self.rvc_stream_uuids = {} + + + def CheckServiceStatus(self, request, context): + response = voice_agent_pb2.ServiceStatus( + version=self.service_version, + status=True + ) + return response + + + def DetectWakeWord(self, request, context): + wake_word_detector = WakeWordDetector(self.wake_word, self.stt_model, self.channels, self.sample_rate, self.bits_per_sample) + wake_word_detector.create_pipeline() + detection_thread = threading.Thread(target=wake_word_detector.start_listening) + detection_thread.start() + while True: + status = wake_word_detector.get_wake_word_status() + time.sleep(1) + if not context.is_active(): + wake_word_detector.send_eos() + break + yield voice_agent_pb2.WakeWordStatus(status=status) + if status: + break + + detection_thread.join() + + + def RecognizeVoiceCommand(self, requests, context): + stt = "" + intent = "" + intent_slots = [] + + for request in requests: + if request.record_mode == voice_agent_pb2.MANUAL: + + if request.action == voice_agent_pb2.START: + status = voice_agent_pb2.REC_PROCESSING + stream_uuid = generate_unique_uuid(8) + recorder = AudioRecorder(self.stt_model, self.base_audio_dir, self.channels, self.sample_rate, self.bits_per_sample) + recorder.set_pipeline_mode("manual") + audio_file = recorder.create_pipeline() + + self.rvc_stream_uuids[stream_uuid] = { + "recorder": recorder, + "audio_file": audio_file + } + + recorder.start_recording() + + elif request.action == voice_agent_pb2.STOP: + stream_uuid = request.stream_id + status = voice_agent_pb2.REC_SUCCESS + + recorder = self.rvc_stream_uuids[stream_uuid]["recorder"] + audio_file = self.rvc_stream_uuids[stream_uuid]["audio_file"] + del self.rvc_stream_uuids[stream_uuid] + + recorder.stop_recording() + recognizer_uuid = self.stt_model.setup_recognizer() + stt = self.stt_model.recognize_from_file(recognizer_uuid, audio_file) + + if stt not in ["FILE_NOT_FOUND", "FILE_FORMAT_INVALID", "VOICE_NOT_RECOGNIZED", ""]: + if request.nlu_model == voice_agent_pb2.SNIPS: + extracted_intent = self.snips_interface.extract_intent(stt) + intent, intent_actions = self.snips_interface.process_intent(extracted_intent) + + if not intent or intent == "": + status = voice_agent_pb2.INTENT_NOT_RECOGNIZED + + for action, value in intent_actions.items(): + intent_slots.append(voice_agent_pb2.IntentSlot(name=action, value=value)) + + elif request.nlu_model == voice_agent_pb2.RASA: + extracted_intent = self.rasa_interface.extract_intent(stt) + intent, intent_actions = self.rasa_interface.process_intent(extracted_intent) + + if not intent or intent == "": + status = voice_agent_pb2.INTENT_NOT_RECOGNIZED + + for action, value in intent_actions.items(): + intent_slots.append(voice_agent_pb2.IntentSlot(name=action, value=value)) + + else: + stt = "" + status = voice_agent_pb2.VOICE_NOT_RECOGNIZED + + # cleanup the kaldi recognizer + self.stt_model.cleanup_recognizer(recognizer_uuid) + + # delete the audio file + if not self.store_voice_command: + delete_file(audio_file) + + + # Process the request and generate a RecognizeResult + response = voice_agent_pb2.RecognizeResult( + command=stt, + intent=intent, + intent_slots=intent_slots, + stream_id=stream_uuid, + status=status + ) + return response diff --git a/agl_service_voiceagent/utils/__init__.py b/agl_service_voiceagent/utils/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/agl_service_voiceagent/utils/__init__.py diff --git a/agl_service_voiceagent/utils/audio_recorder.py b/agl_service_voiceagent/utils/audio_recorder.py new file mode 100644 index 0000000..61ce994 --- /dev/null +++ b/agl_service_voiceagent/utils/audio_recorder.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gi +import vosk +import time +gi.require_version('Gst', '1.0') +from gi.repository import Gst, GLib + +Gst.init(None) +GLib.threads_init() + +class AudioRecorder: + def __init__(self, stt_model, audio_files_basedir, channels=1, sample_rate=16000, bits_per_sample=16): + self.loop = GLib.MainLoop() + self.mode = None + self.pipeline = None + self.bus = None + self.audio_files_basedir = audio_files_basedir + self.sample_rate = sample_rate + self.channels = channels + self.bits_per_sample = bits_per_sample + self.frame_size = int(self.sample_rate * 0.02) + self.audio_model = stt_model + self.buffer_duration = 1 # Buffer audio for atleast 1 second + self.audio_buffer = bytearray() + self.energy_threshold = 50000 # Adjust this threshold as needed + self.silence_frames_threshold = 10 + self.frames_above_threshold = 0 + + + def create_pipeline(self): + print("Creating pipeline for audio recording in {} mode...".format(self.mode)) + self.pipeline = Gst.Pipeline() + autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None) + queue = Gst.ElementFactory.make("queue", None) + audioconvert = Gst.ElementFactory.make("audioconvert", None) + wavenc = Gst.ElementFactory.make("wavenc", None) + + capsfilter = Gst.ElementFactory.make("capsfilter", None) + caps = Gst.Caps.new_empty_simple("audio/x-raw") + caps.set_value("format", "S16LE") + caps.set_value("rate", self.sample_rate) + caps.set_value("channels", self.channels) + capsfilter.set_property("caps", caps) + + self.pipeline.add(autoaudiosrc) + self.pipeline.add(queue) + self.pipeline.add(audioconvert) + self.pipeline.add(wavenc) + self.pipeline.add(capsfilter) + + autoaudiosrc.link(queue) + queue.link(audioconvert) + audioconvert.link(capsfilter) + + audio_file_name = f"{self.audio_files_basedir}{int(time.time())}.wav" + + filesink = Gst.ElementFactory.make("filesink", None) + filesink.set_property("location", audio_file_name) + self.pipeline.add(filesink) + capsfilter.link(wavenc) + wavenc.link(filesink) + + self.bus = self.pipeline.get_bus() + self.bus.add_signal_watch() + self.bus.connect("message", self.on_bus_message) + + return audio_file_name + + + def start_recording(self): + self.pipeline.set_state(Gst.State.PLAYING) + print("Recording Voice Input...") + + + def stop_recording(self): + print("Stopping recording...") + self.frames_above_threshold = 0 + self.cleanup_pipeline() + print("Recording finished!") + + + def set_pipeline_mode(self, mode): + self.mode = mode + + + # this method helps with error handling + def on_bus_message(self, bus, message): + if message.type == Gst.MessageType.EOS: + print("End-of-stream message received") + self.stop_recording() + + elif message.type == Gst.MessageType.ERROR: + err, debug_info = message.parse_error() + print(f"Error received from element {message.src.get_name()}: {err.message}") + print(f"Debugging information: {debug_info}") + self.stop_recording() + + elif message.type == Gst.MessageType.WARNING: + err, debug_info = message.parse_warning() + print(f"Warning received from element {message.src.get_name()}: {err.message}") + print(f"Debugging information: {debug_info}") + + elif message.type == Gst.MessageType.STATE_CHANGED: + if isinstance(message.src, Gst.Pipeline): + old_state, new_state, pending_state = message.parse_state_changed() + print(("Pipeline state changed from %s to %s." % + (old_state.value_nick, new_state.value_nick))) + + elif self.mode == "auto" and message.type == Gst.MessageType.ELEMENT: + if message.get_structure().get_name() == "level": + rms = message.get_structure()["rms"][0] + if rms > self.energy_threshold: + self.frames_above_threshold += 1 + # if self.frames_above_threshold >= self.silence_frames_threshold: + # self.start_recording() + else: + if self.frames_above_threshold > 0: + self.frames_above_threshold -= 1 + if self.frames_above_threshold == 0: + self.stop_recording() + + + def cleanup_pipeline(self): + if self.pipeline is not None: + print("Cleaning up pipeline...") + self.pipeline.set_state(Gst.State.NULL) + self.bus.remove_signal_watch() + print("Pipeline cleanup complete!") + self.bus = None + self.pipeline = None diff --git a/agl_service_voiceagent/utils/common.py b/agl_service_voiceagent/utils/common.py new file mode 100644 index 0000000..682473e --- /dev/null +++ b/agl_service_voiceagent/utils/common.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +import json + + +def add_trailing_slash(path): + if path and not path.endswith('/'): + path += '/' + return path + + +def generate_unique_uuid(length): + unique_id = str(uuid.uuid4().int) + # Ensure the generated ID is exactly 'length' digits by taking the last 'length' characters + unique_id = unique_id[-length:] + return unique_id + + +def load_json_file(file_path): + try: + with open(file_path, 'r') as file: + return json.load(file) + except FileNotFoundError: + raise ValueError(f"File '{file_path}' not found.") + + +def delete_file(file_path): + if os.path.exists(file_path): + try: + os.remove(file_path) + except Exception as e: + print(f"Error deleting '{file_path}': {e}") + else: + print(f"File '{file_path}' does not exist.")
\ No newline at end of file diff --git a/agl_service_voiceagent/utils/config.py b/agl_service_voiceagent/utils/config.py new file mode 100644 index 0000000..8d7f346 --- /dev/null +++ b/agl_service_voiceagent/utils/config.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import configparser + +# Get the absolute path to the directory of the current script +current_dir = os.path.dirname(os.path.abspath(__file__)) +# Construct the path to the config.ini file located in the base directory +config_path = os.path.join(current_dir, '..', 'config.ini') + +config = configparser.ConfigParser() +config.read(config_path) + +def update_config_value(value, key, group="General"): + config.set(group, key, value) + with open(config_path, 'w') as configfile: + config.write(configfile) + +def get_config_value(key, group="General"): + return config.get(group, key) diff --git a/agl_service_voiceagent/utils/kuksa_interface.py b/agl_service_voiceagent/utils/kuksa_interface.py new file mode 100644 index 0000000..3e1c045 --- /dev/null +++ b/agl_service_voiceagent/utils/kuksa_interface.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from kuksa_client import KuksaClientThread +from agl_service_voiceagent.utils.config import get_config_value + +class KuksaInterface: + def __init__(self): + # get config values + self.ip = get_config_value("ip", "Kuksa") + self.port = get_config_value("port", "Kuksa") + self.insecure = get_config_value("insecure", "Kuksa") + self.protocol = get_config_value("protocol", "Kuksa") + self.token = get_config_value("token", "Kuksa") + + # define class methods + self.kuksa_client = None + + + def get_kuksa_client(self): + return self.kuksa_client + + + def get_kuksa_status(self): + if self.kuksa_client: + return self.kuksa_client.checkConnection() + + return False + + + def connect_kuksa_client(self): + try: + self.kuksa_client = KuksaClientThread({ + "ip": self.ip, + "port": self.port, + "insecure": self.insecure, + "protocol": self.protocol, + }) + self.kuksa_client.authorize(self.token) + + except Exception as e: + print("Error: ", e) + + + def send_values(self, Path=None, Value=None): + if self.kuksa_client is None: + print("Error: Kuksa client is not initialized.") + + if self.get_kuksa_status(): + self.kuksa_client.setValue(Path, Value) + + else: + print("Error: Connection to Kuksa failed.") diff --git a/agl_service_voiceagent/utils/stt_model.py b/agl_service_voiceagent/utils/stt_model.py new file mode 100644 index 0000000..5337162 --- /dev/null +++ b/agl_service_voiceagent/utils/stt_model.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import vosk +import wave +from agl_service_voiceagent.utils.common import generate_unique_uuid + +class STTModel: + def __init__(self, model_path, sample_rate=16000): + self.sample_rate = sample_rate + self.model = vosk.Model(model_path) + self.recognizer = {} + self.chunk_size = 1024 + + def setup_recognizer(self): + uuid = generate_unique_uuid(6) + self.recognizer[uuid] = vosk.KaldiRecognizer(self.model, self.sample_rate) + return uuid + + def init_recognition(self, uuid, audio_data): + return self.recognizer[uuid].AcceptWaveform(audio_data) + + def recognize(self, uuid, partial=False): + self.recognizer[uuid].SetWords(True) + if partial: + result = json.loads(self.recognizer[uuid].PartialResult()) + else: + result = json.loads(self.recognizer[uuid].Result()) + self.recognizer[uuid].Reset() + return result + + def recognize_from_file(self, uuid, filename): + if not os.path.exists(filename): + print(f"Audio file '{filename}' not found.") + return "FILE_NOT_FOUND" + + wf = wave.open(filename, "rb") + if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE": + print("Audio file must be WAV format mono PCM.") + return "FILE_FORMAT_INVALID" + + # audio_data = wf.readframes(wf.getnframes()) + # we need to perform chunking as target AGL system can't handle an entire audio file + audio_data = b"" + while True: + chunk = wf.readframes(self.chunk_size) + if not chunk: + break # End of file reached + audio_data += chunk + + if audio_data: + if self.init_recognition(uuid, audio_data): + result = self.recognize(uuid) + return result['text'] + else: + result = self.recognize(uuid, partial=True) + return result['partial'] + + else: + print("Voice not recognized. Please speak again...") + return "VOICE_NOT_RECOGNIZED" + + def cleanup_recognizer(self, uuid): + del self.recognizer[uuid] + +import wave + +def read_wav_file(filename, chunk_size=1024): + try: + wf = wave.open(filename, "rb") + if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE": + print("Audio file must be WAV format mono PCM.") + return "FILE_FORMAT_INVALID" + + audio_data = b"" # Initialize an empty bytes object to store audio data + while True: + chunk = wf.readframes(chunk_size) + if not chunk: + break # End of file reached + audio_data += chunk + + return audio_data + except Exception as e: + print(f"Error reading audio file: {e}") + return None + +# Example usage: +filename = "your_audio.wav" +audio_data = read_wav_file(filename) +
\ No newline at end of file diff --git a/agl_service_voiceagent/utils/wake_word.py b/agl_service_voiceagent/utils/wake_word.py new file mode 100644 index 0000000..066ae6d --- /dev/null +++ b/agl_service_voiceagent/utils/wake_word.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gi +import vosk +gi.require_version('Gst', '1.0') +from gi.repository import Gst, GLib + +Gst.init(None) +GLib.threads_init() + +class WakeWordDetector: + def __init__(self, wake_word, stt_model, channels=1, sample_rate=16000, bits_per_sample=16): + self.loop = GLib.MainLoop() + self.pipeline = None + self.bus = None + self.wake_word = wake_word + self.wake_word_detected = False + self.sample_rate = sample_rate + self.channels = channels + self.bits_per_sample = bits_per_sample + self.frame_size = int(self.sample_rate * 0.02) + self.stt_model = stt_model # Speech to text model recognizer + self.recognizer_uuid = stt_model.setup_recognizer() + self.buffer_duration = 1 # Buffer audio for atleast 1 second + self.audio_buffer = bytearray() + + def get_wake_word_status(self): + return self.wake_word_detected + + def create_pipeline(self): + print("Creating pipeline for Wake Word Detection...") + self.pipeline = Gst.Pipeline() + autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None) + queue = Gst.ElementFactory.make("queue", None) + audioconvert = Gst.ElementFactory.make("audioconvert", None) + wavenc = Gst.ElementFactory.make("wavenc", None) + + capsfilter = Gst.ElementFactory.make("capsfilter", None) + caps = Gst.Caps.new_empty_simple("audio/x-raw") + caps.set_value("format", "S16LE") + caps.set_value("rate", self.sample_rate) + caps.set_value("channels", self.channels) + capsfilter.set_property("caps", caps) + + appsink = Gst.ElementFactory.make("appsink", None) + appsink.set_property("emit-signals", True) + appsink.set_property("sync", False) # Set sync property to False to enable async processing + appsink.connect("new-sample", self.on_new_buffer, None) + + self.pipeline.add(autoaudiosrc) + self.pipeline.add(queue) + self.pipeline.add(audioconvert) + self.pipeline.add(wavenc) + self.pipeline.add(capsfilter) + self.pipeline.add(appsink) + + autoaudiosrc.link(queue) + queue.link(audioconvert) + audioconvert.link(capsfilter) + capsfilter.link(appsink) + + self.bus = self.pipeline.get_bus() + self.bus.add_signal_watch() + self.bus.connect("message", self.on_bus_message) + + def on_new_buffer(self, appsink, data) -> Gst.FlowReturn: + sample = appsink.emit("pull-sample") + buffer = sample.get_buffer() + data = buffer.extract_dup(0, buffer.get_size()) + self.audio_buffer.extend(data) + + if len(self.audio_buffer) >= self.sample_rate * self.buffer_duration * self.channels * self.bits_per_sample // 8: + self.process_audio_buffer() + + return Gst.FlowReturn.OK + + + def process_audio_buffer(self): + # Process the accumulated audio data using the audio model + audio_data = bytes(self.audio_buffer) + if self.stt_model.init_recognition(self.recognizer_uuid, audio_data): + stt_result = self.stt_model.recognize(self.recognizer_uuid) + print("STT Result: ", stt_result) + if self.wake_word in stt_result["text"]: + self.wake_word_detected = True + print("Wake word detected!") + self.pipeline.send_event(Gst.Event.new_eos()) + + self.audio_buffer.clear() # Clear the buffer + + + def send_eos(self): + self.pipeline.send_event(Gst.Event.new_eos()) + self.audio_buffer.clear() + + + def start_listening(self): + self.pipeline.set_state(Gst.State.PLAYING) + print("Listening for Wake Word...") + self.loop.run() + + + def stop_listening(self): + self.cleanup_pipeline() + self.loop.quit() + + + def on_bus_message(self, bus, message): + if message.type == Gst.MessageType.EOS: + print("End-of-stream message received") + self.stop_listening() + elif message.type == Gst.MessageType.ERROR: + err, debug_info = message.parse_error() + print(f"Error received from element {message.src.get_name()}: {err.message}") + print(f"Debugging information: {debug_info}") + self.stop_listening() + elif message.type == Gst.MessageType.WARNING: + err, debug_info = message.parse_warning() + print(f"Warning received from element {message.src.get_name()}: {err.message}") + print(f"Debugging information: {debug_info}") + elif message.type == Gst.MessageType.STATE_CHANGED: + if isinstance(message.src, Gst.Pipeline): + old_state, new_state, pending_state = message.parse_state_changed() + print(("Pipeline state changed from %s to %s." % + (old_state.value_nick, new_state.value_nick))) + + + def cleanup_pipeline(self): + if self.pipeline is not None: + print("Cleaning up pipeline...") + self.pipeline.set_state(Gst.State.NULL) + self.bus.remove_signal_watch() + print("Pipeline cleanup complete!") + self.bus = None + self.pipeline = None + self.stt_model.cleanup_recognizer(self.recognizer_uuid) |