diff options
Diffstat (limited to 'agl_service_voiceagent')
20 files changed, 1185 insertions, 0 deletions
diff --git a/agl_service_voiceagent/__init__.py b/agl_service_voiceagent/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/agl_service_voiceagent/__init__.py diff --git a/agl_service_voiceagent/client.py b/agl_service_voiceagent/client.py new file mode 100644 index 0000000..9b2e0a0 --- /dev/null +++ b/agl_service_voiceagent/client.py @@ -0,0 +1,78 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import time +import grpc +from agl_service_voiceagent.generated import voice_agent_pb2 +from agl_service_voiceagent.generated import voice_agent_pb2_grpc +from agl_service_voiceagent.utils.config import get_config_value + +# following code is only reqired for logging +import logging +logging.basicConfig() +logging.getLogger("grpc").setLevel(logging.DEBUG) + +SERVER_URL = get_config_value('SERVER_ADDRESS') + ":" + str(get_config_value('SERVER_PORT')) + +def run_client(mode, nlu_model): + nlu_model = voice_agent_pb2.SNIPS if nlu_model == "snips" else voice_agent_pb2.RASA + print("Starting Voice Agent Client...") + print(f"Client connecting to URL: {SERVER_URL}") + with grpc.insecure_channel(SERVER_URL) as channel: + print("Press Ctrl+C to stop the client.") + print("Voice Agent Client started!") + if mode == 'wake-word': + stub = voice_agent_pb2_grpc.VoiceAgentServiceStub(channel) + print("Listening for wake word...") + wake_request = voice_agent_pb2.Empty() + wake_results = stub.DetectWakeWord(wake_request) + wake_word_detected = False + for wake_result in wake_results: + print("Wake word status: ", wake_word_detected) + if wake_result.status: + print("Wake word status: ", wake_result.status) + wake_word_detected = True + break + + elif mode == 'auto': + raise ValueError("Auto mode is not implemented yet.") + + elif mode == 'manual': + stub = voice_agent_pb2_grpc.VoiceAgentServiceStub(channel) + print("Recording voice command...") + record_start_request = voice_agent_pb2.RecognizeControl(action=voice_agent_pb2.START, nlu_model=nlu_model, record_mode=voice_agent_pb2.MANUAL) + response = stub.RecognizeVoiceCommand(iter([record_start_request])) + stream_id = response.stream_id + time.sleep(5) # any arbitrary pause here + record_stop_request = voice_agent_pb2.RecognizeControl(action=voice_agent_pb2.STOP, nlu_model=nlu_model, record_mode=voice_agent_pb2.MANUAL, stream_id=stream_id) + record_result = stub.RecognizeVoiceCommand(iter([record_stop_request])) + print("Voice command recorded!") + + status = "Uh oh! Status is unknown." + if record_result.status == voice_agent_pb2.REC_SUCCESS: + status = "Yay! Status is success." + elif record_result.status == voice_agent_pb2.VOICE_NOT_RECOGNIZED: + status = "Voice not recognized." + elif record_result.status == voice_agent_pb2.INTENT_NOT_RECOGNIZED: + status = "Intent not recognized." + + # Process the response + print("Command:", record_result.command) + print("Status:", status) + print("Intent:", record_result.intent) + for slot in record_result.intent_slots: + print("Slot Name:", slot.name) + print("Slot Value:", slot.value)
\ No newline at end of file diff --git a/agl_service_voiceagent/config.ini b/agl_service_voiceagent/config.ini new file mode 100644 index 0000000..9455d6a --- /dev/null +++ b/agl_service_voiceagent/config.ini @@ -0,0 +1,22 @@ +[General] +service_version = 0.2.0 +base_audio_dir = /usr/share/nlu/commands/ +stt_model_path = /usr/share/vosk/vosk-model-small-en-us-0.15 +snips_model_path = /usr/share/nlu/snips/model/ +channels = 1 +sample_rate = 16000 +bits_per_sample = 16 +wake_word = hello auto +server_port = 51053 +server_address = localhost +rasa_model_path = /usr/share/nlu/rasa/models/ +rasa_server_port = 51054 +base_log_dir = /usr/share/nlu/logs/ +store_voice_commands = 0 + +[Kuksa] +ip = localhost +port = 8090 +protocol = ws +insecure = False +token = / diff --git a/agl_service_voiceagent/generated/__init__.py b/agl_service_voiceagent/generated/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/agl_service_voiceagent/generated/__init__.py diff --git a/agl_service_voiceagent/nlu/__init__.py b/agl_service_voiceagent/nlu/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/agl_service_voiceagent/nlu/__init__.py diff --git a/agl_service_voiceagent/nlu/rasa_interface.py b/agl_service_voiceagent/nlu/rasa_interface.py new file mode 100644 index 0000000..0232126 --- /dev/null +++ b/agl_service_voiceagent/nlu/rasa_interface.py @@ -0,0 +1,87 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import time +import requests +import subprocess +from concurrent.futures import ThreadPoolExecutor + +class RASAInterface: + def __init__(self, port, model_path, log_dir, max_threads=5): + self.port = port + self.model_path = model_path + self.max_threads = max_threads + self.server_process = None + self.thread_pool = ThreadPoolExecutor(max_workers=max_threads) + self.log_file = log_dir+"rasa_server_logs.txt" + + + def _start_server(self): + command = ( + f"rasa run --enable-api -m \"{self.model_path}\" -p {self.port}" + ) + # Redirect stdout and stderr to capture the output + with open(self.log_file, "w") as output_file: + self.server_process = subprocess.Popen(command, shell=True, stdout=output_file, stderr=subprocess.STDOUT) + self.server_process.wait() # Wait for the server process to finish + + + def start_server(self): + self.thread_pool.submit(self._start_server) + + # Wait for a brief moment to allow the server to start + time.sleep(25) + + + def stop_server(self): + if self.server_process: + self.server_process.terminate() + self.server_process.wait() + self.server_process = None + self.thread_pool.shutdown(wait=True) + + + def preprocess_text(self, text): + # text to lower case and remove trailing and leading spaces + preprocessed_text = text.lower().strip() + # remove special characters, punctuation, and extra whitespaces + preprocessed_text = re.sub(r'[^\w\s]', '', preprocessed_text).strip() + return preprocessed_text + + + def extract_intent(self, text): + preprocessed_text = self.preprocess_text(text) + url = f"http://localhost:{self.port}/model/parse" + data = { + "text": preprocessed_text + } + response = requests.post(url, json=data) + if response.status_code == 200: + return response.json() + else: + return None + + + def process_intent(self, intent_output): + intent = intent_output["intent"]["name"] + entities = {} + for entity in intent_output["entities"]: + entity_name = entity["entity"] + entity_value = entity["value"] + entities[entity_name] = entity_value + + return intent, entities
\ No newline at end of file diff --git a/agl_service_voiceagent/nlu/snips_interface.py b/agl_service_voiceagent/nlu/snips_interface.py new file mode 100644 index 0000000..f0b05d2 --- /dev/null +++ b/agl_service_voiceagent/nlu/snips_interface.py @@ -0,0 +1,46 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import Text +from snips_inference_agl import SnipsNLUEngine + +class SnipsInterface: + def __init__(self, model_path: Text): + self.engine = SnipsNLUEngine.from_path(model_path) + + def preprocess_text(self, text): + # text to lower case and remove trailing and leading spaces + preprocessed_text = text.lower().strip() + # remove special characters, punctuation, and extra whitespaces + preprocessed_text = re.sub(r'[^\w\s]', '', preprocessed_text).strip() + return preprocessed_text + + def extract_intent(self, text: Text): + preprocessed_text = self.preprocess_text(text) + result = self.engine.parse(preprocessed_text) + return result + + def process_intent(self, intent_output): + intent_actions = {} + intent = intent_output['intent']['intentName'] + slots = intent_output.get('slots', []) + for slot in slots: + action = slot['entity'] + value = slot['value']['value'] + intent_actions[action] = value + + return intent, intent_actions
\ No newline at end of file diff --git a/agl_service_voiceagent/protos/__init__.py b/agl_service_voiceagent/protos/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/agl_service_voiceagent/protos/__init__.py diff --git a/agl_service_voiceagent/protos/voice_agent.proto b/agl_service_voiceagent/protos/voice_agent.proto new file mode 100644 index 0000000..8ee8324 --- /dev/null +++ b/agl_service_voiceagent/protos/voice_agent.proto @@ -0,0 +1,82 @@ +syntax = "proto3"; + + +service VoiceAgentService { + rpc CheckServiceStatus(Empty) returns (ServiceStatus); + rpc DetectWakeWord(Empty) returns (stream WakeWordStatus); + rpc RecognizeVoiceCommand(stream RecognizeControl) returns (RecognizeResult); + rpc ExecuteVoiceCommand(ExecuteInput) returns (ExecuteResult); +} + + +enum RecordAction { + START = 0; + STOP = 1; +} + +enum NLUModel { + SNIPS = 0; + RASA = 1; +} + +enum RecordMode { + MANUAL = 0; + AUTO = 1; +} + +enum RecognizeStatusType { + REC_ERROR = 0; + REC_SUCCESS = 1; + REC_PROCESSING = 2; + VOICE_NOT_RECOGNIZED = 3; + INTENT_NOT_RECOGNIZED = 4; +} + +enum ExecuteStatusType { + EXEC_ERROR = 0; + EXEC_SUCCESS = 1; + KUKSA_CONN_ERROR = 2; + INTENT_NOT_SUPPORTED = 3; + INTENT_SLOTS_INCOMPLETE = 4; +} + + +message Empty {} + +message ServiceStatus { + string version = 1; + bool status = 2; +} + +message WakeWordStatus { + bool status = 1; +} + +message RecognizeControl { + RecordAction action = 1; + NLUModel nlu_model = 2; + RecordMode record_mode = 3; + string stream_id = 4; +} + +message IntentSlot { + string name = 1; + string value = 2; +} + +message RecognizeResult { + string command = 1; + string intent = 2; + repeated IntentSlot intent_slots = 3; + string stream_id = 4; + RecognizeStatusType status = 5; +} + +message ExecuteInput { + string intent = 1; + repeated IntentSlot intent_slots = 2; +} + +message ExecuteResult { + ExecuteStatusType status = 1; +} diff --git a/agl_service_voiceagent/server.py b/agl_service_voiceagent/server.py new file mode 100644 index 0000000..5fe0746 --- /dev/null +++ b/agl_service_voiceagent/server.py @@ -0,0 +1,36 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import grpc +from concurrent import futures +from agl_service_voiceagent.generated import voice_agent_pb2_grpc +from agl_service_voiceagent.servicers.voice_agent_servicer import VoiceAgentServicer +from agl_service_voiceagent.utils.config import get_config_value + +SERVER_URL = get_config_value('SERVER_ADDRESS') + ":" + str(get_config_value('SERVER_PORT')) + +def run_server(): + print("Starting Voice Agent Service...") + print(f"Server running at URL: {SERVER_URL}") + print(f"STT Model Path: {get_config_value('STT_MODEL_PATH')}") + print(f"Audio Store Directory: {get_config_value('BASE_AUDIO_DIR')}") + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + voice_agent_pb2_grpc.add_VoiceAgentServiceServicer_to_server(VoiceAgentServicer(), server) + server.add_insecure_port(SERVER_URL) + print("Press Ctrl+C to stop the server.") + print("Voice Agent Server started!") + server.start() + server.wait_for_termination()
\ No newline at end of file diff --git a/agl_service_voiceagent/service.py b/agl_service_voiceagent/service.py new file mode 100644 index 0000000..1b34c27 --- /dev/null +++ b/agl_service_voiceagent/service.py @@ -0,0 +1,128 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys + +# Get the path to the directory containing this script +current_dir = os.path.dirname(os.path.abspath(__file__)) +# Construct the path to the "generated" folder +generated_dir = os.path.join(current_dir, "generated") +# Add the "generated" folder to sys.path +sys.path.append(generated_dir) + +import argparse +from agl_service_voiceagent.utils.config import update_config_value, get_config_value +from agl_service_voiceagent.utils.common import add_trailing_slash +from agl_service_voiceagent.server import run_server +from agl_service_voiceagent.client import run_client + + +def print_version(): + print("Automotive Grade Linux (AGL)") + print(f"Voice Agent Service v{get_config_value('SERVICE_VERSION')}") + + +def main(): + parser = argparse.ArgumentParser(description="Automotive Grade Linux (AGL) - Voice Agent Service") + parser.add_argument('--version', action='store_true', help='Show version') + + subparsers = parser.add_subparsers(dest='subcommand', title='Available Commands') + subparsers.required = False + + # Create subparsers for "run server" and "run client" + server_parser = subparsers.add_parser('run-server', help='Run the Voice Agent gRPC Server') + client_parser = subparsers.add_parser('run-client', help='Run the Voice Agent gRPC Client') + + server_parser.add_argument('--config', action='store_true', help='Starts the server solely based on values provided in config file.') + server_parser.add_argument('--stt-model-path', required=False, help='Path to the Speech To Text model. Currently only supports VOSK Kaldi.') + server_parser.add_argument('--snips-model-path', required=False, help='Path to the Snips NLU model.') + server_parser.add_argument('--rasa-model-path', required=False, help='Path to the RASA NLU model.') + server_parser.add_argument('--audio-store-dir', required=False, help='Directory to store the generated audio files.') + server_parser.add_argument('--log-store-dir', required=False, help='Directory to store the generated log files.') + + client_parser.add_argument('--mode', required=True, help='Mode to run the client in. Supported modes: "wake-word", "auto" and "manual".') + client_parser.add_argument('--nlu', required=True, help='NLU engine to use. Supported NLU egnines: "snips" and "rasa".') + + args = parser.parse_args() + + if args.version: + print_version() + + elif args.subcommand == 'run-server': + if not args.config: + if not args.stt_model_path: + raise ValueError("The --stt-model-path is missing. Please provide a value. Use --help to see available options.") + + if not args.snips_model_path: + raise ValueError("The --snips-model-path is missing. Please provide a value. Use --help to see available options.") + + if not args.rasa_model_path: + raise ValueError("The --rasa-model-path is missing. Please provide a value. Use --help to see available options.") + + stt_path = args.stt_model_path + snips_model_path = args.snips_model_path + rasa_model_path = args.rasa_model_path + + # Convert to an absolute path if it's a relative path + stt_path = add_trailing_slash(os.path.abspath(stt_path)) if not os.path.isabs(stt_path) else stt_path + snips_model_path = add_trailing_slash(os.path.abspath(snips_model_path)) if not os.path.isabs(snips_model_path) else snips_model_path + rasa_model_path = add_trailing_slash(os.path.abspath(rasa_model_path)) if not os.path.isabs(rasa_model_path) else rasa_model_path + + # Also update the config.ini file + update_config_value(stt_path, 'STT_MODEL_PATH') + update_config_value(snips_model_path, 'SNIPS_MODEL_PATH') + update_config_value(rasa_model_path, 'RASA_MODEL_PATH') + + # Update the audio store dir in config.ini if provided + audio_dir = args.audio_store_dir or get_config_value('BASE_AUDIO_DIR') + audio_dir = add_trailing_slash(os.path.abspath(audio_dir)) if not os.path.isabs(audio_dir) else audio_dir + update_config_value(audio_dir, 'BASE_AUDIO_DIR') + + # Update the log store dir in config.ini if provided + log_dir = args.log_store_dir or get_config_value('BASE_LOG_DIR') + log_dir = add_trailing_slash(os.path.abspath(log_dir)) if not os.path.isabs(log_dir) else log_dir + update_config_value(log_dir, 'BASE_LOG_DIR') + + + # create the base audio dir if not exists + if not os.path.exists(get_config_value('BASE_AUDIO_DIR')): + os.makedirs(get_config_value('BASE_AUDIO_DIR')) + + # create the base log dir if not exists + if not os.path.exists(get_config_value('BASE_LOG_DIR')): + os.makedirs(get_config_value('BASE_LOG_DIR')) + + run_server() + + elif args.subcommand == 'run-client': + mode = args.mode + if mode not in ['wake-word', 'auto', 'manual']: + raise ValueError("Invalid mode. Supported modes: 'wake-word', 'auto' and 'manual'. Use --help to see available options.") + + model = args.nlu + if model not in ['snips', 'rasa']: + raise ValueError("Invalid NLU engine. Supported NLU engines: 'snips' and 'rasa'. Use --help to see available options.") + + run_client(mode, model) + + else: + print_version() + print("Use --help to see available options.") + + +if __name__ == '__main__': + main()
\ No newline at end of file diff --git a/agl_service_voiceagent/servicers/__init__.py b/agl_service_voiceagent/servicers/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/agl_service_voiceagent/servicers/__init__.py diff --git a/agl_service_voiceagent/servicers/voice_agent_servicer.py b/agl_service_voiceagent/servicers/voice_agent_servicer.py new file mode 100644 index 0000000..4038b85 --- /dev/null +++ b/agl_service_voiceagent/servicers/voice_agent_servicer.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import grpc +import time +import threading +from generated import voice_agent_pb2 +from generated import voice_agent_pb2_grpc +from agl_service_voiceagent.utils.audio_recorder import AudioRecorder +from agl_service_voiceagent.utils.wake_word import WakeWordDetector +from agl_service_voiceagent.utils.stt_model import STTModel +from agl_service_voiceagent.utils.config import get_config_value +from agl_service_voiceagent.nlu.snips_interface import SnipsInterface +from agl_service_voiceagent.nlu.rasa_interface import RASAInterface +from agl_service_voiceagent.utils.common import generate_unique_uuid, delete_file + + +class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer): + def __init__(self): + # Get the config values + self.service_version = get_config_value('SERVICE_VERSION') + self.wake_word = get_config_value('WAKE_WORD') + self.base_audio_dir = get_config_value('BASE_AUDIO_DIR') + self.channels = int(get_config_value('CHANNELS')) + self.sample_rate = int(get_config_value('SAMPLE_RATE')) + self.bits_per_sample = int(get_config_value('BITS_PER_SAMPLE')) + self.stt_model_path = get_config_value('STT_MODEL_PATH') + self.snips_model_path = get_config_value('SNIPS_MODEL_PATH') + self.rasa_model_path = get_config_value('RASA_MODEL_PATH') + self.rasa_server_port = int(get_config_value('RASA_SERVER_PORT')) + self.base_log_dir = get_config_value('BASE_LOG_DIR') + self.store_voice_command = bool(int(get_config_value('STORE_VOICE_COMMANDS'))) + + # Initialize class methods + self.stt_model = STTModel(self.stt_model_path, self.sample_rate) + self.snips_interface = SnipsInterface(self.snips_model_path) + self.rasa_interface = RASAInterface(self.rasa_server_port, self.rasa_model_path, self.base_log_dir) + self.rasa_interface.start_server() + self.rvc_stream_uuids = {} + + + def CheckServiceStatus(self, request, context): + response = voice_agent_pb2.ServiceStatus( + version=self.service_version, + status=True + ) + return response + + + def DetectWakeWord(self, request, context): + wake_word_detector = WakeWordDetector(self.wake_word, self.stt_model, self.channels, self.sample_rate, self.bits_per_sample) + wake_word_detector.create_pipeline() + detection_thread = threading.Thread(target=wake_word_detector.start_listening) + detection_thread.start() + while True: + status = wake_word_detector.get_wake_word_status() + time.sleep(1) + if not context.is_active(): + wake_word_detector.send_eos() + break + yield voice_agent_pb2.WakeWordStatus(status=status) + if status: + break + + detection_thread.join() + + + def RecognizeVoiceCommand(self, requests, context): + stt = "" + intent = "" + intent_slots = [] + + for request in requests: + if request.record_mode == voice_agent_pb2.MANUAL: + + if request.action == voice_agent_pb2.START: + status = voice_agent_pb2.REC_PROCESSING + stream_uuid = generate_unique_uuid(8) + recorder = AudioRecorder(self.stt_model, self.base_audio_dir, self.channels, self.sample_rate, self.bits_per_sample) + recorder.set_pipeline_mode("manual") + audio_file = recorder.create_pipeline() + + self.rvc_stream_uuids[stream_uuid] = { + "recorder": recorder, + "audio_file": audio_file + } + + recorder.start_recording() + + elif request.action == voice_agent_pb2.STOP: + stream_uuid = request.stream_id + status = voice_agent_pb2.REC_SUCCESS + + recorder = self.rvc_stream_uuids[stream_uuid]["recorder"] + audio_file = self.rvc_stream_uuids[stream_uuid]["audio_file"] + del self.rvc_stream_uuids[stream_uuid] + + recorder.stop_recording() + recognizer_uuid = self.stt_model.setup_recognizer() + stt = self.stt_model.recognize_from_file(recognizer_uuid, audio_file) + + if stt not in ["FILE_NOT_FOUND", "FILE_FORMAT_INVALID", "VOICE_NOT_RECOGNIZED", ""]: + if request.nlu_model == voice_agent_pb2.SNIPS: + extracted_intent = self.snips_interface.extract_intent(stt) + intent, intent_actions = self.snips_interface.process_intent(extracted_intent) + + if not intent or intent == "": + status = voice_agent_pb2.INTENT_NOT_RECOGNIZED + + for action, value in intent_actions.items(): + intent_slots.append(voice_agent_pb2.IntentSlot(name=action, value=value)) + + elif request.nlu_model == voice_agent_pb2.RASA: + extracted_intent = self.rasa_interface.extract_intent(stt) + intent, intent_actions = self.rasa_interface.process_intent(extracted_intent) + + if not intent or intent == "": + status = voice_agent_pb2.INTENT_NOT_RECOGNIZED + + for action, value in intent_actions.items(): + intent_slots.append(voice_agent_pb2.IntentSlot(name=action, value=value)) + + else: + stt = "" + status = voice_agent_pb2.VOICE_NOT_RECOGNIZED + + # cleanup the kaldi recognizer + self.stt_model.cleanup_recognizer(recognizer_uuid) + + # delete the audio file + if not self.store_voice_command: + delete_file(audio_file) + + + # Process the request and generate a RecognizeResult + response = voice_agent_pb2.RecognizeResult( + command=stt, + intent=intent, + intent_slots=intent_slots, + stream_id=stream_uuid, + status=status + ) + return response diff --git a/agl_service_voiceagent/utils/__init__.py b/agl_service_voiceagent/utils/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/agl_service_voiceagent/utils/__init__.py diff --git a/agl_service_voiceagent/utils/audio_recorder.py b/agl_service_voiceagent/utils/audio_recorder.py new file mode 100644 index 0000000..61ce994 --- /dev/null +++ b/agl_service_voiceagent/utils/audio_recorder.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gi +import vosk +import time +gi.require_version('Gst', '1.0') +from gi.repository import Gst, GLib + +Gst.init(None) +GLib.threads_init() + +class AudioRecorder: + def __init__(self, stt_model, audio_files_basedir, channels=1, sample_rate=16000, bits_per_sample=16): + self.loop = GLib.MainLoop() + self.mode = None + self.pipeline = None + self.bus = None + self.audio_files_basedir = audio_files_basedir + self.sample_rate = sample_rate + self.channels = channels + self.bits_per_sample = bits_per_sample + self.frame_size = int(self.sample_rate * 0.02) + self.audio_model = stt_model + self.buffer_duration = 1 # Buffer audio for atleast 1 second + self.audio_buffer = bytearray() + self.energy_threshold = 50000 # Adjust this threshold as needed + self.silence_frames_threshold = 10 + self.frames_above_threshold = 0 + + + def create_pipeline(self): + print("Creating pipeline for audio recording in {} mode...".format(self.mode)) + self.pipeline = Gst.Pipeline() + autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None) + queue = Gst.ElementFactory.make("queue", None) + audioconvert = Gst.ElementFactory.make("audioconvert", None) + wavenc = Gst.ElementFactory.make("wavenc", None) + + capsfilter = Gst.ElementFactory.make("capsfilter", None) + caps = Gst.Caps.new_empty_simple("audio/x-raw") + caps.set_value("format", "S16LE") + caps.set_value("rate", self.sample_rate) + caps.set_value("channels", self.channels) + capsfilter.set_property("caps", caps) + + self.pipeline.add(autoaudiosrc) + self.pipeline.add(queue) + self.pipeline.add(audioconvert) + self.pipeline.add(wavenc) + self.pipeline.add(capsfilter) + + autoaudiosrc.link(queue) + queue.link(audioconvert) + audioconvert.link(capsfilter) + + audio_file_name = f"{self.audio_files_basedir}{int(time.time())}.wav" + + filesink = Gst.ElementFactory.make("filesink", None) + filesink.set_property("location", audio_file_name) + self.pipeline.add(filesink) + capsfilter.link(wavenc) + wavenc.link(filesink) + + self.bus = self.pipeline.get_bus() + self.bus.add_signal_watch() + self.bus.connect("message", self.on_bus_message) + + return audio_file_name + + + def start_recording(self): + self.pipeline.set_state(Gst.State.PLAYING) + print("Recording Voice Input...") + + + def stop_recording(self): + print("Stopping recording...") + self.frames_above_threshold = 0 + self.cleanup_pipeline() + print("Recording finished!") + + + def set_pipeline_mode(self, mode): + self.mode = mode + + + # this method helps with error handling + def on_bus_message(self, bus, message): + if message.type == Gst.MessageType.EOS: + print("End-of-stream message received") + self.stop_recording() + + elif message.type == Gst.MessageType.ERROR: + err, debug_info = message.parse_error() + print(f"Error received from element {message.src.get_name()}: {err.message}") + print(f"Debugging information: {debug_info}") + self.stop_recording() + + elif message.type == Gst.MessageType.WARNING: + err, debug_info = message.parse_warning() + print(f"Warning received from element {message.src.get_name()}: {err.message}") + print(f"Debugging information: {debug_info}") + + elif message.type == Gst.MessageType.STATE_CHANGED: + if isinstance(message.src, Gst.Pipeline): + old_state, new_state, pending_state = message.parse_state_changed() + print(("Pipeline state changed from %s to %s." % + (old_state.value_nick, new_state.value_nick))) + + elif self.mode == "auto" and message.type == Gst.MessageType.ELEMENT: + if message.get_structure().get_name() == "level": + rms = message.get_structure()["rms"][0] + if rms > self.energy_threshold: + self.frames_above_threshold += 1 + # if self.frames_above_threshold >= self.silence_frames_threshold: + # self.start_recording() + else: + if self.frames_above_threshold > 0: + self.frames_above_threshold -= 1 + if self.frames_above_threshold == 0: + self.stop_recording() + + + def cleanup_pipeline(self): + if self.pipeline is not None: + print("Cleaning up pipeline...") + self.pipeline.set_state(Gst.State.NULL) + self.bus.remove_signal_watch() + print("Pipeline cleanup complete!") + self.bus = None + self.pipeline = None diff --git a/agl_service_voiceagent/utils/common.py b/agl_service_voiceagent/utils/common.py new file mode 100644 index 0000000..682473e --- /dev/null +++ b/agl_service_voiceagent/utils/common.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +import json + + +def add_trailing_slash(path): + if path and not path.endswith('/'): + path += '/' + return path + + +def generate_unique_uuid(length): + unique_id = str(uuid.uuid4().int) + # Ensure the generated ID is exactly 'length' digits by taking the last 'length' characters + unique_id = unique_id[-length:] + return unique_id + + +def load_json_file(file_path): + try: + with open(file_path, 'r') as file: + return json.load(file) + except FileNotFoundError: + raise ValueError(f"File '{file_path}' not found.") + + +def delete_file(file_path): + if os.path.exists(file_path): + try: + os.remove(file_path) + except Exception as e: + print(f"Error deleting '{file_path}': {e}") + else: + print(f"File '{file_path}' does not exist.")
\ No newline at end of file diff --git a/agl_service_voiceagent/utils/config.py b/agl_service_voiceagent/utils/config.py new file mode 100644 index 0000000..8d7f346 --- /dev/null +++ b/agl_service_voiceagent/utils/config.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import configparser + +# Get the absolute path to the directory of the current script +current_dir = os.path.dirname(os.path.abspath(__file__)) +# Construct the path to the config.ini file located in the base directory +config_path = os.path.join(current_dir, '..', 'config.ini') + +config = configparser.ConfigParser() +config.read(config_path) + +def update_config_value(value, key, group="General"): + config.set(group, key, value) + with open(config_path, 'w') as configfile: + config.write(configfile) + +def get_config_value(key, group="General"): + return config.get(group, key) diff --git a/agl_service_voiceagent/utils/kuksa_interface.py b/agl_service_voiceagent/utils/kuksa_interface.py new file mode 100644 index 0000000..3e1c045 --- /dev/null +++ b/agl_service_voiceagent/utils/kuksa_interface.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from kuksa_client import KuksaClientThread +from agl_service_voiceagent.utils.config import get_config_value + +class KuksaInterface: + def __init__(self): + # get config values + self.ip = get_config_value("ip", "Kuksa") + self.port = get_config_value("port", "Kuksa") + self.insecure = get_config_value("insecure", "Kuksa") + self.protocol = get_config_value("protocol", "Kuksa") + self.token = get_config_value("token", "Kuksa") + + # define class methods + self.kuksa_client = None + + + def get_kuksa_client(self): + return self.kuksa_client + + + def get_kuksa_status(self): + if self.kuksa_client: + return self.kuksa_client.checkConnection() + + return False + + + def connect_kuksa_client(self): + try: + self.kuksa_client = KuksaClientThread({ + "ip": self.ip, + "port": self.port, + "insecure": self.insecure, + "protocol": self.protocol, + }) + self.kuksa_client.authorize(self.token) + + except Exception as e: + print("Error: ", e) + + + def send_values(self, Path=None, Value=None): + if self.kuksa_client is None: + print("Error: Kuksa client is not initialized.") + + if self.get_kuksa_status(): + self.kuksa_client.setValue(Path, Value) + + else: + print("Error: Connection to Kuksa failed.") diff --git a/agl_service_voiceagent/utils/stt_model.py b/agl_service_voiceagent/utils/stt_model.py new file mode 100644 index 0000000..5337162 --- /dev/null +++ b/agl_service_voiceagent/utils/stt_model.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import vosk +import wave +from agl_service_voiceagent.utils.common import generate_unique_uuid + +class STTModel: + def __init__(self, model_path, sample_rate=16000): + self.sample_rate = sample_rate + self.model = vosk.Model(model_path) + self.recognizer = {} + self.chunk_size = 1024 + + def setup_recognizer(self): + uuid = generate_unique_uuid(6) + self.recognizer[uuid] = vosk.KaldiRecognizer(self.model, self.sample_rate) + return uuid + + def init_recognition(self, uuid, audio_data): + return self.recognizer[uuid].AcceptWaveform(audio_data) + + def recognize(self, uuid, partial=False): + self.recognizer[uuid].SetWords(True) + if partial: + result = json.loads(self.recognizer[uuid].PartialResult()) + else: + result = json.loads(self.recognizer[uuid].Result()) + self.recognizer[uuid].Reset() + return result + + def recognize_from_file(self, uuid, filename): + if not os.path.exists(filename): + print(f"Audio file '{filename}' not found.") + return "FILE_NOT_FOUND" + + wf = wave.open(filename, "rb") + if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE": + print("Audio file must be WAV format mono PCM.") + return "FILE_FORMAT_INVALID" + + # audio_data = wf.readframes(wf.getnframes()) + # we need to perform chunking as target AGL system can't handle an entire audio file + audio_data = b"" + while True: + chunk = wf.readframes(self.chunk_size) + if not chunk: + break # End of file reached + audio_data += chunk + + if audio_data: + if self.init_recognition(uuid, audio_data): + result = self.recognize(uuid) + return result['text'] + else: + result = self.recognize(uuid, partial=True) + return result['partial'] + + else: + print("Voice not recognized. Please speak again...") + return "VOICE_NOT_RECOGNIZED" + + def cleanup_recognizer(self, uuid): + del self.recognizer[uuid] + +import wave + +def read_wav_file(filename, chunk_size=1024): + try: + wf = wave.open(filename, "rb") + if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE": + print("Audio file must be WAV format mono PCM.") + return "FILE_FORMAT_INVALID" + + audio_data = b"" # Initialize an empty bytes object to store audio data + while True: + chunk = wf.readframes(chunk_size) + if not chunk: + break # End of file reached + audio_data += chunk + + return audio_data + except Exception as e: + print(f"Error reading audio file: {e}") + return None + +# Example usage: +filename = "your_audio.wav" +audio_data = read_wav_file(filename) +
\ No newline at end of file diff --git a/agl_service_voiceagent/utils/wake_word.py b/agl_service_voiceagent/utils/wake_word.py new file mode 100644 index 0000000..066ae6d --- /dev/null +++ b/agl_service_voiceagent/utils/wake_word.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gi +import vosk +gi.require_version('Gst', '1.0') +from gi.repository import Gst, GLib + +Gst.init(None) +GLib.threads_init() + +class WakeWordDetector: + def __init__(self, wake_word, stt_model, channels=1, sample_rate=16000, bits_per_sample=16): + self.loop = GLib.MainLoop() + self.pipeline = None + self.bus = None + self.wake_word = wake_word + self.wake_word_detected = False + self.sample_rate = sample_rate + self.channels = channels + self.bits_per_sample = bits_per_sample + self.frame_size = int(self.sample_rate * 0.02) + self.stt_model = stt_model # Speech to text model recognizer + self.recognizer_uuid = stt_model.setup_recognizer() + self.buffer_duration = 1 # Buffer audio for atleast 1 second + self.audio_buffer = bytearray() + + def get_wake_word_status(self): + return self.wake_word_detected + + def create_pipeline(self): + print("Creating pipeline for Wake Word Detection...") + self.pipeline = Gst.Pipeline() + autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None) + queue = Gst.ElementFactory.make("queue", None) + audioconvert = Gst.ElementFactory.make("audioconvert", None) + wavenc = Gst.ElementFactory.make("wavenc", None) + + capsfilter = Gst.ElementFactory.make("capsfilter", None) + caps = Gst.Caps.new_empty_simple("audio/x-raw") + caps.set_value("format", "S16LE") + caps.set_value("rate", self.sample_rate) + caps.set_value("channels", self.channels) + capsfilter.set_property("caps", caps) + + appsink = Gst.ElementFactory.make("appsink", None) + appsink.set_property("emit-signals", True) + appsink.set_property("sync", False) # Set sync property to False to enable async processing + appsink.connect("new-sample", self.on_new_buffer, None) + + self.pipeline.add(autoaudiosrc) + self.pipeline.add(queue) + self.pipeline.add(audioconvert) + self.pipeline.add(wavenc) + self.pipeline.add(capsfilter) + self.pipeline.add(appsink) + + autoaudiosrc.link(queue) + queue.link(audioconvert) + audioconvert.link(capsfilter) + capsfilter.link(appsink) + + self.bus = self.pipeline.get_bus() + self.bus.add_signal_watch() + self.bus.connect("message", self.on_bus_message) + + def on_new_buffer(self, appsink, data) -> Gst.FlowReturn: + sample = appsink.emit("pull-sample") + buffer = sample.get_buffer() + data = buffer.extract_dup(0, buffer.get_size()) + self.audio_buffer.extend(data) + + if len(self.audio_buffer) >= self.sample_rate * self.buffer_duration * self.channels * self.bits_per_sample // 8: + self.process_audio_buffer() + + return Gst.FlowReturn.OK + + + def process_audio_buffer(self): + # Process the accumulated audio data using the audio model + audio_data = bytes(self.audio_buffer) + if self.stt_model.init_recognition(self.recognizer_uuid, audio_data): + stt_result = self.stt_model.recognize(self.recognizer_uuid) + print("STT Result: ", stt_result) + if self.wake_word in stt_result["text"]: + self.wake_word_detected = True + print("Wake word detected!") + self.pipeline.send_event(Gst.Event.new_eos()) + + self.audio_buffer.clear() # Clear the buffer + + + def send_eos(self): + self.pipeline.send_event(Gst.Event.new_eos()) + self.audio_buffer.clear() + + + def start_listening(self): + self.pipeline.set_state(Gst.State.PLAYING) + print("Listening for Wake Word...") + self.loop.run() + + + def stop_listening(self): + self.cleanup_pipeline() + self.loop.quit() + + + def on_bus_message(self, bus, message): + if message.type == Gst.MessageType.EOS: + print("End-of-stream message received") + self.stop_listening() + elif message.type == Gst.MessageType.ERROR: + err, debug_info = message.parse_error() + print(f"Error received from element {message.src.get_name()}: {err.message}") + print(f"Debugging information: {debug_info}") + self.stop_listening() + elif message.type == Gst.MessageType.WARNING: + err, debug_info = message.parse_warning() + print(f"Warning received from element {message.src.get_name()}: {err.message}") + print(f"Debugging information: {debug_info}") + elif message.type == Gst.MessageType.STATE_CHANGED: + if isinstance(message.src, Gst.Pipeline): + old_state, new_state, pending_state = message.parse_state_changed() + print(("Pipeline state changed from %s to %s." % + (old_state.value_nick, new_state.value_nick))) + + + def cleanup_pipeline(self): + if self.pipeline is not None: + print("Cleaning up pipeline...") + self.pipeline.set_state(Gst.State.NULL) + self.bus.remove_signal_watch() + print("Pipeline cleanup complete!") + self.bus = None + self.pipeline = None + self.stt_model.cleanup_recognizer(self.recognizer_uuid) |