diff options
Diffstat (limited to 'agl_service_voiceagent/utils')
-rw-r--r-- | agl_service_voiceagent/utils/__init__.py | 0 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/audio_recorder.py | 145 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/common.py | 50 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/config.py | 34 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/kuksa_interface.py | 66 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/stt_model.py | 105 | ||||
-rw-r--r-- | agl_service_voiceagent/utils/wake_word.py | 150 |
7 files changed, 550 insertions, 0 deletions
diff --git a/agl_service_voiceagent/utils/__init__.py b/agl_service_voiceagent/utils/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/agl_service_voiceagent/utils/__init__.py diff --git a/agl_service_voiceagent/utils/audio_recorder.py b/agl_service_voiceagent/utils/audio_recorder.py new file mode 100644 index 0000000..61ce994 --- /dev/null +++ b/agl_service_voiceagent/utils/audio_recorder.py @@ -0,0 +1,145 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gi +import vosk +import time +gi.require_version('Gst', '1.0') +from gi.repository import Gst, GLib + +Gst.init(None) +GLib.threads_init() + +class AudioRecorder: + def __init__(self, stt_model, audio_files_basedir, channels=1, sample_rate=16000, bits_per_sample=16): + self.loop = GLib.MainLoop() + self.mode = None + self.pipeline = None + self.bus = None + self.audio_files_basedir = audio_files_basedir + self.sample_rate = sample_rate + self.channels = channels + self.bits_per_sample = bits_per_sample + self.frame_size = int(self.sample_rate * 0.02) + self.audio_model = stt_model + self.buffer_duration = 1 # Buffer audio for atleast 1 second + self.audio_buffer = bytearray() + self.energy_threshold = 50000 # Adjust this threshold as needed + self.silence_frames_threshold = 10 + self.frames_above_threshold = 0 + + + def create_pipeline(self): + print("Creating pipeline for audio recording in {} mode...".format(self.mode)) + self.pipeline = Gst.Pipeline() + autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None) + queue = Gst.ElementFactory.make("queue", None) + audioconvert = Gst.ElementFactory.make("audioconvert", None) + wavenc = Gst.ElementFactory.make("wavenc", None) + + capsfilter = Gst.ElementFactory.make("capsfilter", None) + caps = Gst.Caps.new_empty_simple("audio/x-raw") + caps.set_value("format", "S16LE") + caps.set_value("rate", self.sample_rate) + caps.set_value("channels", self.channels) + capsfilter.set_property("caps", caps) + + self.pipeline.add(autoaudiosrc) + self.pipeline.add(queue) + self.pipeline.add(audioconvert) + self.pipeline.add(wavenc) + self.pipeline.add(capsfilter) + + autoaudiosrc.link(queue) + queue.link(audioconvert) + audioconvert.link(capsfilter) + + audio_file_name = f"{self.audio_files_basedir}{int(time.time())}.wav" + + filesink = Gst.ElementFactory.make("filesink", None) + filesink.set_property("location", audio_file_name) + self.pipeline.add(filesink) + capsfilter.link(wavenc) + wavenc.link(filesink) + + self.bus = self.pipeline.get_bus() + self.bus.add_signal_watch() + self.bus.connect("message", self.on_bus_message) + + return audio_file_name + + + def start_recording(self): + self.pipeline.set_state(Gst.State.PLAYING) + print("Recording Voice Input...") + + + def stop_recording(self): + print("Stopping recording...") + self.frames_above_threshold = 0 + self.cleanup_pipeline() + print("Recording finished!") + + + def set_pipeline_mode(self, mode): + self.mode = mode + + + # this method helps with error handling + def on_bus_message(self, bus, message): + if message.type == Gst.MessageType.EOS: + print("End-of-stream message received") + self.stop_recording() + + elif message.type == Gst.MessageType.ERROR: + err, debug_info = message.parse_error() + print(f"Error received from element {message.src.get_name()}: {err.message}") + print(f"Debugging information: {debug_info}") + self.stop_recording() + + elif message.type == Gst.MessageType.WARNING: + err, debug_info = message.parse_warning() + print(f"Warning received from element {message.src.get_name()}: {err.message}") + print(f"Debugging information: {debug_info}") + + elif message.type == Gst.MessageType.STATE_CHANGED: + if isinstance(message.src, Gst.Pipeline): + old_state, new_state, pending_state = message.parse_state_changed() + print(("Pipeline state changed from %s to %s." % + (old_state.value_nick, new_state.value_nick))) + + elif self.mode == "auto" and message.type == Gst.MessageType.ELEMENT: + if message.get_structure().get_name() == "level": + rms = message.get_structure()["rms"][0] + if rms > self.energy_threshold: + self.frames_above_threshold += 1 + # if self.frames_above_threshold >= self.silence_frames_threshold: + # self.start_recording() + else: + if self.frames_above_threshold > 0: + self.frames_above_threshold -= 1 + if self.frames_above_threshold == 0: + self.stop_recording() + + + def cleanup_pipeline(self): + if self.pipeline is not None: + print("Cleaning up pipeline...") + self.pipeline.set_state(Gst.State.NULL) + self.bus.remove_signal_watch() + print("Pipeline cleanup complete!") + self.bus = None + self.pipeline = None diff --git a/agl_service_voiceagent/utils/common.py b/agl_service_voiceagent/utils/common.py new file mode 100644 index 0000000..682473e --- /dev/null +++ b/agl_service_voiceagent/utils/common.py @@ -0,0 +1,50 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import uuid +import json + + +def add_trailing_slash(path): + if path and not path.endswith('/'): + path += '/' + return path + + +def generate_unique_uuid(length): + unique_id = str(uuid.uuid4().int) + # Ensure the generated ID is exactly 'length' digits by taking the last 'length' characters + unique_id = unique_id[-length:] + return unique_id + + +def load_json_file(file_path): + try: + with open(file_path, 'r') as file: + return json.load(file) + except FileNotFoundError: + raise ValueError(f"File '{file_path}' not found.") + + +def delete_file(file_path): + if os.path.exists(file_path): + try: + os.remove(file_path) + except Exception as e: + print(f"Error deleting '{file_path}': {e}") + else: + print(f"File '{file_path}' does not exist.")
\ No newline at end of file diff --git a/agl_service_voiceagent/utils/config.py b/agl_service_voiceagent/utils/config.py new file mode 100644 index 0000000..8d7f346 --- /dev/null +++ b/agl_service_voiceagent/utils/config.py @@ -0,0 +1,34 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import configparser + +# Get the absolute path to the directory of the current script +current_dir = os.path.dirname(os.path.abspath(__file__)) +# Construct the path to the config.ini file located in the base directory +config_path = os.path.join(current_dir, '..', 'config.ini') + +config = configparser.ConfigParser() +config.read(config_path) + +def update_config_value(value, key, group="General"): + config.set(group, key, value) + with open(config_path, 'w') as configfile: + config.write(configfile) + +def get_config_value(key, group="General"): + return config.get(group, key) diff --git a/agl_service_voiceagent/utils/kuksa_interface.py b/agl_service_voiceagent/utils/kuksa_interface.py new file mode 100644 index 0000000..3e1c045 --- /dev/null +++ b/agl_service_voiceagent/utils/kuksa_interface.py @@ -0,0 +1,66 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from kuksa_client import KuksaClientThread +from agl_service_voiceagent.utils.config import get_config_value + +class KuksaInterface: + def __init__(self): + # get config values + self.ip = get_config_value("ip", "Kuksa") + self.port = get_config_value("port", "Kuksa") + self.insecure = get_config_value("insecure", "Kuksa") + self.protocol = get_config_value("protocol", "Kuksa") + self.token = get_config_value("token", "Kuksa") + + # define class methods + self.kuksa_client = None + + + def get_kuksa_client(self): + return self.kuksa_client + + + def get_kuksa_status(self): + if self.kuksa_client: + return self.kuksa_client.checkConnection() + + return False + + + def connect_kuksa_client(self): + try: + self.kuksa_client = KuksaClientThread({ + "ip": self.ip, + "port": self.port, + "insecure": self.insecure, + "protocol": self.protocol, + }) + self.kuksa_client.authorize(self.token) + + except Exception as e: + print("Error: ", e) + + + def send_values(self, Path=None, Value=None): + if self.kuksa_client is None: + print("Error: Kuksa client is not initialized.") + + if self.get_kuksa_status(): + self.kuksa_client.setValue(Path, Value) + + else: + print("Error: Connection to Kuksa failed.") diff --git a/agl_service_voiceagent/utils/stt_model.py b/agl_service_voiceagent/utils/stt_model.py new file mode 100644 index 0000000..5337162 --- /dev/null +++ b/agl_service_voiceagent/utils/stt_model.py @@ -0,0 +1,105 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import json +import vosk +import wave +from agl_service_voiceagent.utils.common import generate_unique_uuid + +class STTModel: + def __init__(self, model_path, sample_rate=16000): + self.sample_rate = sample_rate + self.model = vosk.Model(model_path) + self.recognizer = {} + self.chunk_size = 1024 + + def setup_recognizer(self): + uuid = generate_unique_uuid(6) + self.recognizer[uuid] = vosk.KaldiRecognizer(self.model, self.sample_rate) + return uuid + + def init_recognition(self, uuid, audio_data): + return self.recognizer[uuid].AcceptWaveform(audio_data) + + def recognize(self, uuid, partial=False): + self.recognizer[uuid].SetWords(True) + if partial: + result = json.loads(self.recognizer[uuid].PartialResult()) + else: + result = json.loads(self.recognizer[uuid].Result()) + self.recognizer[uuid].Reset() + return result + + def recognize_from_file(self, uuid, filename): + if not os.path.exists(filename): + print(f"Audio file '{filename}' not found.") + return "FILE_NOT_FOUND" + + wf = wave.open(filename, "rb") + if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE": + print("Audio file must be WAV format mono PCM.") + return "FILE_FORMAT_INVALID" + + # audio_data = wf.readframes(wf.getnframes()) + # we need to perform chunking as target AGL system can't handle an entire audio file + audio_data = b"" + while True: + chunk = wf.readframes(self.chunk_size) + if not chunk: + break # End of file reached + audio_data += chunk + + if audio_data: + if self.init_recognition(uuid, audio_data): + result = self.recognize(uuid) + return result['text'] + else: + result = self.recognize(uuid, partial=True) + return result['partial'] + + else: + print("Voice not recognized. Please speak again...") + return "VOICE_NOT_RECOGNIZED" + + def cleanup_recognizer(self, uuid): + del self.recognizer[uuid] + +import wave + +def read_wav_file(filename, chunk_size=1024): + try: + wf = wave.open(filename, "rb") + if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE": + print("Audio file must be WAV format mono PCM.") + return "FILE_FORMAT_INVALID" + + audio_data = b"" # Initialize an empty bytes object to store audio data + while True: + chunk = wf.readframes(chunk_size) + if not chunk: + break # End of file reached + audio_data += chunk + + return audio_data + except Exception as e: + print(f"Error reading audio file: {e}") + return None + +# Example usage: +filename = "your_audio.wav" +audio_data = read_wav_file(filename) +
\ No newline at end of file diff --git a/agl_service_voiceagent/utils/wake_word.py b/agl_service_voiceagent/utils/wake_word.py new file mode 100644 index 0000000..066ae6d --- /dev/null +++ b/agl_service_voiceagent/utils/wake_word.py @@ -0,0 +1,150 @@ +# SPDX-License-Identifier: Apache-2.0 +# +# Copyright (c) 2023 Malik Talha +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gi +import vosk +gi.require_version('Gst', '1.0') +from gi.repository import Gst, GLib + +Gst.init(None) +GLib.threads_init() + +class WakeWordDetector: + def __init__(self, wake_word, stt_model, channels=1, sample_rate=16000, bits_per_sample=16): + self.loop = GLib.MainLoop() + self.pipeline = None + self.bus = None + self.wake_word = wake_word + self.wake_word_detected = False + self.sample_rate = sample_rate + self.channels = channels + self.bits_per_sample = bits_per_sample + self.frame_size = int(self.sample_rate * 0.02) + self.stt_model = stt_model # Speech to text model recognizer + self.recognizer_uuid = stt_model.setup_recognizer() + self.buffer_duration = 1 # Buffer audio for atleast 1 second + self.audio_buffer = bytearray() + + def get_wake_word_status(self): + return self.wake_word_detected + + def create_pipeline(self): + print("Creating pipeline for Wake Word Detection...") + self.pipeline = Gst.Pipeline() + autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None) + queue = Gst.ElementFactory.make("queue", None) + audioconvert = Gst.ElementFactory.make("audioconvert", None) + wavenc = Gst.ElementFactory.make("wavenc", None) + + capsfilter = Gst.ElementFactory.make("capsfilter", None) + caps = Gst.Caps.new_empty_simple("audio/x-raw") + caps.set_value("format", "S16LE") + caps.set_value("rate", self.sample_rate) + caps.set_value("channels", self.channels) + capsfilter.set_property("caps", caps) + + appsink = Gst.ElementFactory.make("appsink", None) + appsink.set_property("emit-signals", True) + appsink.set_property("sync", False) # Set sync property to False to enable async processing + appsink.connect("new-sample", self.on_new_buffer, None) + + self.pipeline.add(autoaudiosrc) + self.pipeline.add(queue) + self.pipeline.add(audioconvert) + self.pipeline.add(wavenc) + self.pipeline.add(capsfilter) + self.pipeline.add(appsink) + + autoaudiosrc.link(queue) + queue.link(audioconvert) + audioconvert.link(capsfilter) + capsfilter.link(appsink) + + self.bus = self.pipeline.get_bus() + self.bus.add_signal_watch() + self.bus.connect("message", self.on_bus_message) + + def on_new_buffer(self, appsink, data) -> Gst.FlowReturn: + sample = appsink.emit("pull-sample") + buffer = sample.get_buffer() + data = buffer.extract_dup(0, buffer.get_size()) + self.audio_buffer.extend(data) + + if len(self.audio_buffer) >= self.sample_rate * self.buffer_duration * self.channels * self.bits_per_sample // 8: + self.process_audio_buffer() + + return Gst.FlowReturn.OK + + + def process_audio_buffer(self): + # Process the accumulated audio data using the audio model + audio_data = bytes(self.audio_buffer) + if self.stt_model.init_recognition(self.recognizer_uuid, audio_data): + stt_result = self.stt_model.recognize(self.recognizer_uuid) + print("STT Result: ", stt_result) + if self.wake_word in stt_result["text"]: + self.wake_word_detected = True + print("Wake word detected!") + self.pipeline.send_event(Gst.Event.new_eos()) + + self.audio_buffer.clear() # Clear the buffer + + + def send_eos(self): + self.pipeline.send_event(Gst.Event.new_eos()) + self.audio_buffer.clear() + + + def start_listening(self): + self.pipeline.set_state(Gst.State.PLAYING) + print("Listening for Wake Word...") + self.loop.run() + + + def stop_listening(self): + self.cleanup_pipeline() + self.loop.quit() + + + def on_bus_message(self, bus, message): + if message.type == Gst.MessageType.EOS: + print("End-of-stream message received") + self.stop_listening() + elif message.type == Gst.MessageType.ERROR: + err, debug_info = message.parse_error() + print(f"Error received from element {message.src.get_name()}: {err.message}") + print(f"Debugging information: {debug_info}") + self.stop_listening() + elif message.type == Gst.MessageType.WARNING: + err, debug_info = message.parse_warning() + print(f"Warning received from element {message.src.get_name()}: {err.message}") + print(f"Debugging information: {debug_info}") + elif message.type == Gst.MessageType.STATE_CHANGED: + if isinstance(message.src, Gst.Pipeline): + old_state, new_state, pending_state = message.parse_state_changed() + print(("Pipeline state changed from %s to %s." % + (old_state.value_nick, new_state.value_nick))) + + + def cleanup_pipeline(self): + if self.pipeline is not None: + print("Cleaning up pipeline...") + self.pipeline.set_state(Gst.State.NULL) + self.bus.remove_signal_watch() + print("Pipeline cleanup complete!") + self.bus = None + self.pipeline = None + self.stt_model.cleanup_recognizer(self.recognizer_uuid) |