aboutsummaryrefslogtreecommitdiffstats
path: root/agl_service_voiceagent/nlu/rasa_interface.py
diff options
context:
space:
mode:
Diffstat (limited to 'agl_service_voiceagent/nlu/rasa_interface.py')
-rw-r--r--agl_service_voiceagent/nlu/rasa_interface.py41
1 files changed, 41 insertions, 0 deletions
diff --git a/agl_service_voiceagent/nlu/rasa_interface.py b/agl_service_voiceagent/nlu/rasa_interface.py
index 0232126..537a318 100644
--- a/agl_service_voiceagent/nlu/rasa_interface.py
+++ b/agl_service_voiceagent/nlu/rasa_interface.py
@@ -21,7 +21,20 @@ import subprocess
from concurrent.futures import ThreadPoolExecutor
class RASAInterface:
+ """
+ RASAInterface is a class for interfacing with a Rasa NLU server to extract intents and entities from text input.
+ """
+
def __init__(self, port, model_path, log_dir, max_threads=5):
+ """
+ Initialize the RASAInterface instance with the provided parameters.
+
+ Args:
+ port (int): The port number on which the Rasa NLU server will run.
+ model_path (str): The path to the Rasa NLU model.
+ log_dir (str): The directory where server logs will be saved.
+ max_threads (int, optional): The maximum number of concurrent threads (default is 5).
+ """
self.port = port
self.model_path = model_path
self.max_threads = max_threads
@@ -31,6 +44,9 @@ class RASAInterface:
def _start_server(self):
+ """
+ Start the Rasa NLU server in a subprocess and redirect its output to the log file.
+ """
command = (
f"rasa run --enable-api -m \"{self.model_path}\" -p {self.port}"
)
@@ -41,6 +57,9 @@ class RASAInterface:
def start_server(self):
+ """
+ Start the Rasa NLU server in a separate thread and wait for it to initialize.
+ """
self.thread_pool.submit(self._start_server)
# Wait for a brief moment to allow the server to start
@@ -48,6 +67,9 @@ class RASAInterface:
def stop_server(self):
+ """
+ Stop the Rasa NLU server and shut down the thread pool.
+ """
if self.server_process:
self.server_process.terminate()
self.server_process.wait()
@@ -56,6 +78,16 @@ class RASAInterface:
def preprocess_text(self, text):
+ """
+ Preprocess the input text by converting it to lowercase, removing leading/trailing spaces,
+ and removing special characters and punctuation.
+
+ Args:
+ text (str): The input text to preprocess.
+
+ Returns:
+ str: The preprocessed text.
+ """
# text to lower case and remove trailing and leading spaces
preprocessed_text = text.lower().strip()
# remove special characters, punctuation, and extra whitespaces
@@ -77,6 +109,15 @@ class RASAInterface:
def process_intent(self, intent_output):
+ """
+ Extract intents and entities from preprocessed text using the Rasa NLU server.
+
+ Args:
+ text (str): The preprocessed input text.
+
+ Returns:
+ dict: Intent and entity extraction result as a dictionary.
+ """
intent = intent_output["intent"]["name"]
entities = {}
for entity in intent_output["entities"]: