aboutsummaryrefslogtreecommitdiffstats
path: root/agl_service_voiceagent/utils/wake_word.py
diff options
context:
space:
mode:
Diffstat (limited to 'agl_service_voiceagent/utils/wake_word.py')
-rw-r--r--agl_service_voiceagent/utils/wake_word.py95
1 files changed, 79 insertions, 16 deletions
diff --git a/agl_service_voiceagent/utils/wake_word.py b/agl_service_voiceagent/utils/wake_word.py
index 066ae6d..47e547e 100644
--- a/agl_service_voiceagent/utils/wake_word.py
+++ b/agl_service_voiceagent/utils/wake_word.py
@@ -15,7 +15,6 @@
# limitations under the License.
import gi
-import vosk
gi.require_version('Gst', '1.0')
from gi.repository import Gst, GLib
@@ -23,7 +22,21 @@ Gst.init(None)
GLib.threads_init()
class WakeWordDetector:
+ """
+ WakeWordDetector is a class for detecting a wake word in an audio stream using GStreamer and Vosk.
+ """
+
def __init__(self, wake_word, stt_model, channels=1, sample_rate=16000, bits_per_sample=16):
+ """
+ Initialize the WakeWordDetector instance with the provided parameters.
+
+ Args:
+ wake_word (str): The wake word to detect in the audio stream.
+ stt_model (STTModel): An instance of the STTModel for speech-to-text recognition.
+ channels (int, optional): The number of audio channels (default is 1).
+ sample_rate (int, optional): The audio sample rate in Hz (default is 16000).
+ bits_per_sample (int, optional): The number of bits per sample (default is 16).
+ """
self.loop = GLib.MainLoop()
self.pipeline = None
self.bus = None
@@ -32,16 +45,25 @@ class WakeWordDetector:
self.sample_rate = sample_rate
self.channels = channels
self.bits_per_sample = bits_per_sample
- self.frame_size = int(self.sample_rate * 0.02)
- self.stt_model = stt_model # Speech to text model recognizer
+ self.wake_word_model = stt_model # Speech to text model recognizer
self.recognizer_uuid = stt_model.setup_recognizer()
- self.buffer_duration = 1 # Buffer audio for atleast 1 second
self.audio_buffer = bytearray()
+ self.segment_size = int(self.sample_rate * 1.0) # Adjust the segment size (e.g., 1 second)
+
def get_wake_word_status(self):
+ """
+ Get the status of wake word detection.
+
+ Returns:
+ bool: True if the wake word has been detected, False otherwise.
+ """
return self.wake_word_detected
def create_pipeline(self):
+ """
+ Create and configure the GStreamer audio processing pipeline for wake word detection.
+ """
print("Creating pipeline for Wake Word Detection...")
self.pipeline = Gst.Pipeline()
autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None)
@@ -77,49 +99,87 @@ class WakeWordDetector:
self.bus.add_signal_watch()
self.bus.connect("message", self.on_bus_message)
+
def on_new_buffer(self, appsink, data) -> Gst.FlowReturn:
+ """
+ Callback function to handle new audio buffers from GStreamer appsink.
+
+ Args:
+ appsink (Gst.AppSink): The GStreamer appsink.
+ data (object): User data (not used).
+
+ Returns:
+ Gst.FlowReturn: Indicates the status of buffer processing.
+ """
sample = appsink.emit("pull-sample")
buffer = sample.get_buffer()
data = buffer.extract_dup(0, buffer.get_size())
+
+ # Add the new data to the buffer
self.audio_buffer.extend(data)
- if len(self.audio_buffer) >= self.sample_rate * self.buffer_duration * self.channels * self.bits_per_sample // 8:
- self.process_audio_buffer()
+ # Process audio in segments
+ while len(self.audio_buffer) >= self.segment_size:
+ segment = self.audio_buffer[:self.segment_size]
+ self.process_audio_segment(segment)
+
+ # Advance the buffer by the segment size
+ self.audio_buffer = self.audio_buffer[self.segment_size:]
return Gst.FlowReturn.OK
-
- def process_audio_buffer(self):
- # Process the accumulated audio data using the audio model
- audio_data = bytes(self.audio_buffer)
- if self.stt_model.init_recognition(self.recognizer_uuid, audio_data):
- stt_result = self.stt_model.recognize(self.recognizer_uuid)
+ def process_audio_segment(self, segment):
+ """
+ Process an audio segment for wake word detection.
+
+ Args:
+ segment (bytes): The audio segment to process.
+ """
+ # Process the audio data segment
+ audio_data = bytes(segment)
+
+ # Perform wake word detection on the audio_data
+ if self.wake_word_model.init_recognition(self.recognizer_uuid, audio_data):
+ stt_result = self.wake_word_model.recognize(self.recognizer_uuid)
print("STT Result: ", stt_result)
if self.wake_word in stt_result["text"]:
self.wake_word_detected = True
print("Wake word detected!")
self.pipeline.send_event(Gst.Event.new_eos())
- self.audio_buffer.clear() # Clear the buffer
-
-
def send_eos(self):
+ """
+ Send an End-of-Stream (EOS) event to the pipeline.
+ """
self.pipeline.send_event(Gst.Event.new_eos())
self.audio_buffer.clear()
def start_listening(self):
+ """
+ Start listening for the wake word and enter the event loop.
+ """
self.pipeline.set_state(Gst.State.PLAYING)
print("Listening for Wake Word...")
self.loop.run()
def stop_listening(self):
+ """
+ Stop listening for the wake word and clean up the pipeline.
+ """
self.cleanup_pipeline()
self.loop.quit()
def on_bus_message(self, bus, message):
+ """
+ Handle GStreamer bus messages and perform actions based on the message type.
+
+ Args:
+ bus (Gst.Bus): The GStreamer bus.
+ message (Gst.Message): The GStreamer message to process.
+ """
if message.type == Gst.MessageType.EOS:
print("End-of-stream message received")
self.stop_listening()
@@ -140,6 +200,9 @@ class WakeWordDetector:
def cleanup_pipeline(self):
+ """
+ Clean up the GStreamer pipeline and release associated resources.
+ """
if self.pipeline is not None:
print("Cleaning up pipeline...")
self.pipeline.set_state(Gst.State.NULL)
@@ -147,4 +210,4 @@ class WakeWordDetector:
print("Pipeline cleanup complete!")
self.bus = None
self.pipeline = None
- self.stt_model.cleanup_recognizer(self.recognizer_uuid)
+ self.wake_word_model.cleanup_recognizer(self.recognizer_uuid)