aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/nlu_engine
diff options
context:
space:
mode:
Diffstat (limited to 'snips_inference_agl/nlu_engine')
-rw-r--r--snips_inference_agl/nlu_engine/__init__.py1
-rw-r--r--snips_inference_agl/nlu_engine/nlu_engine.py330
2 files changed, 331 insertions, 0 deletions
diff --git a/snips_inference_agl/nlu_engine/__init__.py b/snips_inference_agl/nlu_engine/__init__.py
new file mode 100644
index 0000000..b45eb32
--- /dev/null
+++ b/snips_inference_agl/nlu_engine/__init__.py
@@ -0,0 +1 @@
+from snips_inference_agl.nlu_engine.nlu_engine import SnipsNLUEngine
diff --git a/snips_inference_agl/nlu_engine/nlu_engine.py b/snips_inference_agl/nlu_engine/nlu_engine.py
new file mode 100644
index 0000000..3f19aba
--- /dev/null
+++ b/snips_inference_agl/nlu_engine/nlu_engine.py
@@ -0,0 +1,330 @@
+from __future__ import unicode_literals
+
+import json
+import logging
+from builtins import str
+from pathlib import Path
+
+from future.utils import itervalues
+
+from snips_inference_agl.__about__ import __model_version__, __version__
+from snips_inference_agl.common.log_utils import log_elapsed_time
+from snips_inference_agl.common.utils import (fitted_required)
+from snips_inference_agl.constants import (
+ AUTOMATICALLY_EXTENSIBLE, BUILTIN_ENTITY_PARSER, CUSTOM_ENTITY_PARSER,
+ ENTITIES, ENTITY_KIND, LANGUAGE, RESOLVED_VALUE, RES_ENTITY,
+ RES_INTENT, RES_INTENT_NAME, RES_MATCH_RANGE, RES_PROBA, RES_SLOTS,
+ RES_VALUE, RESOURCES, BYPASS_VERSION_CHECK)
+# from snips_inference_agl.dataset import validate_and_format_dataset
+from snips_inference_agl.entity_parser import CustomEntityParser
+from snips_inference_agl.entity_parser.builtin_entity_parser import (
+ BuiltinEntityParser, is_builtin_entity)
+from snips_inference_agl.exceptions import (
+ InvalidInputError, IntentNotFoundError, LoadingError,
+ IncompatibleModelError)
+from snips_inference_agl.intent_parser import IntentParser
+from snips_inference_agl.pipeline.configs import NLUEngineConfig
+from snips_inference_agl.pipeline.processing_unit import ProcessingUnit
+from snips_inference_agl.resources import load_resources_from_dir
+from snips_inference_agl.result import (
+ builtin_slot, custom_slot, empty_result, extraction_result, is_empty,
+ parsing_result)
+
+logger = logging.getLogger(__name__)
+
+
+@ProcessingUnit.register("nlu_engine")
+class SnipsNLUEngine(ProcessingUnit):
+ """Main class to use for intent parsing
+
+ A :class:`SnipsNLUEngine` relies on a list of :class:`.IntentParser`
+ object to parse intents, by calling them successively using the first
+ positive output.
+
+ With the default parameters, it will use the two following intent parsers
+ in this order:
+
+ - a :class:`.DeterministicIntentParser`
+ - a :class:`.ProbabilisticIntentParser`
+
+ The logic behind is to first use a conservative parser which has a very
+ good precision while its recall is modest, so simple patterns will be
+ caught, and then fallback on a second parser which is machine-learning
+ based and will be able to parse unseen utterances while ensuring a good
+ precision and recall.
+ """
+
+ config_type = NLUEngineConfig
+
+ def __init__(self, config=None, **shared):
+ """The NLU engine can be configured by passing a
+ :class:`.NLUEngineConfig`"""
+ super(SnipsNLUEngine, self).__init__(config, **shared)
+ self.intent_parsers = []
+ """list of :class:`.IntentParser`"""
+ self.dataset_metadata = None
+
+ @classmethod
+ def default_config(cls):
+ # Do not use the global default config, and use per-language default
+ # configs instead
+ return None
+
+ @property
+ def fitted(self):
+ """Whether or not the nlu engine has already been fitted"""
+ return self.dataset_metadata is not None
+
+ @log_elapsed_time(logger, logging.DEBUG, "Parsed input in {elapsed_time}")
+ @fitted_required
+ def parse(self, text, intents=None, top_n=None):
+ """Performs intent parsing on the provided *text* by calling its intent
+ parsers successively
+
+ Args:
+ text (str): Input
+ intents (str or list of str, optional): If provided, reduces the
+ scope of intent parsing to the provided list of intents.
+ The ``None`` intent is never filtered out, meaning that it can
+ be returned even when using an intents scope.
+ top_n (int, optional): when provided, this method will return a
+ list of at most ``top_n`` most likely intents, instead of a
+ single parsing result.
+ Note that the returned list can contain less than ``top_n``
+ elements, for instance when the parameter ``intents`` is not
+ None, or when ``top_n`` is greater than the total number of
+ intents.
+
+ Returns:
+ dict or list: the most likely intent(s) along with the extracted
+ slots. See :func:`.parsing_result` and :func:`.extraction_result`
+ for the output format.
+
+ Raises:
+ NotTrained: When the nlu engine is not fitted
+ InvalidInputError: When input type is not unicode
+ """
+ if not isinstance(text, str):
+ raise InvalidInputError("Expected unicode but received: %s"
+ % type(text))
+
+ if isinstance(intents, str):
+ intents = {intents}
+ elif isinstance(intents, list):
+ intents = set(intents)
+
+ if intents is not None:
+ for intent in intents:
+ if intent not in self.dataset_metadata["slot_name_mappings"]:
+ raise IntentNotFoundError(intent)
+
+ if top_n is None:
+ none_proba = 0.0
+ for parser in self.intent_parsers:
+ res = parser.parse(text, intents)
+ if is_empty(res):
+ none_proba = res[RES_INTENT][RES_PROBA]
+ continue
+ resolved_slots = self._resolve_slots(text, res[RES_SLOTS])
+ return parsing_result(text, intent=res[RES_INTENT],
+ slots=resolved_slots)
+ return empty_result(text, none_proba)
+
+ intents_results = self.get_intents(text)
+ if intents is not None:
+ intents_results = [res for res in intents_results
+ if res[RES_INTENT_NAME] is None
+ or res[RES_INTENT_NAME] in intents]
+ intents_results = intents_results[:top_n]
+ results = []
+ for intent_res in intents_results:
+ slots = self.get_slots(text, intent_res[RES_INTENT_NAME])
+ results.append(extraction_result(intent_res, slots))
+ return results
+
+ @log_elapsed_time(logger, logging.DEBUG, "Got intents in {elapsed_time}")
+ @fitted_required
+ def get_intents(self, text):
+ """Performs intent classification on the provided *text* and returns
+ the list of intents ordered by decreasing probability
+
+ The length of the returned list is exactly the number of intents in the
+ dataset + 1 for the None intent
+
+ .. note::
+
+ The probabilities returned along with each intent are not
+ guaranteed to sum to 1.0. They should be considered as scores
+ between 0 and 1.
+ """
+ results = None
+ for parser in self.intent_parsers:
+ parser_results = parser.get_intents(text)
+ if results is None:
+ results = {res[RES_INTENT_NAME]: res for res in parser_results}
+ continue
+
+ for res in parser_results:
+ intent = res[RES_INTENT_NAME]
+ proba = max(res[RES_PROBA], results[intent][RES_PROBA])
+ results[intent][RES_PROBA] = proba
+
+ return sorted(itervalues(results), key=lambda res: -res[RES_PROBA])
+
+ @log_elapsed_time(logger, logging.DEBUG, "Parsed slots in {elapsed_time}")
+ @fitted_required
+ def get_slots(self, text, intent):
+ """Extracts slots from a text input, with the knowledge of the intent
+
+ Args:
+ text (str): input
+ intent (str): the intent which the input corresponds to
+
+ Returns:
+ list: the list of extracted slots
+
+ Raises:
+ IntentNotFoundError: When the intent was not part of the training
+ data
+ InvalidInputError: When input type is not unicode
+ """
+ if not isinstance(text, str):
+ raise InvalidInputError("Expected unicode but received: %s"
+ % type(text))
+
+ if intent is None:
+ return []
+
+ if intent not in self.dataset_metadata["slot_name_mappings"]:
+ raise IntentNotFoundError(intent)
+
+ for parser in self.intent_parsers:
+ slots = parser.get_slots(text, intent)
+ if not slots:
+ continue
+ return self._resolve_slots(text, slots)
+ return []
+
+
+ @classmethod
+ def from_path(cls, path, **shared):
+ """Loads a :class:`SnipsNLUEngine` instance from a directory path
+
+ The data at the given path must have been generated using
+ :func:`~SnipsNLUEngine.persist`
+
+ Args:
+ path (str): The path where the nlu engine is stored
+
+ Raises:
+ LoadingError: when some files are missing
+ IncompatibleModelError: when trying to load an engine model which
+ is not compatible with the current version of the lib
+ """
+ directory_path = Path(path)
+ model_path = directory_path / "nlu_engine.json"
+ if not model_path.exists():
+ raise LoadingError("Missing nlu engine model file: %s"
+ % model_path.name)
+
+ with model_path.open(encoding="utf8") as f:
+ model = json.load(f)
+ model_version = model.get("model_version")
+ if model_version is None or model_version != __model_version__:
+ bypass_version_check = shared.get(BYPASS_VERSION_CHECK, False)
+ if bypass_version_check:
+ logger.warning(
+ "Incompatible model version found. The library expected "
+ "'%s' but the loaded engine is '%s'. The NLU engine may "
+ "not load correctly.", __model_version__, model_version)
+ else:
+ raise IncompatibleModelError(model_version)
+
+ dataset_metadata = model["dataset_metadata"]
+ if shared.get(RESOURCES) is None and dataset_metadata is not None:
+ language = dataset_metadata["language_code"]
+ resources_dir = directory_path / "resources" / language
+ if resources_dir.is_dir():
+ resources = load_resources_from_dir(resources_dir)
+ shared[RESOURCES] = resources
+
+ if shared.get(BUILTIN_ENTITY_PARSER) is None:
+ path = model["builtin_entity_parser"]
+ if path is not None:
+ parser_path = directory_path / path
+ shared[BUILTIN_ENTITY_PARSER] = BuiltinEntityParser.from_path(
+ parser_path)
+
+ if shared.get(CUSTOM_ENTITY_PARSER) is None:
+ path = model["custom_entity_parser"]
+ if path is not None:
+ parser_path = directory_path / path
+ shared[CUSTOM_ENTITY_PARSER] = CustomEntityParser.from_path(
+ parser_path)
+
+ config = cls.config_type.from_dict(model["config"])
+ nlu_engine = cls(config=config, **shared)
+ nlu_engine.dataset_metadata = dataset_metadata
+ intent_parsers = []
+ for parser_idx, parser_name in enumerate(model["intent_parsers"]):
+ parser_config = config.intent_parsers_configs[parser_idx]
+ intent_parser_path = directory_path / parser_name
+ intent_parser = IntentParser.load_from_path(
+ intent_parser_path, parser_config.unit_name, **shared)
+ intent_parsers.append(intent_parser)
+ nlu_engine.intent_parsers = intent_parsers
+ return nlu_engine
+
+ def _resolve_slots(self, text, slots):
+ builtin_scope = [slot[RES_ENTITY] for slot in slots
+ if is_builtin_entity(slot[RES_ENTITY])]
+ custom_scope = [slot[RES_ENTITY] for slot in slots
+ if not is_builtin_entity(slot[RES_ENTITY])]
+ # Do not use cached entities here as datetimes must be computed using
+ # current context
+ builtin_entities = self.builtin_entity_parser.parse(
+ text, builtin_scope, use_cache=False)
+ custom_entities = self.custom_entity_parser.parse(
+ text, custom_scope, use_cache=True)
+
+ resolved_slots = []
+ for slot in slots:
+ entity_name = slot[RES_ENTITY]
+ raw_value = slot[RES_VALUE]
+ is_builtin = is_builtin_entity(entity_name)
+ if is_builtin:
+ entities = builtin_entities
+ parser = self.builtin_entity_parser
+ slot_builder = builtin_slot
+ use_cache = False
+ extensible = False
+ else:
+ entities = custom_entities
+ parser = self.custom_entity_parser
+ slot_builder = custom_slot
+ use_cache = True
+ extensible = self.dataset_metadata[ENTITIES][entity_name][
+ AUTOMATICALLY_EXTENSIBLE]
+
+ resolved_slot = None
+ for ent in entities:
+ if ent[ENTITY_KIND] == entity_name and \
+ ent[RES_MATCH_RANGE] == slot[RES_MATCH_RANGE]:
+ resolved_slot = slot_builder(slot, ent[RESOLVED_VALUE])
+ break
+ if resolved_slot is None:
+ matches = parser.parse(
+ raw_value, scope=[entity_name], use_cache=use_cache)
+ if matches:
+ match = matches[0]
+ if is_builtin or len(match[RES_VALUE]) == len(raw_value):
+ resolved_slot = slot_builder(
+ slot, match[RESOLVED_VALUE])
+
+ if resolved_slot is None and extensible:
+ resolved_slot = slot_builder(slot)
+
+ if resolved_slot is not None:
+ resolved_slots.append(resolved_slot)
+
+ return resolved_slots \ No newline at end of file