diff options
Diffstat (limited to 'agl_service_voiceagent/nlu/rasa_interface.py')
-rw-r--r-- | agl_service_voiceagent/nlu/rasa_interface.py | 41 |
1 files changed, 41 insertions, 0 deletions
diff --git a/agl_service_voiceagent/nlu/rasa_interface.py b/agl_service_voiceagent/nlu/rasa_interface.py index 0232126..537a318 100644 --- a/agl_service_voiceagent/nlu/rasa_interface.py +++ b/agl_service_voiceagent/nlu/rasa_interface.py @@ -21,7 +21,20 @@ import subprocess from concurrent.futures import ThreadPoolExecutor class RASAInterface: + """ + RASAInterface is a class for interfacing with a Rasa NLU server to extract intents and entities from text input. + """ + def __init__(self, port, model_path, log_dir, max_threads=5): + """ + Initialize the RASAInterface instance with the provided parameters. + + Args: + port (int): The port number on which the Rasa NLU server will run. + model_path (str): The path to the Rasa NLU model. + log_dir (str): The directory where server logs will be saved. + max_threads (int, optional): The maximum number of concurrent threads (default is 5). + """ self.port = port self.model_path = model_path self.max_threads = max_threads @@ -31,6 +44,9 @@ class RASAInterface: def _start_server(self): + """ + Start the Rasa NLU server in a subprocess and redirect its output to the log file. + """ command = ( f"rasa run --enable-api -m \"{self.model_path}\" -p {self.port}" ) @@ -41,6 +57,9 @@ class RASAInterface: def start_server(self): + """ + Start the Rasa NLU server in a separate thread and wait for it to initialize. + """ self.thread_pool.submit(self._start_server) # Wait for a brief moment to allow the server to start @@ -48,6 +67,9 @@ class RASAInterface: def stop_server(self): + """ + Stop the Rasa NLU server and shut down the thread pool. + """ if self.server_process: self.server_process.terminate() self.server_process.wait() @@ -56,6 +78,16 @@ class RASAInterface: def preprocess_text(self, text): + """ + Preprocess the input text by converting it to lowercase, removing leading/trailing spaces, + and removing special characters and punctuation. + + Args: + text (str): The input text to preprocess. + + Returns: + str: The preprocessed text. + """ # text to lower case and remove trailing and leading spaces preprocessed_text = text.lower().strip() # remove special characters, punctuation, and extra whitespaces @@ -77,6 +109,15 @@ class RASAInterface: def process_intent(self, intent_output): + """ + Extract intents and entities from preprocessed text using the Rasa NLU server. + + Args: + text (str): The preprocessed input text. + + Returns: + dict: Intent and entity extraction result as a dictionary. + """ intent = intent_output["intent"]["name"] entities = {} for entity in intent_output["entities"]: |