diff options
Diffstat (limited to 'agl_service_voiceagent/servicers/voice_agent_servicer.py')
-rw-r--r-- | agl_service_voiceagent/servicers/voice_agent_servicer.py | 156 |
1 files changed, 156 insertions, 0 deletions
diff --git a/agl_service_voiceagent/servicers/voice_agent_servicer.py b/agl_service_voiceagent/servicers/voice_agent_servicer.py new file mode 100644 index 0000000..4038b85 --- /dev/null +++ b/agl_service_voiceagent/servicers/voice_agent_servicer.py @@ -0,0 +1,156 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import grpc +import time +import threading +from generated import voice_agent_pb2 +from generated import voice_agent_pb2_grpc +from agl_service_voiceagent.utils.audio_recorder import AudioRecorder +from agl_service_voiceagent.utils.wake_word import WakeWordDetector +from agl_service_voiceagent.utils.stt_model import STTModel +from agl_service_voiceagent.utils.config import get_config_value +from agl_service_voiceagent.nlu.snips_interface import SnipsInterface +from agl_service_voiceagent.nlu.rasa_interface import RASAInterface +from agl_service_voiceagent.utils.common import generate_unique_uuid, delete_file + + +class VoiceAgentServicer(voice_agent_pb2_grpc.VoiceAgentServiceServicer): + def __init__(self): + # Get the config values + self.service_version = get_config_value('SERVICE_VERSION') + self.wake_word = get_config_value('WAKE_WORD') + self.base_audio_dir = get_config_value('BASE_AUDIO_DIR') + self.channels = int(get_config_value('CHANNELS')) + self.sample_rate = int(get_config_value('SAMPLE_RATE')) + self.bits_per_sample = int(get_config_value('BITS_PER_SAMPLE')) + self.stt_model_path = get_config_value('STT_MODEL_PATH') + self.snips_model_path = get_config_value('SNIPS_MODEL_PATH') + self.rasa_model_path = get_config_value('RASA_MODEL_PATH') + self.rasa_server_port = int(get_config_value('RASA_SERVER_PORT')) + self.base_log_dir = get_config_value('BASE_LOG_DIR') + self.store_voice_command = bool(int(get_config_value('STORE_VOICE_COMMANDS'))) + + # Initialize class methods + self.stt_model = STTModel(self.stt_model_path, self.sample_rate) + self.snips_interface = SnipsInterface(self.snips_model_path) + self.rasa_interface = RASAInterface(self.rasa_server_port, self.rasa_model_path, self.base_log_dir) + self.rasa_interface.start_server() + self.rvc_stream_uuids = {} + + + def CheckServiceStatus(self, request, context): + response = voice_agent_pb2.ServiceStatus( + version=self.service_version, + status=True + ) + return response + + + def DetectWakeWord(self, request, context): + wake_word_detector = WakeWordDetector(self.wake_word, self.stt_model, self.channels, self.sample_rate, self.bits_per_sample) + wake_word_detector.create_pipeline() + detection_thread = threading.Thread(target=wake_word_detector.start_listening) + detection_thread.start() + while True: + status = wake_word_detector.get_wake_word_status() + time.sleep(1) + if not context.is_active(): + wake_word_detector.send_eos() + break + yield voice_agent_pb2.WakeWordStatus(status=status) + if status: + break + + detection_thread.join() + + + def RecognizeVoiceCommand(self, requests, context): + stt = "" + intent = "" + intent_slots = [] + + for request in requests: + if request.record_mode == voice_agent_pb2.MANUAL: + + if request.action == voice_agent_pb2.START: + status = voice_agent_pb2.REC_PROCESSING + stream_uuid = generate_unique_uuid(8) + recorder = AudioRecorder(self.stt_model, self.base_audio_dir, self.channels, self.sample_rate, self.bits_per_sample) + recorder.set_pipeline_mode("manual") + audio_file = recorder.create_pipeline() + + self.rvc_stream_uuids[stream_uuid] = { + "recorder": recorder, + "audio_file": audio_file + } + + recorder.start_recording() + + elif request.action == voice_agent_pb2.STOP: + stream_uuid = request.stream_id + status = voice_agent_pb2.REC_SUCCESS + + recorder = self.rvc_stream_uuids[stream_uuid]["recorder"] + audio_file = self.rvc_stream_uuids[stream_uuid]["audio_file"] + del self.rvc_stream_uuids[stream_uuid] + + recorder.stop_recording() + recognizer_uuid = self.stt_model.setup_recognizer() + stt = self.stt_model.recognize_from_file(recognizer_uuid, audio_file) + + if stt not in ["FILE_NOT_FOUND", "FILE_FORMAT_INVALID", "VOICE_NOT_RECOGNIZED", ""]: + if request.nlu_model == voice_agent_pb2.SNIPS: + extracted_intent = self.snips_interface.extract_intent(stt) + intent, intent_actions = self.snips_interface.process_intent(extracted_intent) + + if not intent or intent == "": + status = voice_agent_pb2.INTENT_NOT_RECOGNIZED + + for action, value in intent_actions.items(): + intent_slots.append(voice_agent_pb2.IntentSlot(name=action, value=value)) + + elif request.nlu_model == voice_agent_pb2.RASA: + extracted_intent = self.rasa_interface.extract_intent(stt) + intent, intent_actions = self.rasa_interface.process_intent(extracted_intent) + + if not intent or intent == "": + status = voice_agent_pb2.INTENT_NOT_RECOGNIZED + + for action, value in intent_actions.items(): + intent_slots.append(voice_agent_pb2.IntentSlot(name=action, value=value)) + + else: + stt = "" + status = voice_agent_pb2.VOICE_NOT_RECOGNIZED + + # cleanup the kaldi recognizer + self.stt_model.cleanup_recognizer(recognizer_uuid) + + # delete the audio file + if not self.store_voice_command: + delete_file(audio_file) + + + # Process the request and generate a RecognizeResult + response = voice_agent_pb2.RecognizeResult( + command=stt, + intent=intent, + intent_slots=intent_slots, + stream_id=stream_uuid, + status=status + ) + return response |