diff options
Diffstat (limited to 'agl_service_voiceagent')
18 files changed, 1536 insertions, 71 deletions
diff --git a/agl_service_voiceagent/client.py b/agl_service_voiceagent/client.py index 88ef785..ee7bc52 100644 --- a/agl_service_voiceagent/client.py +++ b/agl_service_voiceagent/client.py @@ -14,16 +14,19 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys +sys.path.append("../") import time import grpc from agl_service_voiceagent.generated import voice_agent_pb2 from agl_service_voiceagent.generated import voice_agent_pb2_grpc -def run_client(server_address, server_port, action, mode, nlu_engine, recording_time): +def run_client(server_address, server_port, action, mode, nlu_engine, recording_time,stt_framework,online_mode): SERVER_URL = server_address + ":" + server_port nlu_engine = voice_agent_pb2.RASA if nlu_engine == "rasa" else voice_agent_pb2.SNIPS print("Starting Voice Agent Client...") print(f"Client connecting to URL: {SERVER_URL}") + print("STT Framework:", stt_framework) with grpc.insecure_channel(SERVER_URL) as channel: print("Press Ctrl+C to stop the client.") print("Voice Agent Client started!") @@ -51,18 +54,74 @@ def run_client(server_address, server_port, action, mode, nlu_engine, recording_ elif action == 'ExecuteVoiceCommand': if mode == 'auto': - raise ValueError("[-] Auto mode is not implemented yet.") + # raise ValueError("[-] Auto mode is not implemented yet.") + stub = voice_agent_pb2_grpc.VoiceAgentServiceStub(channel=channel) + stt_framework = voice_agent_pb2.VOSK if stt_framework == "vosk" else voice_agent_pb2.WHISPER + online_mode = voice_agent_pb2.ONLINE if online_mode == True else voice_agent_pb2.OFFLINE + while(True): + wake_request = voice_agent_pb2.Empty() + wake_results = stub.DetectWakeWord(wake_request) + wake_word_detected = False + for wake_result in wake_results: + print("Wake word status: ", wake_word_detected) + if wake_result.status: + print("Wake word status: ", wake_result.status) + wake_word_detected = True + break + print("Wake word detected: ", wake_word_detected) + if wake_word_detected: + print("[+] Wake Word detected! Recording voice command...") + record_start_request = voice_agent_pb2.RecognizeVoiceControl(action=voice_agent_pb2.START, nlu_model=nlu_engine, record_mode=voice_agent_pb2.MANUAL, stt_framework=stt_framework,online_mode=online_mode,) + response = stub.RecognizeVoiceCommand(iter([record_start_request])) + stream_id = response.stream_id + + time.sleep(recording_time) # pause here for the number of seconds passed by user or default 5 seconds + + record_stop_request = voice_agent_pb2.RecognizeVoiceControl(action=voice_agent_pb2.STOP, nlu_model=nlu_engine, record_mode=voice_agent_pb2.MANUAL, stream_id=stream_id,stt_framework=stt_framework,online_mode=online_mode,) + record_result = stub.RecognizeVoiceCommand(iter([record_stop_request])) + print("[+] Voice command recording ended!") + + status = "Uh oh! Status is unknown." + if record_result.status == voice_agent_pb2.REC_SUCCESS: + status = "Yay! Status is success." + elif record_result.status == voice_agent_pb2.VOICE_NOT_RECOGNIZED: + status = "Voice not recognized." + elif record_result.status == voice_agent_pb2.INTENT_NOT_RECOGNIZED: + status = "Intent not recognized." + + # Process the response + print("Status:", status) + print("Command:", record_result.command) + print("Intent:", record_result.intent) + intent_slots = [] + for slot in record_result.intent_slots: + print("Slot Name:", slot.name) + print("Slot Value:", slot.value) + i_slot = voice_agent_pb2.IntentSlot(name=slot.name, value=slot.value) + intent_slots.append(i_slot) + + if record_result.status == voice_agent_pb2.REC_SUCCESS: + print("[+] Executing voice command...") + exec_voice_command_request = voice_agent_pb2.ExecuteInput(intent=record_result.intent, intent_slots=intent_slots) + response = stub.ExecuteCommand(exec_voice_command_request) + print("Response:", response) + wake_word_detected = False + time.sleep(1) + + elif mode == 'manual': stub = voice_agent_pb2_grpc.VoiceAgentServiceStub(channel) + stt_framework = voice_agent_pb2.VOSK if stt_framework == "vosk" else voice_agent_pb2.WHISPER + online_mode = voice_agent_pb2.ONLINE if online_mode == True else voice_agent_pb2.OFFLINE print("[+] Recording voice command in manual mode...") - record_start_request = voice_agent_pb2.RecognizeVoiceControl(action=voice_agent_pb2.START, nlu_model=nlu_engine, record_mode=voice_agent_pb2.MANUAL) + record_start_request = voice_agent_pb2.RecognizeVoiceControl(action=voice_agent_pb2.START, nlu_model=nlu_engine, record_mode=voice_agent_pb2.MANUAL, stt_framework=stt_framework,online_mode=online_mode,) response = stub.RecognizeVoiceCommand(iter([record_start_request])) stream_id = response.stream_id time.sleep(recording_time) # pause here for the number of seconds passed by user or default 5 seconds - record_stop_request = voice_agent_pb2.RecognizeVoiceControl(action=voice_agent_pb2.STOP, nlu_model=nlu_engine, record_mode=voice_agent_pb2.MANUAL, stream_id=stream_id) + record_stop_request = voice_agent_pb2.RecognizeVoiceControl(action=voice_agent_pb2.STOP, nlu_model=nlu_engine, record_mode=voice_agent_pb2.MANUAL, stream_id=stream_id,stt_framework=stt_framework,online_mode=online_mode,) record_result = stub.RecognizeVoiceCommand(iter([record_stop_request])) print("[+] Voice command recording ended!") diff --git a/agl_service_voiceagent/config.ini b/agl_service_voiceagent/config.ini index 1651da5..d6d695e 100644 --- a/agl_service_voiceagent/config.ini +++ b/agl_service_voiceagent/config.ini @@ -1,12 +1,15 @@ [General] base_audio_dir = /usr/share/nlu/commands/ -stt_model_path = /usr/share/vosk/VOSK_STT_MODEL_NAME/ -wake_word_model_path = /usr/share/vosk/VOSK_WWD_MODEL_NAME/ +vosk_model_path = /usr/share/vosk/vosk-model-small-en-us-0.15/ +whisper_model_path = /usr/share/whisper/tiny.pt +whisper_cpp_path = /usr/bin/whisper-cpp +whisper_cpp_model_path = /usr/share/whisper-cpp/models/tiny.en.bin +wake_word_model_path = /usr/share/vosk/vosk-model-small-en-us-0.15/ snips_model_path = /usr/share/nlu/snips/model/ channels = 1 sample_rate = 16000 bits_per_sample = 16 -wake_word = WAKE_WORD_VALUE +wake_word = hello server_port = 51053 server_address = 127.0.0.1 rasa_model_path = /usr/share/nlu/rasa/models/ @@ -14,13 +17,28 @@ rasa_server_port = 51054 rasa_detached_mode = 1 base_log_dir = /usr/share/nlu/logs/ store_voice_commands = 0 +online_mode = 1 +online_mode_address = 65.108.107.216 +online_mode_port = 50051 +online_mode_timeout = 15 +mpd_ip = 127.0.0.1 +mpd_port = 6600 [Kuksa] ip = 127.0.0.1 port = 55555 protocol = grpc insecure = 0 -token = PYTHON_DIR/kuksa_certificates/jwt/super-admin.json.token +token = /usr/lib/python3.12/site-packages/kuksa_certificates/jwt/super-admin.json.token +tls_server_name = Server + +[VSS] +hostname = localhost +port = 55555 +protocol = grpc +insecure = 0 +token_filename = /etc/xdg/AGL/agl-vss-helper/agl-vss-helper.token +ca_cert_filename = /etc/kuksa-val/CA.pem tls_server_name = Server [Mapper] diff --git a/agl_service_voiceagent/generated/audio_processing_pb2.py b/agl_service_voiceagent/generated/audio_processing_pb2.py new file mode 100644 index 0000000..fdbeedb --- /dev/null +++ b/agl_service_voiceagent/generated/audio_processing_pb2.py @@ -0,0 +1,30 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: audio_processing.proto +# Protobuf Python Version: 5.26.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x16\x61udio_processing.proto\x12\taudioproc\"\"\n\x0c\x41udioRequest\x12\x12\n\naudio_data\x18\x01 \x01(\x0c\"\x1c\n\x0cTextResponse\x12\x0c\n\x04text\x18\x01 \x01(\t2S\n\x0f\x41udioProcessing\x12@\n\x0cProcessAudio\x12\x17.audioproc.AudioRequest\x1a\x17.audioproc.TextResponseb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'audio_processing_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_AUDIOREQUEST']._serialized_start=37 + _globals['_AUDIOREQUEST']._serialized_end=71 + _globals['_TEXTRESPONSE']._serialized_start=73 + _globals['_TEXTRESPONSE']._serialized_end=101 + _globals['_AUDIOPROCESSING']._serialized_start=103 + _globals['_AUDIOPROCESSING']._serialized_end=186 +# @@protoc_insertion_point(module_scope) diff --git a/agl_service_voiceagent/generated/audio_processing_pb2_grpc.py b/agl_service_voiceagent/generated/audio_processing_pb2_grpc.py new file mode 100644 index 0000000..4b54903 --- /dev/null +++ b/agl_service_voiceagent/generated/audio_processing_pb2_grpc.py @@ -0,0 +1,106 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +from . import audio_processing_pb2 as audio__processing__pb2 + +GRPC_GENERATED_VERSION = '1.65.0rc1' +GRPC_VERSION = grpc.__version__ +EXPECTED_ERROR_RELEASE = '1.65.0' +SCHEDULED_RELEASE_DATE = 'June 25, 2024' +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + warnings.warn( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in audio_processing_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', + RuntimeWarning + ) + + +class AudioProcessingStub(object): + """The audio processing service definition. + """ + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.ProcessAudio = channel.unary_unary( + '/audioproc.AudioProcessing/ProcessAudio', + request_serializer=audio__processing__pb2.AudioRequest.SerializeToString, + response_deserializer=audio__processing__pb2.TextResponse.FromString, + _registered_method=True) + + +class AudioProcessingServicer(object): + """The audio processing service definition. + """ + + def ProcessAudio(self, request, context): + """Sends audio data and receives processed text. + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_AudioProcessingServicer_to_server(servicer, server): + rpc_method_handlers = { + 'ProcessAudio': grpc.unary_unary_rpc_method_handler( + servicer.ProcessAudio, + request_deserializer=audio__processing__pb2.AudioRequest.FromString, + response_serializer=audio__processing__pb2.TextResponse.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'audioproc.AudioProcessing', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('audioproc.AudioProcessing', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class AudioProcessing(object): + """The audio processing service definition. + """ + + @staticmethod + def ProcessAudio(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/audioproc.AudioProcessing/ProcessAudio', + audio__processing__pb2.AudioRequest.SerializeToString, + audio__processing__pb2.TextResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/agl_service_voiceagent/generated/voice_agent_pb2.py b/agl_service_voiceagent/generated/voice_agent_pb2.py new file mode 100644 index 0000000..4606f60 --- /dev/null +++ b/agl_service_voiceagent/generated/voice_agent_pb2.py @@ -0,0 +1,62 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: voice_agent.proto +# Protobuf Python Version: 5.26.1 +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x11voice_agent.proto\"\x07\n\x05\x45mpty\"C\n\rServiceStatus\x12\x0f\n\x07version\x18\x01 \x01(\t\x12\x0e\n\x06status\x18\x02 \x01(\x08\x12\x11\n\twake_word\x18\x03 \x01(\t\"^\n\nVoiceAudio\x12\x13\n\x0b\x61udio_chunk\x18\x01 \x01(\x0c\x12\x14\n\x0c\x61udio_format\x18\x02 \x01(\t\x12\x13\n\x0bsample_rate\x18\x03 \x01(\x05\x12\x10\n\x08language\x18\x04 \x01(\t\" \n\x0eWakeWordStatus\x12\x0e\n\x06status\x18\x01 \x01(\x08\"\x93\x01\n\x17S_RecognizeVoiceControl\x12!\n\x0c\x61udio_stream\x18\x01 \x01(\x0b\x32\x0b.VoiceAudio\x12\x1c\n\tnlu_model\x18\x02 \x01(\x0e\x32\t.NLUModel\x12\x11\n\tstream_id\x18\x03 \x01(\t\x12$\n\rstt_framework\x18\x04 \x01(\x0e\x32\r.STTFramework\"\xd1\x01\n\x15RecognizeVoiceControl\x12\x1d\n\x06\x61\x63tion\x18\x01 \x01(\x0e\x32\r.RecordAction\x12\x1c\n\tnlu_model\x18\x02 \x01(\x0e\x32\t.NLUModel\x12 \n\x0brecord_mode\x18\x03 \x01(\x0e\x32\x0b.RecordMode\x12\x11\n\tstream_id\x18\x04 \x01(\t\x12$\n\rstt_framework\x18\x05 \x01(\x0e\x32\r.STTFramework\x12 \n\x0bonline_mode\x18\x06 \x01(\x0e\x32\x0b.OnlineMode\"J\n\x14RecognizeTextControl\x12\x14\n\x0ctext_command\x18\x01 \x01(\t\x12\x1c\n\tnlu_model\x18\x02 \x01(\x0e\x32\t.NLUModel\")\n\nIntentSlot\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t\"\x8e\x01\n\x0fRecognizeResult\x12\x0f\n\x07\x63ommand\x18\x01 \x01(\t\x12\x0e\n\x06intent\x18\x02 \x01(\t\x12!\n\x0cintent_slots\x18\x03 \x03(\x0b\x32\x0b.IntentSlot\x12\x11\n\tstream_id\x18\x04 \x01(\t\x12$\n\x06status\x18\x05 \x01(\x0e\x32\x14.RecognizeStatusType\"A\n\x0c\x45xecuteInput\x12\x0e\n\x06intent\x18\x01 \x01(\t\x12!\n\x0cintent_slots\x18\x02 \x03(\x0b\x32\x0b.IntentSlot\"E\n\rExecuteResult\x12\x10\n\x08response\x18\x01 \x01(\t\x12\"\n\x06status\x18\x02 \x01(\x0e\x32\x12.ExecuteStatusType*%\n\x0cSTTFramework\x12\x08\n\x04VOSK\x10\x00\x12\x0b\n\x07WHISPER\x10\x01*%\n\nOnlineMode\x12\n\n\x06ONLINE\x10\x00\x12\x0b\n\x07OFFLINE\x10\x01*#\n\x0cRecordAction\x12\t\n\x05START\x10\x00\x12\x08\n\x04STOP\x10\x01*\x1f\n\x08NLUModel\x12\t\n\x05SNIPS\x10\x00\x12\x08\n\x04RASA\x10\x01*\"\n\nRecordMode\x12\n\n\x06MANUAL\x10\x00\x12\x08\n\x04\x41UTO\x10\x01*\xb4\x01\n\x13RecognizeStatusType\x12\r\n\tREC_ERROR\x10\x00\x12\x0f\n\x0bREC_SUCCESS\x10\x01\x12\x12\n\x0eREC_PROCESSING\x10\x02\x12\x18\n\x14VOICE_NOT_RECOGNIZED\x10\x03\x12\x19\n\x15INTENT_NOT_RECOGNIZED\x10\x04\x12\x17\n\x13TEXT_NOT_RECOGNIZED\x10\x05\x12\x1b\n\x17NLU_MODEL_NOT_SUPPORTED\x10\x06*\x82\x01\n\x11\x45xecuteStatusType\x12\x0e\n\nEXEC_ERROR\x10\x00\x12\x10\n\x0c\x45XEC_SUCCESS\x10\x01\x12\x14\n\x10KUKSA_CONN_ERROR\x10\x02\x12\x18\n\x14INTENT_NOT_SUPPORTED\x10\x03\x12\x1b\n\x17INTENT_SLOTS_INCOMPLETE\x10\x04\x32\xa4\x03\n\x11VoiceAgentService\x12,\n\x12\x43heckServiceStatus\x12\x06.Empty\x1a\x0e.ServiceStatus\x12\x34\n\x10S_DetectWakeWord\x12\x0b.VoiceAudio\x1a\x0f.WakeWordStatus(\x01\x30\x01\x12+\n\x0e\x44\x65tectWakeWord\x12\x06.Empty\x1a\x0f.WakeWordStatus0\x01\x12G\n\x17S_RecognizeVoiceCommand\x12\x18.S_RecognizeVoiceControl\x1a\x10.RecognizeResult(\x01\x12\x43\n\x15RecognizeVoiceCommand\x12\x16.RecognizeVoiceControl\x1a\x10.RecognizeResult(\x01\x12?\n\x14RecognizeTextCommand\x12\x15.RecognizeTextControl\x1a\x10.RecognizeResult\x12/\n\x0e\x45xecuteCommand\x12\r.ExecuteInput\x1a\x0e.ExecuteResultb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'voice_agent_pb2', _globals) +if not _descriptor._USE_C_DESCRIPTORS: + DESCRIPTOR._loaded_options = None + _globals['_STTFRAMEWORK']._serialized_start=993 + _globals['_STTFRAMEWORK']._serialized_end=1030 + _globals['_ONLINEMODE']._serialized_start=1032 + _globals['_ONLINEMODE']._serialized_end=1069 + _globals['_RECORDACTION']._serialized_start=1071 + _globals['_RECORDACTION']._serialized_end=1106 + _globals['_NLUMODEL']._serialized_start=1108 + _globals['_NLUMODEL']._serialized_end=1139 + _globals['_RECORDMODE']._serialized_start=1141 + _globals['_RECORDMODE']._serialized_end=1175 + _globals['_RECOGNIZESTATUSTYPE']._serialized_start=1178 + _globals['_RECOGNIZESTATUSTYPE']._serialized_end=1358 + _globals['_EXECUTESTATUSTYPE']._serialized_start=1361 + _globals['_EXECUTESTATUSTYPE']._serialized_end=1491 + _globals['_EMPTY']._serialized_start=21 + _globals['_EMPTY']._serialized_end=28 + _globals['_SERVICESTATUS']._serialized_start=30 + _globals['_SERVICESTATUS']._serialized_end=97 + _globals['_VOICEAUDIO']._serialized_start=99 + _globals['_VOICEAUDIO']._serialized_end=193 + _globals['_WAKEWORDSTATUS']._serialized_start=195 + _globals['_WAKEWORDSTATUS']._serialized_end=227 + _globals['_S_RECOGNIZEVOICECONTROL']._serialized_start=230 + _globals['_S_RECOGNIZEVOICECONTROL']._serialized_end=377 + _globals['_RECOGNIZEVOICECONTROL']._serialized_start=380 + _globals['_RECOGNIZEVOICECONTROL']._serialized_end=589 + _globals['_RECOGNIZETEXTCONTROL']._serialized_start=591 + _globals['_RECOGNIZETEXTCONTROL']._serialized_end=665 + _globals['_INTENTSLOT']._serialized_start=667 + _globals['_INTENTSLOT']._serialized_end=708 + _globals['_RECOGNIZERESULT']._serialized_start=711 + _globals['_RECOGNIZERESULT']._serialized_end=853 + _globals['_EXECUTEINPUT']._serialized_start=855 + _globals['_EXECUTEINPUT']._serialized_end=920 + _globals['_EXECUTERESULT']._serialized_start=922 + _globals['_EXECUTERESULT']._serialized_end=991 + _globals['_VOICEAGENTSERVICE']._serialized_start=1494 + _globals['_VOICEAGENTSERVICE']._serialized_end=1914 +# @@protoc_insertion_point(module_scope) diff --git a/agl_service_voiceagent/generated/voice_agent_pb2_grpc.py b/agl_service_voiceagent/generated/voice_agent_pb2_grpc.py new file mode 100644 index 0000000..15d76f4 --- /dev/null +++ b/agl_service_voiceagent/generated/voice_agent_pb2_grpc.py @@ -0,0 +1,362 @@ +# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT! +"""Client and server classes corresponding to protobuf-defined services.""" +import grpc +import warnings + +from . import voice_agent_pb2 as voice__agent__pb2 + +GRPC_GENERATED_VERSION = '1.65.0rc1' +GRPC_VERSION = grpc.__version__ +EXPECTED_ERROR_RELEASE = '1.65.0' +SCHEDULED_RELEASE_DATE = 'June 25, 2024' +_version_not_supported = False + +try: + from grpc._utilities import first_version_is_lower + _version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION) +except ImportError: + _version_not_supported = True + +if _version_not_supported: + warnings.warn( + f'The grpc package installed is at version {GRPC_VERSION},' + + f' but the generated code in voice_agent_pb2_grpc.py depends on' + + f' grpcio>={GRPC_GENERATED_VERSION}.' + + f' Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}' + + f' or downgrade your generated code using grpcio-tools<={GRPC_VERSION}.' + + f' This warning will become an error in {EXPECTED_ERROR_RELEASE},' + + f' scheduled for release on {SCHEDULED_RELEASE_DATE}.', + RuntimeWarning + ) + + +class VoiceAgentServiceStub(object): + """Missing associated documentation comment in .proto file.""" + + def __init__(self, channel): + """Constructor. + + Args: + channel: A grpc.Channel. + """ + self.CheckServiceStatus = channel.unary_unary( + '/VoiceAgentService/CheckServiceStatus', + request_serializer=voice__agent__pb2.Empty.SerializeToString, + response_deserializer=voice__agent__pb2.ServiceStatus.FromString, + _registered_method=True) + self.S_DetectWakeWord = channel.stream_stream( + '/VoiceAgentService/S_DetectWakeWord', + request_serializer=voice__agent__pb2.VoiceAudio.SerializeToString, + response_deserializer=voice__agent__pb2.WakeWordStatus.FromString, + _registered_method=True) + self.DetectWakeWord = channel.unary_stream( + '/VoiceAgentService/DetectWakeWord', + request_serializer=voice__agent__pb2.Empty.SerializeToString, + response_deserializer=voice__agent__pb2.WakeWordStatus.FromString, + _registered_method=True) + self.S_RecognizeVoiceCommand = channel.stream_unary( + '/VoiceAgentService/S_RecognizeVoiceCommand', + request_serializer=voice__agent__pb2.S_RecognizeVoiceControl.SerializeToString, + response_deserializer=voice__agent__pb2.RecognizeResult.FromString, + _registered_method=True) + self.RecognizeVoiceCommand = channel.stream_unary( + '/VoiceAgentService/RecognizeVoiceCommand', + request_serializer=voice__agent__pb2.RecognizeVoiceControl.SerializeToString, + response_deserializer=voice__agent__pb2.RecognizeResult.FromString, + _registered_method=True) + self.RecognizeTextCommand = channel.unary_unary( + '/VoiceAgentService/RecognizeTextCommand', + request_serializer=voice__agent__pb2.RecognizeTextControl.SerializeToString, + response_deserializer=voice__agent__pb2.RecognizeResult.FromString, + _registered_method=True) + self.ExecuteCommand = channel.unary_unary( + '/VoiceAgentService/ExecuteCommand', + request_serializer=voice__agent__pb2.ExecuteInput.SerializeToString, + response_deserializer=voice__agent__pb2.ExecuteResult.FromString, + _registered_method=True) + + +class VoiceAgentServiceServicer(object): + """Missing associated documentation comment in .proto file.""" + + def CheckServiceStatus(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def S_DetectWakeWord(self, request_iterator, context): + """Stream version of DetectWakeWord, assumes audio is coming from client + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def DetectWakeWord(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def S_RecognizeVoiceCommand(self, request_iterator, context): + """Stream version of RecognizeVoiceCommand, assumes audio is coming from client + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def RecognizeVoiceCommand(self, request_iterator, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def RecognizeTextCommand(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def ExecuteCommand(self, request, context): + """Missing associated documentation comment in .proto file.""" + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + +def add_VoiceAgentServiceServicer_to_server(servicer, server): + rpc_method_handlers = { + 'CheckServiceStatus': grpc.unary_unary_rpc_method_handler( + servicer.CheckServiceStatus, + request_deserializer=voice__agent__pb2.Empty.FromString, + response_serializer=voice__agent__pb2.ServiceStatus.SerializeToString, + ), + 'S_DetectWakeWord': grpc.stream_stream_rpc_method_handler( + servicer.S_DetectWakeWord, + request_deserializer=voice__agent__pb2.VoiceAudio.FromString, + response_serializer=voice__agent__pb2.WakeWordStatus.SerializeToString, + ), + 'DetectWakeWord': grpc.unary_stream_rpc_method_handler( + servicer.DetectWakeWord, + request_deserializer=voice__agent__pb2.Empty.FromString, + response_serializer=voice__agent__pb2.WakeWordStatus.SerializeToString, + ), + 'S_RecognizeVoiceCommand': grpc.stream_unary_rpc_method_handler( + servicer.S_RecognizeVoiceCommand, + request_deserializer=voice__agent__pb2.S_RecognizeVoiceControl.FromString, + response_serializer=voice__agent__pb2.RecognizeResult.SerializeToString, + ), + 'RecognizeVoiceCommand': grpc.stream_unary_rpc_method_handler( + servicer.RecognizeVoiceCommand, + request_deserializer=voice__agent__pb2.RecognizeVoiceControl.FromString, + response_serializer=voice__agent__pb2.RecognizeResult.SerializeToString, + ), + 'RecognizeTextCommand': grpc.unary_unary_rpc_method_handler( + servicer.RecognizeTextCommand, + request_deserializer=voice__agent__pb2.RecognizeTextControl.FromString, + response_serializer=voice__agent__pb2.RecognizeResult.SerializeToString, + ), + 'ExecuteCommand': grpc.unary_unary_rpc_method_handler( + servicer.ExecuteCommand, + request_deserializer=voice__agent__pb2.ExecuteInput.FromString, + response_serializer=voice__agent__pb2.ExecuteResult.SerializeToString, + ), + } + generic_handler = grpc.method_handlers_generic_handler( + 'VoiceAgentService', rpc_method_handlers) + server.add_generic_rpc_handlers((generic_handler,)) + server.add_registered_method_handlers('VoiceAgentService', rpc_method_handlers) + + + # This class is part of an EXPERIMENTAL API. +class VoiceAgentService(object): + """Missing associated documentation comment in .proto file.""" + + @staticmethod + def CheckServiceStatus(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/VoiceAgentService/CheckServiceStatus', + voice__agent__pb2.Empty.SerializeToString, + voice__agent__pb2.ServiceStatus.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def S_DetectWakeWord(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_stream( + request_iterator, + target, + '/VoiceAgentService/S_DetectWakeWord', + voice__agent__pb2.VoiceAudio.SerializeToString, + voice__agent__pb2.WakeWordStatus.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def DetectWakeWord(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_stream( + request, + target, + '/VoiceAgentService/DetectWakeWord', + voice__agent__pb2.Empty.SerializeToString, + voice__agent__pb2.WakeWordStatus.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def S_RecognizeVoiceCommand(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/VoiceAgentService/S_RecognizeVoiceCommand', + voice__agent__pb2.S_RecognizeVoiceControl.SerializeToString, + voice__agent__pb2.RecognizeResult.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def RecognizeVoiceCommand(request_iterator, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.stream_unary( + request_iterator, + target, + '/VoiceAgentService/RecognizeVoiceCommand', + voice__agent__pb2.RecognizeVoiceControl.SerializeToString, + voice__agent__pb2.RecognizeResult.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def RecognizeTextCommand(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/VoiceAgentService/RecognizeTextCommand', + voice__agent__pb2.RecognizeTextControl.SerializeToString, + voice__agent__pb2.RecognizeResult.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def ExecuteCommand(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/VoiceAgentService/ExecuteCommand', + voice__agent__pb2.ExecuteInput.SerializeToString, + voice__agent__pb2.ExecuteResult.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/agl_service_voiceagent/nlu/snips_interface.py b/agl_service_voiceagent/nlu/snips_interface.py index 1febe92..a32f574 100644 --- a/agl_service_voiceagent/nlu/snips_interface.py +++ b/agl_service_voiceagent/nlu/snips_interface.py @@ -13,7 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - import re from typing import Text from snips_inference_agl import SnipsNLUEngine @@ -47,6 +46,10 @@ class SnipsInterface: preprocessed_text = text.lower().strip() # remove special characters, punctuation, and extra whitespaces preprocessed_text = re.sub(r'[^\w\s]', '', preprocessed_text).strip() + # replace % with " precent" + preprocessed_text = re.sub(r'%', ' percent', preprocessed_text) + # replace ° with " degrees" + preprocessed_text = re.sub(r'°', ' degrees ', preprocessed_text) return preprocessed_text def extract_intent(self, text: Text): diff --git a/agl_service_voiceagent/protos/audio_processing.proto b/agl_service_voiceagent/protos/audio_processing.proto new file mode 100644 index 0000000..edacc04 --- /dev/null +++ b/agl_service_voiceagent/protos/audio_processing.proto @@ -0,0 +1,23 @@ +// proto file for audio processing service for whiisper online service + +syntax = "proto3"; + +package audioproc; + +service AudioProcessing { + // Sends audio data and receives processed text. + rpc ProcessAudio (AudioRequest) returns (TextResponse); +} + +// The request message containing the audio data. +message AudioRequest { + bytes audio_data = 1; +} + +// The response message containing the processed text. +message TextResponse { + string text = 1; +} + +// usage: +// python -m grpc_tools.protoc -I. --python_out=./generated/ --grpc_python_out=./generated/ audio_processing.proto
\ No newline at end of file diff --git a/agl_service_voiceagent/protos/voice_agent.proto b/agl_service_voiceagent/protos/voice_agent.proto index 40dfe6a..bd2daa2 100644 --- a/agl_service_voiceagent/protos/voice_agent.proto +++ b/agl_service_voiceagent/protos/voice_agent.proto @@ -11,6 +11,15 @@ service VoiceAgentService { rpc ExecuteCommand(ExecuteInput) returns (ExecuteResult); } +enum STTFramework { + VOSK = 0; + WHISPER = 1; +} + +enum OnlineMode { + ONLINE = 0; + OFFLINE = 1; +} enum RecordAction { START = 0; @@ -69,6 +78,7 @@ message S_RecognizeVoiceControl { VoiceAudio audio_stream = 1; NLUModel nlu_model = 2; string stream_id = 3; + STTFramework stt_framework = 4; } message RecognizeVoiceControl { @@ -76,6 +86,8 @@ message RecognizeVoiceControl { NLUModel nlu_model = 2; RecordMode record_mode = 3; string stream_id = 4; + STTFramework stt_framework = 5; + OnlineMode online_mode = 6; } message RecognizeTextControl { diff --git a/agl_service_voiceagent/server.py b/agl_service_voiceagent/server.py index aa107dc..b244aa4 100644 --- a/agl_service_voiceagent/server.py +++ b/agl_service_voiceagent/server.py @@ -14,6 +14,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys +sys.path.append("../") import grpc from concurrent import futures from agl_service_voiceagent.generated import voice_agent_pb2_grpc @@ -24,7 +26,8 @@ def run_server(): logger = get_logger() SERVER_URL = get_config_value('SERVER_ADDRESS') + ":" + str(get_config_value('SERVER_PORT')) print("Starting Voice Agent Service...") - print(f"STT Model Path: {get_config_value('STT_MODEL_PATH')}") + print(f"VOSK Model Path: {get_config_value('VOSK_MODEL_PATH')}") + print(f"WHISPER Model Path: {get_config_value('WHISPER_MODEL_PATH')}") print(f"Audio Store Directory: {get_config_value('BASE_AUDIO_DIR')}") server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) voice_agent_pb2_grpc.add_VoiceAgentServiceServicer_to_server(VoiceAgentServicer(), server) diff --git a/agl_service_voiceagent/service.py b/agl_service_voiceagent/service.py index baf7b02..b5fb50e 100644 --- a/agl_service_voiceagent/service.py +++ b/agl_service_voiceagent/service.py @@ -23,6 +23,7 @@ current_dir = os.path.dirname(os.path.abspath(__file__)) generated_dir = os.path.join(current_dir, "generated") # Add the "generated" folder to sys.path sys.path.append(generated_dir) +sys.path.append("../") import argparse from agl_service_voiceagent.utils.config import set_config_path, load_config, update_config_value, get_config_value, get_logger @@ -49,8 +50,9 @@ def main(): # Add the arguments for the server server_parser.add_argument('--default', action='store_true', help='Starts the server based on default config file.') server_parser.add_argument('--config', required=False, help='Path to a config file. Server is started based on this config file.') - server_parser.add_argument('--stt-model-path', required=False, help='Path to the Speech To Text model for Voice Commad detection. Currently only supports VOSK Kaldi.') - server_parser.add_argument('--ww-model-path', required=False, help='Path to the Speech To Text model for Wake Word detection. Currently only supports VOSK Kaldi. Defaults to the same model as --stt-model-path if not provided.') + server_parser.add_argument('--vosk-model-path', required=False, help='Path to the Vosk Speech To Text model for Voice Commad detection.') + server_parser.add_argument('--whisper-model-path', required=False, help='Path to the Whisper Speech To Text model for Voice Commad detection.') + server_parser.add_argument('--ww-model-path', required=False, help='Path to the Speech To Text model for Wake Word detection. Currently only supports VOSK Kaldi. Defaults to the same model as --vosk-model-path if not provided.') server_parser.add_argument('--snips-model-path', required=False, help='Path to the Snips NLU model.') server_parser.add_argument('--rasa-model-path', required=False, help='Path to the RASA NLU model.') server_parser.add_argument('--rasa-detached-mode', required=False, help='Assume that the RASA server is already running and does not start it as a sub process.') @@ -59,6 +61,13 @@ def main(): server_parser.add_argument('--audio-store-dir', required=False, help='Directory to store the generated audio files.') server_parser.add_argument('--log-store-dir', required=False, help='Directory to store the generated log files.') + # Arguments for online mode + server_parser.add_argument('--online-mode', required=False, help='Enable online mode for the Voice Agent Service (default is False).') + server_parser.add_argument('--online-mode-address', required=False, help='URL of the online server to connect to.') + server_parser.add_argument('--online-mode-port', required=False, help='Port of the online server to connect to.') + server_parser.add_argument('--online-mode-timeout', required=False, help='Timeout value in seconds for the online server connection.') + + # Add the arguments for the client client_parser.add_argument('--server-address', required=True, help='Address of the gRPC server running the Voice Agent Service.') client_parser.add_argument('--server-port', required=True, help='Port of the gRPC server running the Voice Agent Service.') @@ -66,6 +75,10 @@ def main(): client_parser.add_argument('--mode', help='Mode to run the client in. Supported modes: "auto" and "manual".') client_parser.add_argument('--nlu', help='NLU engine/model to use. Supported NLU engines: "snips" and "rasa".') client_parser.add_argument('--recording-time', help='Number of seconds to continue recording the voice command. Required by the \'manual\' mode. Defaults to 10 seconds.') + client_parser.add_argument('--stt-framework', help='STT framework to use. Supported frameworks: "vosk". Defaults to "vosk".') + + # Arguments for online mode in client as --online-mode is a reserved keyword + client_parser.add_argument('--online-mode', required=False, help='Enable online mode for the Voice Agent Service (default is False).') args = parser.parse_args() @@ -74,8 +87,12 @@ def main(): elif args.subcommand == 'run-server': if not args.default and not args.config: - if not args.stt_model_path: - print("Error: The --stt-model-path is missing. Please provide a value. Use --help to see available options.") + if not args.vosk_model_path: + print("Error: The --vosk-model-path is missing. Please provide a value. Use --help to see available options.") + exit(1) + + if not args.whisper_model_path: + print("Error: The --whisper-model-path is missing. Please provide a value. Use --help to see available options.") exit(1) if not args.snips_model_path: @@ -94,6 +111,16 @@ def main(): print("Error: The --vss-signals-spec-path is missing. Please provide a value. Use --help to see available options.") exit(1) + # Error check for online mode + if args.online_mode: + if not args.online_mode_address: + print("Error: The --online-mode-address is missing. Please provide a value. Use --help to see available options.") + exit(1) + + if not args.online_mode_port: + print("Error: The --online-mode-port is missing. Please provide a value. Use --help to see available options.") + exit(1) + # Contruct the default config file path config_path = os.path.join(current_dir, "config.ini") @@ -105,21 +132,36 @@ def main(): logger.info("Starting Voice Agent Service in server mode using CLI provided params...") # Get the values provided by the user - stt_path = args.stt_model_path + vosk_path = args.vosk_model_path + whisper_path = args.whisper_model_path snips_model_path = args.snips_model_path rasa_model_path = args.rasa_model_path intents_vss_map_path = args.intents_vss_map_path vss_signals_spec_path = args.vss_signals_spec_path + # Get the values for online mode + online_mode = False + if args.online_mode: + online_mode = True + online_mode_address = args.online_mode_address + online_mode_port = args.online_mode_port + online_mode_timeout = args.online_mode_timeout or 5 + update_config_value('1', 'ONLINE_MODE') + update_config_value(online_mode_address, 'ONLINE_MODE_ADDRESS') + update_config_value(online_mode_port, 'ONLINE_MODE_PORT') + update_config_value(online_mode_timeout, 'ONLINE_MODE_TIMEOUT') + # Convert to an absolute path if it's a relative path - stt_path = add_trailing_slash(os.path.abspath(stt_path)) if not os.path.isabs(stt_path) else stt_path + vosk_path = add_trailing_slash(os.path.abspath(vosk_path)) if not os.path.isabs(vosk_path) else vosk_path + whisper_path = add_trailing_slash(os.path.abspath(whisper_path)) if not os.path.isabs(whisper_path) else whisper_path snips_model_path = add_trailing_slash(os.path.abspath(snips_model_path)) if not os.path.isabs(snips_model_path) else snips_model_path rasa_model_path = add_trailing_slash(os.path.abspath(rasa_model_path)) if not os.path.isabs(rasa_model_path) else rasa_model_path intents_vss_map_path = os.path.abspath(intents_vss_map_path) if not os.path.isabs(intents_vss_map_path) else intents_vss_map_path vss_signals_spec_path = os.path.abspath(vss_signals_spec_path) if not os.path.isabs(vss_signals_spec_path) else vss_signals_spec_path # Also update the config.ini file - update_config_value(stt_path, 'STT_MODEL_PATH') + update_config_value(vosk_path, 'VOSK_MODEL_PATH') + update_config_value(whisper_path, 'WHISPER_MODEL_PATH') update_config_value(snips_model_path, 'SNIPS_MODEL_PATH') update_config_value(rasa_model_path, 'RASA_MODEL_PATH') update_config_value(intents_vss_map_path, 'INTENTS_VSS_MAP') @@ -162,7 +204,6 @@ def main(): logger = get_logger() logger.info(f"Starting Voice Agent Service in server mode using the default config file...") - # create the base audio dir if not exists if not os.path.exists(get_config_value('BASE_AUDIO_DIR')): os.makedirs(get_config_value('BASE_AUDIO_DIR')) @@ -176,6 +217,8 @@ def main(): mode = "" action = args.action recording_time = 5 # seconds + stt_framework = args.stt_framework or "vosk" + online_mode = args.online_mode or False if action not in ["GetStatus", "DetectWakeWord", "ExecuteVoiceCommand", "ExecuteTextCommand"]: print("Error: Invalid value for --action. Supported actions: 'GetStatus', 'DetectWakeWord', 'ExecuteVoiceCommand' and 'ExecuteTextCommand'. Use --help to see available options.") @@ -199,8 +242,19 @@ def main(): mode = args.mode if mode == "manual" and args.recording_time: recording_time = int(args.recording_time) - - run_client(server_address, server_port, action, mode, nlu_engine, recording_time) + if args.stt_framework and args.stt_framework not in ['vosk', 'whisper']: + print("Error: Invalid value for --stt-framework. Supported frameworks: 'vosk' and 'whisper'. Use --help to see available options.") + exit(1) + if args.stt_framework: + stt_framework = args.stt_framework + + if args.online_mode and args.online_mode not in ['True', 'False', 'true', 'false', '1', '0']: + print("Error: Invalid value for --online-mode. Supported values: 'True' and 'False'. Use --help to see available options.") + exit(1) + if args.online_mode: + online_mode = True if args.online_mode in ['True', 'true', '1'] else False + + run_client(server_address, server_port, action, mode, nlu_engine, recording_time, stt_framework, online_mode) else: print_version() diff --git a/agl_service_voiceagent/servicers/voice_agent_servicer.py b/agl_service_voiceagent/servicers/voice_agent_servicer.py index 0565655..2a4de33 100644 --- a/agl_service_voiceagent/servicers/voice_agent_servicer.py +++ b/agl_service_voiceagent/servicers/voice_agent_servicer.py @@ -14,9 +14,12 @@ # See the License for the specific language governing permissions and # limitations under the License. +import sys +sys.path.append("../") import json import time import threading +import asyncio from agl_service_voiceagent.generated import voice_agent_pb2 from agl_service_voiceagent.generated import voice_agent_pb2_grpc from agl_service_voiceagent.utils.audio_recorder import AudioRecorder @@ -28,6 +31,10 @@ from agl_service_voiceagent.utils.config import get_config_value, get_logger from agl_service_voiceagent.utils.common import generate_unique_uuid, delete_file from agl_service_voiceagent.nlu.snips_interface import SnipsInterface from agl_service_voiceagent.nlu.rasa_interface import RASAInterface +from agl_service_voiceagent.utils.stt_online_service import STTOnlineService +from agl_service_voiceagent.utils.vss_interface import VSSInterface +from kuksa_client.grpc import Datapoint +from agl_service_voiceagent.utils.media_controller import MediaController class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer): @@ -46,7 +53,7 @@ class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer): self.channels = int(get_config_value('CHANNELS')) self.sample_rate = int(get_config_value('SAMPLE_RATE')) self.bits_per_sample = int(get_config_value('BITS_PER_SAMPLE')) - self.stt_model_path = get_config_value('STT_MODEL_PATH') + self.vosk_model_path = get_config_value('VOSK_MODEL_PATH') self.wake_word_model_path = get_config_value('WAKE_WORD_MODEL_PATH') self.snips_model_path = get_config_value('SNIPS_MODEL_PATH') self.rasa_model_path = get_config_value('RASA_MODEL_PATH') @@ -56,10 +63,25 @@ class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer): self.store_voice_command = bool(int(get_config_value('STORE_VOICE_COMMANDS'))) self.logger = get_logger() + # load the whisper model_path + self.whisper_model_path = get_config_value('WHISPER_MODEL_PATH') + self.whisper_cpp_path = get_config_value('WHISPER_CPP_PATH') + self.whisper_cpp_model_path = get_config_value('WHISPER_CPP_MODEL_PATH') + + # loading values for online mode + self.online_mode = bool(int(get_config_value('ONLINE_MODE'))) + if self.online_mode: + self.online_mode_address = get_config_value('ONLINE_MODE_ADDRESS') + self.online_mode_port = int(get_config_value('ONLINE_MODE_PORT')) + self.online_mode_timeout = int(get_config_value('ONLINE_MODE_TIMEOUT')) + self.stt_online = STTOnlineService(self.online_mode_address, self.online_mode_port, self.online_mode_timeout) + self.stt_online.initialize_connection() + + # Initialize class methods self.logger.info("Loading Speech to Text and Wake Word Model...") - self.stt_model = STTModel(self.stt_model_path, self.sample_rate) - self.stt_wake_word_model = STTModel(self.wake_word_model_path, self.sample_rate) + self.stt_model = STTModel(self.vosk_model_path, self.whisper_model_path,self.whisper_cpp_path,self.whisper_cpp_model_path,self.sample_rate) + self.stt_wake_word_model = STTModel(self.vosk_model_path, self.whisper_model_path,self.whisper_cpp_path,self.whisper_cpp_model_path,self.sample_rate) self.logger.info("Speech to Text and Wake Word Model loaded successfully.") self.logger.info("Starting SNIPS intent engine...") @@ -78,14 +100,91 @@ class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer): self.logger.info(f"RASA intent engine detached mode detected! Assuming RASA server is running at URL: 127.0.0.1:{self.rasa_server_port}") self.rvc_stream_uuids = {} - self.kuksa_client = KuksaInterface() - self.kuksa_client.connect_kuksa_client() - self.kuksa_client.authorize_kuksa_client() self.logger.info(f"Loading and parsing mapping files...") self.mapper = Intent2VSSMapper() self.logger.info(f"Successfully loaded and parsed mapping files.") + # Media controller + self.media_controller = MediaController() + + self.vss_interface = VSSInterface() + self.vss_thread = threading.Thread(target=self.start_vss_client) + self.vss_thread.start() + self.vss_event_loop = None + + # VSS client methods + + def start_vss_client(self): + """ + Start the VSS client. + """ + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + loop.run_until_complete(self.vss_interface.connect_vss_client()) + self.vss_event_loop = loop + loop.run_forever() + + def connect_vss_client(self): + """ + Connect the VSS client. + """ + future = asyncio.run_coroutine_threadsafe( + self.vss_interface.connect_vss_client(), + self.vss_event_loop + ) + return future.result() + + def set_current_values(self, path=None, value=None): + """ + Set the current values. + + Args: + path (str): The path to set. + value (any): The value to set. + """ + future = asyncio.run_coroutine_threadsafe( + self.vss_interface.set_current_values(path, value), + self.vss_event_loop + ) + return future.result() + + def get_current_values(self, paths=None): + """ + Get the current values. + + Args: + paths (list): The paths to get. + + Returns: + dict: The current values. + """ + print("Getting current values for paths:", paths) + future = asyncio.run_coroutine_threadsafe( + self.vss_interface.get_current_values(paths), + self.vss_event_loop + ) + return future.result() + + def disconnect_vss_client(self): + """ + Disconnect the VSS client. + """ + future = asyncio.run_coroutine_threadsafe( + self.vss_interface.disconnect_vss_client(), + self.vss_event_loop + ) + return future.result() + + def get_vss_server_info(self): + """ + Get the VSS server information. + """ + future = asyncio.run_coroutine_threadsafe( + self.vss_interface.get_server_info(), + self.vss_event_loop + ) + return future.result() def CheckServiceStatus(self, request, context): """ @@ -153,6 +252,16 @@ class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer): log_intent_slots = [] for request in requests: + stt_framework = '' + if request.stt_framework == voice_agent_pb2.VOSK: + stt_framework = 'vosk' + elif request.stt_framework == voice_agent_pb2.WHISPER: + stt_framework = 'whisper' + + use_online_mode = False + if request.online_mode == voice_agent_pb2.ONLINE: + use_online_mode = True + if request.record_mode == voice_agent_pb2.MANUAL: if request.action == voice_agent_pb2.START: @@ -187,14 +296,32 @@ class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer): del self.rvc_stream_uuids[stream_uuid] recorder.stop_recording() - recognizer_uuid = self.stt_model.setup_recognizer() - stt = self.stt_model.recognize_from_file(recognizer_uuid, audio_file) + + used_kaldi = False + + if use_online_mode and self.online_mode: + print("Recognizing voice command using online mode.") + if self.stt_online.initialized: + stt = self.stt_online.recognize_audio(audio_file=audio_file) + elif not self.stt_online.initialized: + self.stt_online.initialize_connection() + stt = self.stt_online.recognize_audio(audio_file=audio_file) + else: + recognizer_uuid = self.stt_model.setup_vosk_recognizer() + stt = self.stt_model.recognize_from_file(recognizer_uuid, audio_file,stt_framework=stt_framework) + used_kaldi = True + + if use_online_mode and self.online_mode and stt is None: + print("Online mode enabled but failed to recognize voice command. Switching to offline mode.") + recognizer_uuid = self.stt_model.setup_vosk_recognizer() + stt = self.stt_model.recognize_from_file(recognizer_uuid, audio_file,stt_framework=stt_framework) + used_kaldi = True + print(stt) if stt not in ["FILE_NOT_FOUND", "FILE_FORMAT_INVALID", "VOICE_NOT_RECOGNIZED", ""]: if request.nlu_model == voice_agent_pb2.SNIPS: extracted_intent = self.snips_interface.extract_intent(stt) intent, intent_actions = self.snips_interface.process_intent(extracted_intent) - if not intent or intent == "": status = voice_agent_pb2.INTENT_NOT_RECOGNIZED @@ -223,7 +350,9 @@ class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer): status = voice_agent_pb2.VOICE_NOT_RECOGNIZED # cleanup the kaldi recognizer - self.stt_model.cleanup_recognizer(recognizer_uuid) + if used_kaldi: + self.stt_model.cleanup_recognizer(recognizer_uuid) + used_kaldi = False # delete the audio file if not self.store_voice_command: @@ -323,6 +452,9 @@ class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer): """ Execute the voice command by sending the intent to Kuksa. """ + if self.vss_interface.is_connected==False: + self.logger.error("Kuksa client found disconnected. Trying to close old instance and re-connecting...") + self.connect_vss_client() # Log the unique request ID, client's IP address, and the endpoint request_id = generate_unique_uuid(8) client_ip = context.peer() @@ -335,36 +467,111 @@ class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer): slot_name = slot.name slot_value = slot.value processed_slots.append({"name": slot_name, "value": slot_value}) - print(intent) print(processed_slots) + + # Check for the media control intents + if intent == "MediaControl": + for slot in processed_slots: + if slot["name"] == "media_control_action": + action = slot["value"] + + if action == "resume" or action == "play": + if self.media_controller.resume(): + exec_response = "Yay, I successfully resumed the media." + exec_status = voice_agent_pb2.EXEC_SUCCESS + else: + exec_response = "Uh oh, I failed to resume the media." + exec_status = voice_agent_pb2.EXEC_ERROR + + elif action == "pause": + if self.media_controller.pause(): + exec_response = "Yay, I successfully paused the media." + exec_status = voice_agent_pb2.EXEC_SUCCESS + else: + exec_response = "Uh oh, I failed to pause the media." + exec_status = voice_agent_pb2.EXEC_ERROR + + elif action == "next": + if self.media_controller.next(): + exec_response = "Yay, I successfully played the next track." + exec_status = voice_agent_pb2.EXEC_SUCCESS + else: + exec_response = "Uh oh, I failed to play the next track." + exec_status = voice_agent_pb2.EXEC_ERROR + + elif action == "previous": + if self.media_controller.previous(): + exec_response = "Yay, I successfully played the previous track." + exec_status = voice_agent_pb2.EXEC_SUCCESS + else: + exec_response = "Uh oh, I failed to play the previous track." + exec_status = voice_agent_pb2.EXEC_ERROR + + elif action == "stop": + if self.media_controller.stop(): + exec_response = "Yay, I successfully stopped the media." + exec_status = voice_agent_pb2.EXEC_SUCCESS + else: + exec_response = "Uh oh, I failed to stop the media." + exec_status = voice_agent_pb2.EXEC_ERROR + else: + exec_response = "Sorry, I failed to execute command against intent 'MediaControl'. Maybe try again with more specific instructions." + exec_status = voice_agent_pb2.EXEC_ERROR + + + response = voice_agent_pb2.ExecuteResult( + response=exec_response, + status=exec_status + ) + return response + + execution_list = self.mapper.parse_intent(intent, processed_slots, req_id=request_id) exec_response = f"Sorry, I failed to execute command against intent '{intent}'. Maybe try again with more specific instructions." exec_status = voice_agent_pb2.EXEC_ERROR - # Check for kuksa status, and try re-connecting again if status is False - if not self.kuksa_client.get_kuksa_status(): + if self.vss_interface.is_connected: + self.logger.info(f"[ReqID#{request_id}] Kuksa client found connected.") + else: self.logger.error(f"[ReqID#{request_id}] Kuksa client found disconnected. Trying to close old instance and re-connecting...") - self.kuksa_client.close_kuksa_client() - self.kuksa_client.connect_kuksa_client() - self.kuksa_client.authorize_kuksa_client() - + exec_response = "Uh oh, I failed to connect to Kuksa." + exec_status = voice_agent_pb2.KUKSA_CONN_ERROR + self.disconnect_vss_client() + return voice_agent_pb2.ExecuteResult( + response=exec_response, + status=exec_status + ) + + if not self.vss_interface.is_connected: + self.logger.error(f"[ReqID#{request_id}] Kuksa client failed to connect.") + exec_response = "Uh oh, I failed to connect to Kuksa." + exec_status = voice_agent_pb2.KUKSA_CONN_ERROR + response = voice_agent_pb2.ExecuteResult( + response=exec_response, + status=exec_status + ) + return response for execution_item in execution_list: print(execution_item) action = execution_item["action"] signal = execution_item["signal"] - if self.kuksa_client.get_kuksa_status(): + if self.vss_interface.is_connected: if action == "set" and "value" in execution_item: value = execution_item["value"] - if self.kuksa_client.send_values(signal, value): + response = self.set_current_values(signal, value) + if response is None or response is False: + exec_response = "Uh oh, I failed to send value to Kuksa." + exec_status = voice_agent_pb2.KUKSA_CONN_ERROR + else: exec_response = f"Yay, I successfully updated the intent '{intent}' to value '{value}'." exec_status = voice_agent_pb2.EXEC_SUCCESS - + elif action in ["increase", "decrease"]: if "value" in execution_item: value = execution_item["value"] - if self.kuksa_client.send_values(signal, value): + if self.set_current_values(signal, value): exec_response = f"Yay, I successfully updated the intent '{intent}' to value '{value}'." exec_status = voice_agent_pb2.EXEC_SUCCESS @@ -372,22 +579,26 @@ class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer): # incoming values are always str as kuksa expects str during subscribe we need to convert # the value to int before performing any arithmetic operations and then convert back to str factor = int(execution_item["factor"]) - current_value = self.kuksa_client.get_value(signal) - if current_value: - current_value = int(current_value) - if action == "increase": - value = current_value + factor - value = str(value) - elif action == "decrease": - value = current_value - factor - value = str(value) - if self.kuksa_client.send_values(signal, value): - exec_response = f"Yay, I successfully updated the intent '{intent}' to value '{value}'." - exec_status = voice_agent_pb2.EXEC_SUCCESS - - else: + current_value = self.get_current_values(signal) + if current_value is None: exec_response = f"Uh oh, there is no value set for intent '{intent}'. Why not try setting a value first?" exec_status = voice_agent_pb2.KUKSA_CONN_ERROR + else: + if current_value: + current_value = int(current_value) + if action == "increase": + value = current_value + factor + value = str(value) + elif action == "decrease": + value = current_value - factor + value = str(value) + if self.set_current_values(signal, value): + exec_response = f"Yay, I successfully updated the intent '{intent}' to value '{value}'." + exec_status = voice_agent_pb2.EXEC_SUCCESS + + else: + exec_response = f"Uh oh, there is no value set for intent '{intent}'. Why not try setting a value first?" + exec_status = voice_agent_pb2.KUKSA_CONN_ERROR else: exec_response = "Uh oh, I failed to connect to Kuksa." diff --git a/agl_service_voiceagent/utils/audio_recorder.py b/agl_service_voiceagent/utils/audio_recorder.py index 2e8f11d..49716c9 100644 --- a/agl_service_voiceagent/utils/audio_recorder.py +++ b/agl_service_voiceagent/utils/audio_recorder.py @@ -66,6 +66,9 @@ class AudioRecorder: self.pipeline = Gst.Pipeline() autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None) queue = Gst.ElementFactory.make("queue", None) + queue.set_property("max-size-buffers", 0) + queue.set_property("max-size-bytes", 0) + queue.set_property("max-size-time", 0) audioconvert = Gst.ElementFactory.make("audioconvert", None) wavenc = Gst.ElementFactory.make("wavenc", None) diff --git a/agl_service_voiceagent/utils/media_controller.py b/agl_service_voiceagent/utils/media_controller.py new file mode 100644 index 0000000..60c2717 --- /dev/null +++ b/agl_service_voiceagent/utils/media_controller.py @@ -0,0 +1,132 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from mpd import MPDClient +import json +from agl_service_voiceagent.utils.config import get_config_value, get_logger +from agl_service_voiceagent.utils.common import load_json_file, words_to_number + + +class MediaController: + def __init__(self): + self.client = MPDClient() + self.ip = get_config_value('MPD_IP') + self.port = get_config_value('MPD_PORT') + self.is_connected = self.connect() + + def connect(self): + try: + self.client.connect(self.ip, self.port) + return True + except Exception as e: + print(f"[-] Error: Failed to connect to MPD server: {e}") + return False + + def play(self, uri): + ''' + Play the media file at the specified URI. + + Args: + uri (str): The URI of the media file to play. + + ''' + if not self.is_connected: + print("[-] Error: MPD client is not connected.") + return False + + try: + self.client.clear() + self.client.add(uri) + self.client.play() + return True + except Exception as e: + print(f"[-] Error: Failed to play media: {e}") + return False + + def stop(self): + ''' + Stop the media player. + ''' + if not self.is_connected: + print("[-] Error: MPD client is not connected.") + return False + try: + self.client.stop() + return True + except Exception as e: + print(f"[-] Error: Failed to stop media: {e}") + return False + + def pause(self): + ''' + Pause the media player. + ''' + if not self.is_connected: + print("[-] Error: MPD client is not connected.") + return False + try: + self.client.pause() + return True + except Exception as e: + print(f"[-] Error: Failed to pause media: {e}") + return False + + def resume(self): + ''' + Resume the media player. + ''' + if not self.is_connected: + print("[-] Error: MPD client is not connected.") + return False + try: + self.client.play() + return True + except Exception as e: + print(f"[-] Error: Failed to resume media: {e}") + return False + + def next(self): + ''' + Play the next track in the playlist. + ''' + if not self.is_connected: + print("[-] Error: MPD client is not connected.") + return False + try: + self.client.next() + return True + except Exception as e: + print(f"[-] Error: Failed to play next track: {e}") + return False + + def previous(self): + ''' + Play the previous track in the playlist. + ''' + if not self.is_connected: + print("[-] Error: MPD client is not connected.") + return False + try: + self.client.previous() + return True + except Exception as e: + print(f"[-] Error: Failed to play previous track: {e}") + return False + + def close(self): + self.client.close() + self.client.disconnect() + diff --git a/agl_service_voiceagent/utils/stt_model.py b/agl_service_voiceagent/utils/stt_model.py index d51ae31..7e8ad8b 100644 --- a/agl_service_voiceagent/utils/stt_model.py +++ b/agl_service_voiceagent/utils/stt_model.py @@ -20,12 +20,18 @@ import vosk import wave from agl_service_voiceagent.utils.common import generate_unique_uuid +# import the whisper model +import whisper +# for whisper timeout feature +from concurrent.futures import ThreadPoolExecutor +import subprocess + class STTModel: """ STTModel is a class for speech-to-text (STT) recognition using the Vosk speech recognition library. """ - def __init__(self, model_path, sample_rate=16000): + def __init__(self, vosk_model_path,whisper_model_path,whisper_cpp_path,whisper_cpp_model_path,sample_rate=16000): """ Initialize the STTModel instance with the provided model and sample rate. @@ -34,12 +40,15 @@ class STTModel: sample_rate (int, optional): The audio sample rate in Hz (default is 16000). """ self.sample_rate = sample_rate - self.model = vosk.Model(model_path) + self.vosk_model = vosk.Model(vosk_model_path) self.recognizer = {} self.chunk_size = 1024 + # self.whisper_model = whisper.load_model(whisper_model_path) + self.whisper_cpp_path = whisper_cpp_path + self.whisper_cpp_model_path = whisper_cpp_model_path - def setup_recognizer(self): + def setup_vosk_recognizer(self): """ Set up a Vosk recognizer for a new session and return a unique identifier (UUID) for the session. @@ -47,10 +56,9 @@ class STTModel: str: A unique identifier (UUID) for the session. """ uuid = generate_unique_uuid(6) - self.recognizer[uuid] = vosk.KaldiRecognizer(self.model, self.sample_rate) + self.recognizer[uuid] = vosk.KaldiRecognizer(self.vosk_model, self.sample_rate) return uuid - def init_recognition(self, uuid, audio_data): """ Initialize the Vosk recognizer for a session with audio data. @@ -64,8 +72,8 @@ class STTModel: """ return self.recognizer[uuid].AcceptWaveform(audio_data) - - def recognize(self, uuid, partial=False): + # Recognize speech using the Vosk recognizer + def recognize_using_vosk(self, uuid, partial=False): """ Recognize speech and return the result as a JSON object. @@ -84,14 +92,53 @@ class STTModel: self.recognizer[uuid].Reset() return result + # Recognize speech using the whisper model + def recognize_using_whisper(self,filename,language = None,timeout = 5,fp16=False): + """ + Recognize speech and return the result as a JSON object. + + Args: + filename (str): The path to the audio file. + timeout (int, optional): The timeout for recognition (default is 5 seconds). + fp16 (bool, optional): If True, use 16-bit floating point precision, (default is False) because cuda is not supported. + language (str, optional): The language code for recognition (default is None). + + Returns: + dict: A JSON object containing recognition results. + """ + def transcribe_with_whisper(): + return self.whisper_model.transcribe(filename, language = language,fp16=fp16) + + with ThreadPoolExecutor() as executor: + future = executor.submit(transcribe_with_whisper) + try: + return future.result(timeout=timeout) + except TimeoutError: + return {"error": "Transcription with Whisper exceeded the timeout."} + + def recognize_using_whisper_cpp(self,filename): + command = self.whisper_cpp_path + arguments = ["-m", self.whisper_cpp_model_path, "-f", filename, "-l", "en","-nt"] + + # Run the executable with the specified arguments + result = subprocess.run([command] + arguments, capture_output=True, text=True) + + if result.returncode == 0: + result = result.stdout.replace('\n', ' ').strip() + return {"text": result} + else: + print("Error:\n", result.stderr) + return {"error": result.stderr} + - def recognize_from_file(self, uuid, filename): + def recognize_from_file(self, uuid, filename,stt_framework="vosk"): """ Recognize speech from an audio file and return the recognized text. Args: uuid (str): The unique identifier (UUID) for the session. filename (str): The path to the audio file. + stt_model (str): The STT model to use for recognition (default is "vosk"). Returns: str: The recognized text or error messages. @@ -115,12 +162,28 @@ class STTModel: audio_data += chunk if audio_data: - if self.init_recognition(uuid, audio_data): - result = self.recognize(uuid) - return result['text'] - else: - result = self.recognize(uuid, partial=True) - return result['partial'] + # Perform speech recognition using the specified STT model + if stt_framework == "vosk": + if self.init_recognition(uuid, audio_data): + result = self.recognize_using_vosk(uuid) + return result['text'] + else: + result = self.recognize_using_vosk(uuid, partial=True) + return result['partial'] + + elif stt_framework == "whisper": + result = self.recognize_using_whisper_cpp(filename) + if 'error' in result: + print(result['error']) + # If Whisper times out, fall back to Vosk + if self.init_recognition(uuid, audio_data): + result = self.recognize_using_vosk(uuid) + return result['text'] + else: + result = self.recognize_using_vosk(uuid, partial=True) + return result['partial'] + else: + return result.get('text', '') else: print("Voice not recognized. Please speak again...") diff --git a/agl_service_voiceagent/utils/stt_online_service.py b/agl_service_voiceagent/utils/stt_online_service.py new file mode 100644 index 0000000..7bbdc5d --- /dev/null +++ b/agl_service_voiceagent/utils/stt_online_service.py @@ -0,0 +1,88 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import grpc +import sys +sys.path.append("../") +from agl_service_voiceagent.generated import audio_processing_pb2 +from agl_service_voiceagent.generated import audio_processing_pb2_grpc + +class STTOnlineService: + """ + STTOnlineService class is used to connect to an online gPRC based Whisper ASR service. + """ + def __init__(self, server_address, server_port,server_timeout=5): + """ + Initialize the online speech-to-text service. + + Args: + server_ip (str): The IP address of the online speech-to-text service. + server_port (int): The port number of the online speech-to-text service. + server_timeout (int, optional): The timeout value in seconds (default is 5). + """ + self.server_address = server_address + self.server_port = server_port + self.server_timeout = server_timeout + self.client = None + self.initialized = False + + + def initialize_connection(self): + """ + Initialize the connection to the online speech-to-text service. + """ + try: + channel = grpc.insecure_channel(f"{self.server_address}:{self.server_port}") + self.client = audio_processing_pb2_grpc.AudioProcessingStub(channel) + self.initialized = True + print("STTOnlineService initialized with server address:",self.server_address,"and port:",self.server_port,"and timeout:",self.server_timeout,"seconds.") + except Exception as e: + print("Error initializing online speech-to-text service:",e) + self.initialized = False + return self.initialized + + def close_connection(self): + """ + Close the connection to the online speech-to-text service. + """ + self.client = None + self.initialized = False + return not self.initialized + + def recognize_audio(self, audio_file): + """ + Recognize speech from audio data. + + Args: + audio_data (bytes): Audio data to process. + + Returns: + str: The recognized text. + """ + if not self.initialized: + print("STTOnlineService not initialized.") + return None + try: + with open(audio_file, 'rb') as audio_file: + audio_data = audio_file.read() + request = audio_processing_pb2.AudioRequest(audio_data=audio_data) + response = self.client.ProcessAudio(request,timeout=self.server_timeout) + return response.text + except Exception as e: + print("Error recognizing audio:",e) + return None + +
\ No newline at end of file diff --git a/agl_service_voiceagent/utils/vss_interface.py b/agl_service_voiceagent/utils/vss_interface.py new file mode 100644 index 0000000..a77e52c --- /dev/null +++ b/agl_service_voiceagent/utils/vss_interface.py @@ -0,0 +1,236 @@ +import json +import threading +from kuksa_client import KuksaClientThread +import sys +from pathlib import Path +import asyncio +import concurrent.futures +from kuksa_client.grpc.aio import VSSClient +from kuksa_client.grpc import Datapoint +import time +from agl_service_voiceagent.utils.config import get_config_value, get_logger + + +class VSSInterface: + """ + VSS Interface + + This class provides methods to initialize, authorize, connect, send values, + check the status, and close the Kuksa client. + """ + + _instance = None + _lock = threading.Lock() + + def __new__(cls): + """ + Get the unique instance of the class. + + Returns: + KuksaInterface: The instance of the class. + """ + with cls._lock: + if cls._instance is None: + cls._instance = super(VSSInterface, cls).__new__(cls) + cls._instance.init_client() + return cls._instance + + def init_client(self): + """ + Initialize the Kuksa client. + """ + # Defaults + self.hostname = str(get_config_value("hostname", "VSS")) + self.port = str(get_config_value("port", "VSS")) + self.token_filename = str(get_config_value("token_filename", "VSS")) + self.tls_server_name = str(get_config_value("tls_server_name", "VSS")) + self.verbose = False + self.insecure = bool(int(get_config_value("insecure", "VSS"))) + self.protocol = str(get_config_value("protocol", "VSS")) + self.ca_cert_filename = str(get_config_value("ca_cert_filename", "VSS")) + self.token = None + self.is_connected = False + self.logger = get_logger() + + self.set_token() + + # validate config + if not self.validate_config(): + exit(1) + + # define class methods + self.vss_client = None + + def validate_config(self): + """ + Validate the Kuksa client configuration. + + Returns: + bool: True if the configuration is valid, False otherwise. + """ + if self.hostname is None: + print("[-] Error: Kuksa IP address is not set.") + self.logger.error("Kuksa IP address is not set.") + return False + + if self.port is None: + print("[-] Error: Kuksa port is not set.") + self.logger.error("Kuksa port is not set.") + return False + + if self.token is None: + print("[-] Warning: Kuksa auth token is not set.") + self.logger.warning("Kuksa auth token is not set.") + + if self.protocol != "ws" and self.protocol != "grpc": + print("[-] Error: Invalid Kuksa protocol. Only 'ws' and 'grpc' are supported.") + self.logger.error("Invalid Kuksa protocol. Only 'ws' and 'grpc' are supported.") + return False + + return True + + def set_token(self): + """ + Set the Kuksa auth token. + """ + if self.token_filename != "": + token_file = open(self.token_filename, "r") + self.token = token_file.read() + else: + self.token = "" + + def get_vss_client(self): + """ + Get the VSS client instance. + + Returns: + VSSClientThread: The VSS client instance. + """ + if self.vss_client is None: + return None + return self.vss_client + + async def authorize_vss_client(self): + """ + Authorize the VSS client. + """ + if self.vss_client is None: + print("[-] Error: Failed to authorize Kuksa client. Kuksa client is not initialized.") + self.logger.error("Failed to authorize Kuksa client. Kuksa client is not initialized.") + return False + try: + await self.vss_client.authorize(self.token) + print(f"Authorized Kuksa client with token {self.token}") + return True + except Exception as e: + print(f"[-] Error: Failed to authorize Kuksa client: {e}") + self.logger.error(f"Failed to authorize Kuksa client: {e}") + return False + + async def get_server_info(self): + """ + Get the server information. + + Returns: + dict: The server information. + """ + if self.vss_client is None: + return None + try: + return await self.vss_client.get_server_info() + except Exception as e: + print(f"[-] Error: Failed to get server info: {e}") + self.logger.error(f"Failed to get server info: {e}") + return None + + async def connect_vss_client(self): + """ + Connect the VSS client. + """ + print(f"Connecting to KUKSA.val databroker at {self.hostname}:{self.port}") + try: + self.vss_client = VSSClient( + self.hostname, + self.port, + root_certificates=Path(self.ca_cert_filename), + token=self.token, + tls_server_name=self.tls_server_name, + ensure_startup_connection=True) + await self.vss_client.connect() + print(f"[+] Connected to KUKSA.val databroker at {self.hostname}:{self.port}") + self.is_connected = True + return True + except Exception as e: + print(f"[-] Error: Failed to connect to Kuksa val databroker: {e}") + self.logger.error(f"Failed to connect to Kuksa val databroker: {e}") + self.is_connected = False + return False + + + async def set_current_values(self, path=None, value=None): + """ + Set the current values. + + Args: + updates (dict): The updates to set. + """ + result = False + if self.vss_client is None: + print(f"[-] Error: Failed to send value '{value}' to Kuksa. Kuksa client is not initialized.") + self.logger.error(f"Failed to send value '{value}' to Kuksa. Kuksa client is not initialized.") + return result + try: + await self.vss_client.set_current_values({path: Datapoint(value)}) + result = True + except Exception as e: + print(f"[-] Error: Failed to send value '{value}' to Kuksa: {e}") + self.logger.error(f"Failed to send value '{value}' to Kuksa: {e}") + return result + + + async def get_current_values(self, path=None): + """ + Get the current values. + + Args: + paths (list): The paths to get. + + Returns: + dict: The current values. + + current_values = await client.get_current_values([ + 'Vehicle.Speed', + 'Vehicle.ADAS.ABS.IsActive', + ]) + speed_value = current_values['Vehicle.Speed'].value + """ + + if self.vss_client is None or self.is_connected is False: + return None + try: + result = await self.vss_client.get_current_values([path]) + return result[path].value + except Exception as e: + print(f"[-] Error: Failed to get current values: {e}") + self.logger.error(f"Failed to get current values: {e}") + return None + + async def disconnect_vss_client(self): + """ + Disconnect the VSS client. + """ + if self.vss_client is None: + print("[-] Error: Failed to disconnect Kuksa client. Kuksa client is not initialized.") + self.logger.error("Failed to disconnect Kuksa client. Kuksa client is not initialized.") + return False + try: + await self.vss_client.disconnect() + print("Disconnected from Kuksa val databroker.") + self.is_connected = False + return True + except Exception as e: + print(f"[-] Error: Failed to disconnect from Kuksa val databroker: {e}") + self.logger.error(f"Failed to disconnect from Kuksa val databroker: {e}") + return False + + diff --git a/agl_service_voiceagent/utils/wake_word.py b/agl_service_voiceagent/utils/wake_word.py index 47e547e..b672269 100644 --- a/agl_service_voiceagent/utils/wake_word.py +++ b/agl_service_voiceagent/utils/wake_word.py @@ -46,7 +46,7 @@ class WakeWordDetector: self.channels = channels self.bits_per_sample = bits_per_sample self.wake_word_model = stt_model # Speech to text model recognizer - self.recognizer_uuid = stt_model.setup_recognizer() + self.recognizer_uuid = stt_model.setup_vosk_recognizer() self.audio_buffer = bytearray() self.segment_size = int(self.sample_rate * 1.0) # Adjust the segment size (e.g., 1 second) @@ -140,7 +140,7 @@ class WakeWordDetector: # Perform wake word detection on the audio_data if self.wake_word_model.init_recognition(self.recognizer_uuid, audio_data): - stt_result = self.wake_word_model.recognize(self.recognizer_uuid) + stt_result = self.wake_word_model.recognize_using_vosk(self.recognizer_uuid) print("STT Result: ", stt_result) if self.wake_word in stt_result["text"]: self.wake_word_detected = True |