summaryrefslogtreecommitdiffstats
path: root/agl_service_voiceagent/utils
diff options
context:
space:
mode:
Diffstat (limited to 'agl_service_voiceagent/utils')
-rw-r--r--agl_service_voiceagent/utils/__init__.py0
-rw-r--r--agl_service_voiceagent/utils/audio_recorder.py145
-rw-r--r--agl_service_voiceagent/utils/common.py50
-rw-r--r--agl_service_voiceagent/utils/config.py34
-rw-r--r--agl_service_voiceagent/utils/kuksa_interface.py66
-rw-r--r--agl_service_voiceagent/utils/stt_model.py105
-rw-r--r--agl_service_voiceagent/utils/wake_word.py150
7 files changed, 550 insertions, 0 deletions
diff --git a/agl_service_voiceagent/utils/__init__.py b/agl_service_voiceagent/utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/agl_service_voiceagent/utils/__init__.py
diff --git a/agl_service_voiceagent/utils/audio_recorder.py b/agl_service_voiceagent/utils/audio_recorder.py
new file mode 100644
index 0000000..61ce994
--- /dev/null
+++ b/agl_service_voiceagent/utils/audio_recorder.py
@@ -0,0 +1,145 @@
+# SPDX-License-Identifier: Apache-2.0
+#
+# Copyright (c) 2023 Malik Talha
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gi
+import vosk
+import time
+gi.require_version('Gst', '1.0')
+from gi.repository import Gst, GLib
+
+Gst.init(None)
+GLib.threads_init()
+
+class AudioRecorder:
+ def __init__(self, stt_model, audio_files_basedir, channels=1, sample_rate=16000, bits_per_sample=16):
+ self.loop = GLib.MainLoop()
+ self.mode = None
+ self.pipeline = None
+ self.bus = None
+ self.audio_files_basedir = audio_files_basedir
+ self.sample_rate = sample_rate
+ self.channels = channels
+ self.bits_per_sample = bits_per_sample
+ self.frame_size = int(self.sample_rate * 0.02)
+ self.audio_model = stt_model
+ self.buffer_duration = 1 # Buffer audio for atleast 1 second
+ self.audio_buffer = bytearray()
+ self.energy_threshold = 50000 # Adjust this threshold as needed
+ self.silence_frames_threshold = 10
+ self.frames_above_threshold = 0
+
+
+ def create_pipeline(self):
+ print("Creating pipeline for audio recording in {} mode...".format(self.mode))
+ self.pipeline = Gst.Pipeline()
+ autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None)
+ queue = Gst.ElementFactory.make("queue", None)
+ audioconvert = Gst.ElementFactory.make("audioconvert", None)
+ wavenc = Gst.ElementFactory.make("wavenc", None)
+
+ capsfilter = Gst.ElementFactory.make("capsfilter", None)
+ caps = Gst.Caps.new_empty_simple("audio/x-raw")
+ caps.set_value("format", "S16LE")
+ caps.set_value("rate", self.sample_rate)
+ caps.set_value("channels", self.channels)
+ capsfilter.set_property("caps", caps)
+
+ self.pipeline.add(autoaudiosrc)
+ self.pipeline.add(queue)
+ self.pipeline.add(audioconvert)
+ self.pipeline.add(wavenc)
+ self.pipeline.add(capsfilter)
+
+ autoaudiosrc.link(queue)
+ queue.link(audioconvert)
+ audioconvert.link(capsfilter)
+
+ audio_file_name = f"{self.audio_files_basedir}{int(time.time())}.wav"
+
+ filesink = Gst.ElementFactory.make("filesink", None)
+ filesink.set_property("location", audio_file_name)
+ self.pipeline.add(filesink)
+ capsfilter.link(wavenc)
+ wavenc.link(filesink)
+
+ self.bus = self.pipeline.get_bus()
+ self.bus.add_signal_watch()
+ self.bus.connect("message", self.on_bus_message)
+
+ return audio_file_name
+
+
+ def start_recording(self):
+ self.pipeline.set_state(Gst.State.PLAYING)
+ print("Recording Voice Input...")
+
+
+ def stop_recording(self):
+ print("Stopping recording...")
+ self.frames_above_threshold = 0
+ self.cleanup_pipeline()
+ print("Recording finished!")
+
+
+ def set_pipeline_mode(self, mode):
+ self.mode = mode
+
+
+ # this method helps with error handling
+ def on_bus_message(self, bus, message):
+ if message.type == Gst.MessageType.EOS:
+ print("End-of-stream message received")
+ self.stop_recording()
+
+ elif message.type == Gst.MessageType.ERROR:
+ err, debug_info = message.parse_error()
+ print(f"Error received from element {message.src.get_name()}: {err.message}")
+ print(f"Debugging information: {debug_info}")
+ self.stop_recording()
+
+ elif message.type == Gst.MessageType.WARNING:
+ err, debug_info = message.parse_warning()
+ print(f"Warning received from element {message.src.get_name()}: {err.message}")
+ print(f"Debugging information: {debug_info}")
+
+ elif message.type == Gst.MessageType.STATE_CHANGED:
+ if isinstance(message.src, Gst.Pipeline):
+ old_state, new_state, pending_state = message.parse_state_changed()
+ print(("Pipeline state changed from %s to %s." %
+ (old_state.value_nick, new_state.value_nick)))
+
+ elif self.mode == "auto" and message.type == Gst.MessageType.ELEMENT:
+ if message.get_structure().get_name() == "level":
+ rms = message.get_structure()["rms"][0]
+ if rms > self.energy_threshold:
+ self.frames_above_threshold += 1
+ # if self.frames_above_threshold >= self.silence_frames_threshold:
+ # self.start_recording()
+ else:
+ if self.frames_above_threshold > 0:
+ self.frames_above_threshold -= 1
+ if self.frames_above_threshold == 0:
+ self.stop_recording()
+
+
+ def cleanup_pipeline(self):
+ if self.pipeline is not None:
+ print("Cleaning up pipeline...")
+ self.pipeline.set_state(Gst.State.NULL)
+ self.bus.remove_signal_watch()
+ print("Pipeline cleanup complete!")
+ self.bus = None
+ self.pipeline = None
diff --git a/agl_service_voiceagent/utils/common.py b/agl_service_voiceagent/utils/common.py
new file mode 100644
index 0000000..682473e
--- /dev/null
+++ b/agl_service_voiceagent/utils/common.py
@@ -0,0 +1,50 @@
+# SPDX-License-Identifier: Apache-2.0
+#
+# Copyright (c) 2023 Malik Talha
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import uuid
+import json
+
+
+def add_trailing_slash(path):
+ if path and not path.endswith('/'):
+ path += '/'
+ return path
+
+
+def generate_unique_uuid(length):
+ unique_id = str(uuid.uuid4().int)
+ # Ensure the generated ID is exactly 'length' digits by taking the last 'length' characters
+ unique_id = unique_id[-length:]
+ return unique_id
+
+
+def load_json_file(file_path):
+ try:
+ with open(file_path, 'r') as file:
+ return json.load(file)
+ except FileNotFoundError:
+ raise ValueError(f"File '{file_path}' not found.")
+
+
+def delete_file(file_path):
+ if os.path.exists(file_path):
+ try:
+ os.remove(file_path)
+ except Exception as e:
+ print(f"Error deleting '{file_path}': {e}")
+ else:
+ print(f"File '{file_path}' does not exist.") \ No newline at end of file
diff --git a/agl_service_voiceagent/utils/config.py b/agl_service_voiceagent/utils/config.py
new file mode 100644
index 0000000..8d7f346
--- /dev/null
+++ b/agl_service_voiceagent/utils/config.py
@@ -0,0 +1,34 @@
+# SPDX-License-Identifier: Apache-2.0
+#
+# Copyright (c) 2023 Malik Talha
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import configparser
+
+# Get the absolute path to the directory of the current script
+current_dir = os.path.dirname(os.path.abspath(__file__))
+# Construct the path to the config.ini file located in the base directory
+config_path = os.path.join(current_dir, '..', 'config.ini')
+
+config = configparser.ConfigParser()
+config.read(config_path)
+
+def update_config_value(value, key, group="General"):
+ config.set(group, key, value)
+ with open(config_path, 'w') as configfile:
+ config.write(configfile)
+
+def get_config_value(key, group="General"):
+ return config.get(group, key)
diff --git a/agl_service_voiceagent/utils/kuksa_interface.py b/agl_service_voiceagent/utils/kuksa_interface.py
new file mode 100644
index 0000000..3e1c045
--- /dev/null
+++ b/agl_service_voiceagent/utils/kuksa_interface.py
@@ -0,0 +1,66 @@
+# SPDX-License-Identifier: Apache-2.0
+#
+# Copyright (c) 2023 Malik Talha
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from kuksa_client import KuksaClientThread
+from agl_service_voiceagent.utils.config import get_config_value
+
+class KuksaInterface:
+ def __init__(self):
+ # get config values
+ self.ip = get_config_value("ip", "Kuksa")
+ self.port = get_config_value("port", "Kuksa")
+ self.insecure = get_config_value("insecure", "Kuksa")
+ self.protocol = get_config_value("protocol", "Kuksa")
+ self.token = get_config_value("token", "Kuksa")
+
+ # define class methods
+ self.kuksa_client = None
+
+
+ def get_kuksa_client(self):
+ return self.kuksa_client
+
+
+ def get_kuksa_status(self):
+ if self.kuksa_client:
+ return self.kuksa_client.checkConnection()
+
+ return False
+
+
+ def connect_kuksa_client(self):
+ try:
+ self.kuksa_client = KuksaClientThread({
+ "ip": self.ip,
+ "port": self.port,
+ "insecure": self.insecure,
+ "protocol": self.protocol,
+ })
+ self.kuksa_client.authorize(self.token)
+
+ except Exception as e:
+ print("Error: ", e)
+
+
+ def send_values(self, Path=None, Value=None):
+ if self.kuksa_client is None:
+ print("Error: Kuksa client is not initialized.")
+
+ if self.get_kuksa_status():
+ self.kuksa_client.setValue(Path, Value)
+
+ else:
+ print("Error: Connection to Kuksa failed.")
diff --git a/agl_service_voiceagent/utils/stt_model.py b/agl_service_voiceagent/utils/stt_model.py
new file mode 100644
index 0000000..5337162
--- /dev/null
+++ b/agl_service_voiceagent/utils/stt_model.py
@@ -0,0 +1,105 @@
+# SPDX-License-Identifier: Apache-2.0
+#
+# Copyright (c) 2023 Malik Talha
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import json
+import vosk
+import wave
+from agl_service_voiceagent.utils.common import generate_unique_uuid
+
+class STTModel:
+ def __init__(self, model_path, sample_rate=16000):
+ self.sample_rate = sample_rate
+ self.model = vosk.Model(model_path)
+ self.recognizer = {}
+ self.chunk_size = 1024
+
+ def setup_recognizer(self):
+ uuid = generate_unique_uuid(6)
+ self.recognizer[uuid] = vosk.KaldiRecognizer(self.model, self.sample_rate)
+ return uuid
+
+ def init_recognition(self, uuid, audio_data):
+ return self.recognizer[uuid].AcceptWaveform(audio_data)
+
+ def recognize(self, uuid, partial=False):
+ self.recognizer[uuid].SetWords(True)
+ if partial:
+ result = json.loads(self.recognizer[uuid].PartialResult())
+ else:
+ result = json.loads(self.recognizer[uuid].Result())
+ self.recognizer[uuid].Reset()
+ return result
+
+ def recognize_from_file(self, uuid, filename):
+ if not os.path.exists(filename):
+ print(f"Audio file '{filename}' not found.")
+ return "FILE_NOT_FOUND"
+
+ wf = wave.open(filename, "rb")
+ if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE":
+ print("Audio file must be WAV format mono PCM.")
+ return "FILE_FORMAT_INVALID"
+
+ # audio_data = wf.readframes(wf.getnframes())
+ # we need to perform chunking as target AGL system can't handle an entire audio file
+ audio_data = b""
+ while True:
+ chunk = wf.readframes(self.chunk_size)
+ if not chunk:
+ break # End of file reached
+ audio_data += chunk
+
+ if audio_data:
+ if self.init_recognition(uuid, audio_data):
+ result = self.recognize(uuid)
+ return result['text']
+ else:
+ result = self.recognize(uuid, partial=True)
+ return result['partial']
+
+ else:
+ print("Voice not recognized. Please speak again...")
+ return "VOICE_NOT_RECOGNIZED"
+
+ def cleanup_recognizer(self, uuid):
+ del self.recognizer[uuid]
+
+import wave
+
+def read_wav_file(filename, chunk_size=1024):
+ try:
+ wf = wave.open(filename, "rb")
+ if wf.getnchannels() != 1 or wf.getsampwidth() != 2 or wf.getcomptype() != "NONE":
+ print("Audio file must be WAV format mono PCM.")
+ return "FILE_FORMAT_INVALID"
+
+ audio_data = b"" # Initialize an empty bytes object to store audio data
+ while True:
+ chunk = wf.readframes(chunk_size)
+ if not chunk:
+ break # End of file reached
+ audio_data += chunk
+
+ return audio_data
+ except Exception as e:
+ print(f"Error reading audio file: {e}")
+ return None
+
+# Example usage:
+filename = "your_audio.wav"
+audio_data = read_wav_file(filename)
+ \ No newline at end of file
diff --git a/agl_service_voiceagent/utils/wake_word.py b/agl_service_voiceagent/utils/wake_word.py
new file mode 100644
index 0000000..066ae6d
--- /dev/null
+++ b/agl_service_voiceagent/utils/wake_word.py
@@ -0,0 +1,150 @@
+# SPDX-License-Identifier: Apache-2.0
+#
+# Copyright (c) 2023 Malik Talha
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import gi
+import vosk
+gi.require_version('Gst', '1.0')
+from gi.repository import Gst, GLib
+
+Gst.init(None)
+GLib.threads_init()
+
+class WakeWordDetector:
+ def __init__(self, wake_word, stt_model, channels=1, sample_rate=16000, bits_per_sample=16):
+ self.loop = GLib.MainLoop()
+ self.pipeline = None
+ self.bus = None
+ self.wake_word = wake_word
+ self.wake_word_detected = False
+ self.sample_rate = sample_rate
+ self.channels = channels
+ self.bits_per_sample = bits_per_sample
+ self.frame_size = int(self.sample_rate * 0.02)
+ self.stt_model = stt_model # Speech to text model recognizer
+ self.recognizer_uuid = stt_model.setup_recognizer()
+ self.buffer_duration = 1 # Buffer audio for atleast 1 second
+ self.audio_buffer = bytearray()
+
+ def get_wake_word_status(self):
+ return self.wake_word_detected
+
+ def create_pipeline(self):
+ print("Creating pipeline for Wake Word Detection...")
+ self.pipeline = Gst.Pipeline()
+ autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None)
+ queue = Gst.ElementFactory.make("queue", None)
+ audioconvert = Gst.ElementFactory.make("audioconvert", None)
+ wavenc = Gst.ElementFactory.make("wavenc", None)
+
+ capsfilter = Gst.ElementFactory.make("capsfilter", None)
+ caps = Gst.Caps.new_empty_simple("audio/x-raw")
+ caps.set_value("format", "S16LE")
+ caps.set_value("rate", self.sample_rate)
+ caps.set_value("channels", self.channels)
+ capsfilter.set_property("caps", caps)
+
+ appsink = Gst.ElementFactory.make("appsink", None)
+ appsink.set_property("emit-signals", True)
+ appsink.set_property("sync", False) # Set sync property to False to enable async processing
+ appsink.connect("new-sample", self.on_new_buffer, None)
+
+ self.pipeline.add(autoaudiosrc)
+ self.pipeline.add(queue)
+ self.pipeline.add(audioconvert)
+ self.pipeline.add(wavenc)
+ self.pipeline.add(capsfilter)
+ self.pipeline.add(appsink)
+
+ autoaudiosrc.link(queue)
+ queue.link(audioconvert)
+ audioconvert.link(capsfilter)
+ capsfilter.link(appsink)
+
+ self.bus = self.pipeline.get_bus()
+ self.bus.add_signal_watch()
+ self.bus.connect("message", self.on_bus_message)
+
+ def on_new_buffer(self, appsink, data) -> Gst.FlowReturn:
+ sample = appsink.emit("pull-sample")
+ buffer = sample.get_buffer()
+ data = buffer.extract_dup(0, buffer.get_size())
+ self.audio_buffer.extend(data)
+
+ if len(self.audio_buffer) >= self.sample_rate * self.buffer_duration * self.channels * self.bits_per_sample // 8:
+ self.process_audio_buffer()
+
+ return Gst.FlowReturn.OK
+
+
+ def process_audio_buffer(self):
+ # Process the accumulated audio data using the audio model
+ audio_data = bytes(self.audio_buffer)
+ if self.stt_model.init_recognition(self.recognizer_uuid, audio_data):
+ stt_result = self.stt_model.recognize(self.recognizer_uuid)
+ print("STT Result: ", stt_result)
+ if self.wake_word in stt_result["text"]:
+ self.wake_word_detected = True
+ print("Wake word detected!")
+ self.pipeline.send_event(Gst.Event.new_eos())
+
+ self.audio_buffer.clear() # Clear the buffer
+
+
+ def send_eos(self):
+ self.pipeline.send_event(Gst.Event.new_eos())
+ self.audio_buffer.clear()
+
+
+ def start_listening(self):
+ self.pipeline.set_state(Gst.State.PLAYING)
+ print("Listening for Wake Word...")
+ self.loop.run()
+
+
+ def stop_listening(self):
+ self.cleanup_pipeline()
+ self.loop.quit()
+
+
+ def on_bus_message(self, bus, message):
+ if message.type == Gst.MessageType.EOS:
+ print("End-of-stream message received")
+ self.stop_listening()
+ elif message.type == Gst.MessageType.ERROR:
+ err, debug_info = message.parse_error()
+ print(f"Error received from element {message.src.get_name()}: {err.message}")
+ print(f"Debugging information: {debug_info}")
+ self.stop_listening()
+ elif message.type == Gst.MessageType.WARNING:
+ err, debug_info = message.parse_warning()
+ print(f"Warning received from element {message.src.get_name()}: {err.message}")
+ print(f"Debugging information: {debug_info}")
+ elif message.type == Gst.MessageType.STATE_CHANGED:
+ if isinstance(message.src, Gst.Pipeline):
+ old_state, new_state, pending_state = message.parse_state_changed()
+ print(("Pipeline state changed from %s to %s." %
+ (old_state.value_nick, new_state.value_nick)))
+
+
+ def cleanup_pipeline(self):
+ if self.pipeline is not None:
+ print("Cleaning up pipeline...")
+ self.pipeline.set_state(Gst.State.NULL)
+ self.bus.remove_signal_watch()
+ print("Pipeline cleanup complete!")
+ self.bus = None
+ self.pipeline = None
+ self.stt_model.cleanup_recognizer(self.recognizer_uuid)