diff options
Diffstat (limited to 'agl_service_voiceagent/utils/wake_word.py')
-rw-r--r-- | agl_service_voiceagent/utils/wake_word.py | 95 |
1 files changed, 79 insertions, 16 deletions
diff --git a/agl_service_voiceagent/utils/wake_word.py b/agl_service_voiceagent/utils/wake_word.py index 066ae6d..47e547e 100644 --- a/agl_service_voiceagent/utils/wake_word.py +++ b/agl_service_voiceagent/utils/wake_word.py @@ -15,7 +15,6 @@ # limitations under the License. import gi -import vosk gi.require_version('Gst', '1.0') from gi.repository import Gst, GLib @@ -23,7 +22,21 @@ Gst.init(None) GLib.threads_init() class WakeWordDetector: + """ + WakeWordDetector is a class for detecting a wake word in an audio stream using GStreamer and Vosk. + """ + def __init__(self, wake_word, stt_model, channels=1, sample_rate=16000, bits_per_sample=16): + """ + Initialize the WakeWordDetector instance with the provided parameters. + + Args: + wake_word (str): The wake word to detect in the audio stream. + stt_model (STTModel): An instance of the STTModel for speech-to-text recognition. + channels (int, optional): The number of audio channels (default is 1). + sample_rate (int, optional): The audio sample rate in Hz (default is 16000). + bits_per_sample (int, optional): The number of bits per sample (default is 16). + """ self.loop = GLib.MainLoop() self.pipeline = None self.bus = None @@ -32,16 +45,25 @@ class WakeWordDetector: self.sample_rate = sample_rate self.channels = channels self.bits_per_sample = bits_per_sample - self.frame_size = int(self.sample_rate * 0.02) - self.stt_model = stt_model # Speech to text model recognizer + self.wake_word_model = stt_model # Speech to text model recognizer self.recognizer_uuid = stt_model.setup_recognizer() - self.buffer_duration = 1 # Buffer audio for atleast 1 second self.audio_buffer = bytearray() + self.segment_size = int(self.sample_rate * 1.0) # Adjust the segment size (e.g., 1 second) + def get_wake_word_status(self): + """ + Get the status of wake word detection. + + Returns: + bool: True if the wake word has been detected, False otherwise. + """ return self.wake_word_detected def create_pipeline(self): + """ + Create and configure the GStreamer audio processing pipeline for wake word detection. + """ print("Creating pipeline for Wake Word Detection...") self.pipeline = Gst.Pipeline() autoaudiosrc = Gst.ElementFactory.make("autoaudiosrc", None) @@ -77,49 +99,87 @@ class WakeWordDetector: self.bus.add_signal_watch() self.bus.connect("message", self.on_bus_message) + def on_new_buffer(self, appsink, data) -> Gst.FlowReturn: + """ + Callback function to handle new audio buffers from GStreamer appsink. + + Args: + appsink (Gst.AppSink): The GStreamer appsink. + data (object): User data (not used). + + Returns: + Gst.FlowReturn: Indicates the status of buffer processing. + """ sample = appsink.emit("pull-sample") buffer = sample.get_buffer() data = buffer.extract_dup(0, buffer.get_size()) + + # Add the new data to the buffer self.audio_buffer.extend(data) - if len(self.audio_buffer) >= self.sample_rate * self.buffer_duration * self.channels * self.bits_per_sample // 8: - self.process_audio_buffer() + # Process audio in segments + while len(self.audio_buffer) >= self.segment_size: + segment = self.audio_buffer[:self.segment_size] + self.process_audio_segment(segment) + + # Advance the buffer by the segment size + self.audio_buffer = self.audio_buffer[self.segment_size:] return Gst.FlowReturn.OK - - def process_audio_buffer(self): - # Process the accumulated audio data using the audio model - audio_data = bytes(self.audio_buffer) - if self.stt_model.init_recognition(self.recognizer_uuid, audio_data): - stt_result = self.stt_model.recognize(self.recognizer_uuid) + def process_audio_segment(self, segment): + """ + Process an audio segment for wake word detection. + + Args: + segment (bytes): The audio segment to process. + """ + # Process the audio data segment + audio_data = bytes(segment) + + # Perform wake word detection on the audio_data + if self.wake_word_model.init_recognition(self.recognizer_uuid, audio_data): + stt_result = self.wake_word_model.recognize(self.recognizer_uuid) print("STT Result: ", stt_result) if self.wake_word in stt_result["text"]: self.wake_word_detected = True print("Wake word detected!") self.pipeline.send_event(Gst.Event.new_eos()) - self.audio_buffer.clear() # Clear the buffer - - def send_eos(self): + """ + Send an End-of-Stream (EOS) event to the pipeline. + """ self.pipeline.send_event(Gst.Event.new_eos()) self.audio_buffer.clear() def start_listening(self): + """ + Start listening for the wake word and enter the event loop. + """ self.pipeline.set_state(Gst.State.PLAYING) print("Listening for Wake Word...") self.loop.run() def stop_listening(self): + """ + Stop listening for the wake word and clean up the pipeline. + """ self.cleanup_pipeline() self.loop.quit() def on_bus_message(self, bus, message): + """ + Handle GStreamer bus messages and perform actions based on the message type. + + Args: + bus (Gst.Bus): The GStreamer bus. + message (Gst.Message): The GStreamer message to process. + """ if message.type == Gst.MessageType.EOS: print("End-of-stream message received") self.stop_listening() @@ -140,6 +200,9 @@ class WakeWordDetector: def cleanup_pipeline(self): + """ + Clean up the GStreamer pipeline and release associated resources. + """ if self.pipeline is not None: print("Cleaning up pipeline...") self.pipeline.set_state(Gst.State.NULL) @@ -147,4 +210,4 @@ class WakeWordDetector: print("Pipeline cleanup complete!") self.bus = None self.pipeline = None - self.stt_model.cleanup_recognizer(self.recognizer_uuid) + self.wake_word_model.cleanup_recognizer(self.recognizer_uuid) |