aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/intent_parser/lookup_intent_parser.py
diff options
context:
space:
mode:
Diffstat (limited to 'snips_inference_agl/intent_parser/lookup_intent_parser.py')
-rw-r--r--snips_inference_agl/intent_parser/lookup_intent_parser.py509
1 files changed, 509 insertions, 0 deletions
diff --git a/snips_inference_agl/intent_parser/lookup_intent_parser.py b/snips_inference_agl/intent_parser/lookup_intent_parser.py
new file mode 100644
index 0000000..921dcc5
--- /dev/null
+++ b/snips_inference_agl/intent_parser/lookup_intent_parser.py
@@ -0,0 +1,509 @@
+from __future__ import unicode_literals
+
+import json
+import logging
+from builtins import str
+from collections import defaultdict
+from itertools import combinations
+from pathlib import Path
+
+from future.utils import iteritems, itervalues
+from snips_nlu_utils import normalize, hash_str
+
+from snips_inference_agl.common.log_utils import log_elapsed_time, log_result
+from snips_inference_agl.common.utils import (
+ check_persisted_path, deduplicate_overlapping_entities, fitted_required,
+ json_string)
+from snips_inference_agl.constants import (
+ DATA, END, ENTITIES, ENTITY, ENTITY_KIND, INTENTS, LANGUAGE, RES_INTENT,
+ RES_INTENT_NAME, RES_MATCH_RANGE, RES_SLOTS, SLOT_NAME, START, TEXT,
+ UTTERANCES, RES_PROBA)
+from snips_inference_agl.dataset import (
+ validate_and_format_dataset, extract_intent_entities)
+from snips_inference_agl.dataset.utils import get_stop_words_whitelist
+from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity
+from snips_inference_agl.exceptions import IntentNotFoundError, LoadingError
+from snips_inference_agl.intent_parser.intent_parser import IntentParser
+from snips_inference_agl.pipeline.configs import LookupIntentParserConfig
+from snips_inference_agl.preprocessing import tokenize_light
+from snips_inference_agl.resources import get_stop_words
+from snips_inference_agl.result import (
+ empty_result, intent_classification_result, parsing_result,
+ unresolved_slot, extraction_result)
+
+logger = logging.getLogger(__name__)
+
+
+@IntentParser.register("lookup_intent_parser")
+class LookupIntentParser(IntentParser):
+ """A deterministic Intent parser implementation based on a dictionary
+
+ This intent parser is very strict by nature, and tends to have a very good
+ precision but a low recall. For this reason, it is interesting to use it
+ first before potentially falling back to another parser.
+ """
+
+ config_type = LookupIntentParserConfig
+
+ def __init__(self, config=None, **shared):
+ """The lookup intent parser can be configured by passing a
+ :class:`.LookupIntentParserConfig`"""
+ super(LookupIntentParser, self).__init__(config, **shared)
+ self._language = None
+ self._stop_words = None
+ self._stop_words_whitelist = None
+ self._map = None
+ self._intents_names = []
+ self._slots_names = []
+ self._intents_mapping = dict()
+ self._slots_mapping = dict()
+ self._entity_scopes = None
+
+ @property
+ def language(self):
+ return self._language
+
+ @language.setter
+ def language(self, value):
+ self._language = value
+ if value is None:
+ self._stop_words = None
+ else:
+ if self.config.ignore_stop_words:
+ self._stop_words = get_stop_words(self.resources)
+ else:
+ self._stop_words = set()
+
+ @property
+ def fitted(self):
+ """Whether or not the intent parser has already been trained"""
+ return self._map is not None
+
+ @log_elapsed_time(
+ logger, logging.INFO, "Fitted lookup intent parser in {elapsed_time}")
+ def fit(self, dataset, force_retrain=True):
+ """Fits the intent parser with a valid Snips dataset"""
+ logger.info("Fitting lookup intent parser...")
+ dataset = validate_and_format_dataset(dataset)
+ self.load_resources_if_needed(dataset[LANGUAGE])
+ self.fit_builtin_entity_parser_if_needed(dataset)
+ self.fit_custom_entity_parser_if_needed(dataset)
+ self.language = dataset[LANGUAGE]
+ self._entity_scopes = _get_entity_scopes(dataset)
+ self._map = dict()
+ self._stop_words_whitelist = get_stop_words_whitelist(
+ dataset, self._stop_words)
+ entity_placeholders = _get_entity_placeholders(dataset, self.language)
+
+ ambiguous_keys = set()
+ for (key, val) in self._generate_io_mapping(dataset[INTENTS],
+ entity_placeholders):
+ key = hash_str(key)
+ # handle key collisions -*- flag ambiguous entries -*-
+ if key in self._map and self._map[key] != val:
+ ambiguous_keys.add(key)
+ else:
+ self._map[key] = val
+
+ # delete ambiguous keys
+ for key in ambiguous_keys:
+ self._map.pop(key)
+
+ return self
+
+ @log_result(logger, logging.DEBUG, "LookupIntentParser result -> {result}")
+ @log_elapsed_time(logger, logging.DEBUG, "Parsed in {elapsed_time}.")
+ @fitted_required
+ def parse(self, text, intents=None, top_n=None):
+ """Performs intent parsing on the provided *text*
+
+ Intent and slots are extracted simultaneously through pattern matching
+
+ Args:
+ text (str): input
+ intents (str or list of str): if provided, reduces the scope of
+ intent parsing to the provided list of intents
+ top_n (int, optional): when provided, this method will return a
+ list of at most top_n most likely intents, instead of a single
+ parsing result.
+ Note that the returned list can contain less than ``top_n``
+ elements, for instance when the parameter ``intents`` is not
+ None, or when ``top_n`` is greater than the total number of
+ intents.
+
+ Returns:
+ dict or list: the most likely intent(s) along with the extracted
+ slots. See :func:`.parsing_result` and :func:`.extraction_result`
+ for the output format.
+
+ Raises:
+ NotTrained: when the intent parser is not fitted
+ """
+ if top_n is None:
+ top_intents = self._parse_top_intents(text, top_n=1,
+ intents=intents)
+ if top_intents:
+ intent = top_intents[0][RES_INTENT]
+ slots = top_intents[0][RES_SLOTS]
+ if intent[RES_PROBA] <= 0.5:
+ # return None in case of ambiguity
+ return empty_result(text, probability=1.0)
+ return parsing_result(text, intent, slots)
+ return empty_result(text, probability=1.0)
+ return self._parse_top_intents(text, top_n=top_n, intents=intents)
+
+ def _parse_top_intents(self, text, top_n, intents=None):
+ if isinstance(intents, str):
+ intents = {intents}
+ elif isinstance(intents, list):
+ intents = set(intents)
+
+ if top_n < 1:
+ raise ValueError(
+ "top_n argument must be greater or equal to 1, but got: %s"
+ % top_n)
+
+ results_per_intent = defaultdict(list)
+ for text_candidate, entities in self._get_candidates(text, intents):
+ val = self._map.get(hash_str(text_candidate))
+ if val is not None:
+ result = self._parse_map_output(text, val, entities, intents)
+ if result:
+ intent_name = result[RES_INTENT][RES_INTENT_NAME]
+ results_per_intent[intent_name].append(result)
+
+ results = []
+ for intent_results in itervalues(results_per_intent):
+ sorted_results = sorted(intent_results,
+ key=lambda res: len(res[RES_SLOTS]))
+ results.append(sorted_results[0])
+
+ # In some rare cases there can be multiple ambiguous intents
+ # In such cases, priority is given to results containing fewer slots
+ weights = [1.0 / (1.0 + len(res[RES_SLOTS])) for res in results]
+ total_weight = sum(weights)
+
+ for res, weight in zip(results, weights):
+ res[RES_INTENT][RES_PROBA] = weight / total_weight
+
+ results = sorted(results, key=lambda r: -r[RES_INTENT][RES_PROBA])
+ return results[:top_n]
+
+ def _get_candidates(self, text, intents):
+ candidates = defaultdict(list)
+ for grouped_entity_scope in self._entity_scopes:
+ entity_scope = grouped_entity_scope["entity_scope"]
+ intent_group = grouped_entity_scope["intent_group"]
+ intent_group = [intent_ for intent_ in intent_group
+ if intents is None or intent_ in intents]
+ if not intent_group:
+ continue
+
+ builtin_entities = self.builtin_entity_parser.parse(
+ text, scope=entity_scope["builtin"], use_cache=True)
+ custom_entities = self.custom_entity_parser.parse(
+ text, scope=entity_scope["custom"], use_cache=True)
+ all_entities = builtin_entities + custom_entities
+ all_entities = deduplicate_overlapping_entities(all_entities)
+
+ # We generate all subsets of entities to match utterances
+ # containing ambivalent words which can be both entity values or
+ # random words
+ for entities in _get_entities_combinations(all_entities):
+ processed_text = self._replace_entities_with_placeholders(
+ text, entities)
+ for intent in intent_group:
+ cleaned_text = self._preprocess_text(text, intent)
+ cleaned_processed_text = self._preprocess_text(
+ processed_text, intent)
+
+ raw_candidate = cleaned_text, []
+ placeholder_candidate = cleaned_processed_text, entities
+ intent_candidates = [raw_candidate, placeholder_candidate]
+ for text_input, text_entities in intent_candidates:
+ if text_input not in candidates \
+ or text_entities not in candidates[text_input]:
+ candidates[text_input].append(text_entities)
+ yield text_input, text_entities
+
+ def _parse_map_output(self, text, output, entities, intents):
+ """Parse the map output to the parser's result format"""
+ intent_id, slot_ids = output
+ intent_name = self._intents_names[intent_id]
+ if intents is not None and intent_name not in intents:
+ return None
+
+ parsed_intent = intent_classification_result(
+ intent_name=intent_name, probability=1.0)
+ slots = []
+ # assert invariant
+ assert len(slot_ids) == len(entities)
+ for slot_id, entity in zip(slot_ids, entities):
+ slot_name = self._slots_names[slot_id]
+ rng_start = entity[RES_MATCH_RANGE][START]
+ rng_end = entity[RES_MATCH_RANGE][END]
+ slot_value = text[rng_start:rng_end]
+ entity_name = entity[ENTITY_KIND]
+ slot = unresolved_slot(
+ [rng_start, rng_end], slot_value, entity_name, slot_name)
+ slots.append(slot)
+
+ return extraction_result(parsed_intent, slots)
+
+ @fitted_required
+ def get_intents(self, text):
+ """Returns the list of intents ordered by decreasing probability
+
+ The length of the returned list is exactly the number of intents in the
+ dataset + 1 for the None intent
+ """
+ nb_intents = len(self._intents_names)
+ top_intents = [intent_result[RES_INTENT] for intent_result in
+ self._parse_top_intents(text, top_n=nb_intents)]
+ matched_intents = {res[RES_INTENT_NAME] for res in top_intents}
+ for intent in self._intents_names:
+ if intent not in matched_intents:
+ top_intents.append(intent_classification_result(intent, 0.0))
+
+ # The None intent is not included in the lookup table and is thus
+ # never matched by the lookup parser
+ top_intents.append(intent_classification_result(None, 0.0))
+ return top_intents
+
+ @fitted_required
+ def get_slots(self, text, intent):
+ """Extracts slots from a text input, with the knowledge of the intent
+
+ Args:
+ text (str): input
+ intent (str): the intent which the input corresponds to
+
+ Returns:
+ list: the list of extracted slots
+
+ Raises:
+ IntentNotFoundError: When the intent was not part of the training
+ data
+ """
+ if intent is None:
+ return []
+
+ if intent not in self._intents_names:
+ raise IntentNotFoundError(intent)
+
+ slots = self.parse(text, intents=[intent])[RES_SLOTS]
+ if slots is None:
+ slots = []
+ return slots
+
+ def _get_intent_stop_words(self, intent):
+ whitelist = self._stop_words_whitelist.get(intent, set())
+ return self._stop_words.difference(whitelist)
+
+ def _get_intent_id(self, intent_name):
+ """generate a numeric id for an intent
+
+ Args:
+ intent_name (str): intent name
+
+ Returns:
+ int: numeric id
+
+ """
+ intent_id = self._intents_mapping.get(intent_name)
+ if intent_id is None:
+ intent_id = len(self._intents_names)
+ self._intents_names.append(intent_name)
+ self._intents_mapping[intent_name] = intent_id
+
+ return intent_id
+
+ def _get_slot_id(self, slot_name):
+ """generate a numeric id for a slot
+
+ Args:
+ slot_name (str): intent name
+
+ Returns:
+ int: numeric id
+
+ """
+ slot_id = self._slots_mapping.get(slot_name)
+ if slot_id is None:
+ slot_id = len(self._slots_names)
+ self._slots_names.append(slot_name)
+ self._slots_mapping[slot_name] = slot_id
+
+ return slot_id
+
+ def _preprocess_text(self, txt, intent):
+ """Replaces stop words and characters that are tokenized out by
+ whitespaces"""
+ stop_words = self._get_intent_stop_words(intent)
+ tokens = tokenize_light(txt, self.language)
+ cleaned_string = " ".join(
+ [tkn for tkn in tokens if normalize(tkn) not in stop_words])
+ return cleaned_string.lower()
+
+ def _generate_io_mapping(self, intents, entity_placeholders):
+ """Generate input-output pairs"""
+ for intent_name, intent in sorted(iteritems(intents)):
+ intent_id = self._get_intent_id(intent_name)
+ for entry in intent[UTTERANCES]:
+ yield self._build_io_mapping(
+ intent_id, entry, entity_placeholders)
+
+ def _build_io_mapping(self, intent_id, utterance, entity_placeholders):
+ input_ = []
+ output = [intent_id]
+ slots = []
+ for chunk in utterance[DATA]:
+ if SLOT_NAME in chunk:
+ slot_name = chunk[SLOT_NAME]
+ slot_id = self._get_slot_id(slot_name)
+ entity_name = chunk[ENTITY]
+ placeholder = entity_placeholders[entity_name]
+ input_.append(placeholder)
+ slots.append(slot_id)
+ else:
+ input_.append(chunk[TEXT])
+ output.append(slots)
+
+ intent = self._intents_names[intent_id]
+ key = self._preprocess_text(" ".join(input_), intent)
+
+ return key, output
+
+ def _replace_entities_with_placeholders(self, text, entities):
+ if not entities:
+ return text
+ entities = sorted(entities, key=lambda e: e[RES_MATCH_RANGE][START])
+ processed_text = ""
+ current_idx = 0
+ for ent in entities:
+ start = ent[RES_MATCH_RANGE][START]
+ end = ent[RES_MATCH_RANGE][END]
+ processed_text += text[current_idx:start]
+ place_holder = _get_entity_name_placeholder(
+ ent[ENTITY_KIND], self.language)
+ processed_text += place_holder
+ current_idx = end
+ processed_text += text[current_idx:]
+
+ return processed_text
+
+ @check_persisted_path
+ def persist(self, path):
+ """Persists the object at the given path"""
+ path.mkdir()
+ parser_json = json_string(self.to_dict())
+ parser_path = path / "intent_parser.json"
+
+ with parser_path.open(mode="w", encoding="utf8") as pfile:
+ pfile.write(parser_json)
+ self.persist_metadata(path)
+
+ @classmethod
+ def from_path(cls, path, **shared):
+ """Loads a :class:`LookupIntentParser` instance from a path
+
+ The data at the given path must have been generated using
+ :func:`~LookupIntentParser.persist`
+ """
+ path = Path(path)
+ model_path = path / "intent_parser.json"
+ if not model_path.exists():
+ raise LoadingError(
+ "Missing lookup intent parser metadata file: %s"
+ % model_path.name)
+
+ with model_path.open(encoding="utf8") as pfile:
+ metadata = json.load(pfile)
+ return cls.from_dict(metadata, **shared)
+
+ def to_dict(self):
+ """Returns a json-serializable dict"""
+ stop_words_whitelist = None
+ if self._stop_words_whitelist is not None:
+ stop_words_whitelist = {
+ intent: sorted(values)
+ for intent, values in iteritems(self._stop_words_whitelist)}
+ return {
+ "config": self.config.to_dict(),
+ "language_code": self.language,
+ "map": self._map,
+ "slots_names": self._slots_names,
+ "intents_names": self._intents_names,
+ "entity_scopes": self._entity_scopes,
+ "stop_words_whitelist": stop_words_whitelist,
+ }
+
+ @classmethod
+ def from_dict(cls, unit_dict, **shared):
+ """Creates a :class:`LookupIntentParser` instance from a dict
+
+ The dict must have been generated with
+ :func:`~LookupIntentParser.to_dict`
+ """
+ config = cls.config_type.from_dict(unit_dict["config"])
+ parser = cls(config=config, **shared)
+ parser.language = unit_dict["language_code"]
+ # pylint:disable=protected-access
+ parser._map = _convert_dict_keys_to_int(unit_dict["map"])
+ parser._slots_names = unit_dict["slots_names"]
+ parser._intents_names = unit_dict["intents_names"]
+ parser._entity_scopes = unit_dict["entity_scopes"]
+ if parser.fitted:
+ whitelist = unit_dict["stop_words_whitelist"]
+ parser._stop_words_whitelist = {
+ intent: set(values) for intent, values in iteritems(whitelist)}
+ # pylint:enable=protected-access
+ return parser
+
+
+def _get_entity_scopes(dataset):
+ intent_entities = extract_intent_entities(dataset)
+ intent_groups = []
+ entity_scopes = []
+ for intent, entities in sorted(iteritems(intent_entities)):
+ scope = {
+ "builtin": list(
+ {ent for ent in entities if is_builtin_entity(ent)}),
+ "custom": list(
+ {ent for ent in entities if not is_builtin_entity(ent)})
+ }
+ if scope in entity_scopes:
+ group_idx = entity_scopes.index(scope)
+ intent_groups[group_idx].append(intent)
+ else:
+ entity_scopes.append(scope)
+ intent_groups.append([intent])
+ return [
+ {
+ "intent_group": intent_group,
+ "entity_scope": entity_scope
+ } for intent_group, entity_scope in zip(intent_groups, entity_scopes)
+ ]
+
+
+def _get_entity_placeholders(dataset, language):
+ return {
+ e: _get_entity_name_placeholder(e, language) for e in dataset[ENTITIES]
+ }
+
+
+def _get_entity_name_placeholder(entity_label, language):
+ return "%%%s%%" % "".join(tokenize_light(entity_label, language)).upper()
+
+
+def _convert_dict_keys_to_int(dct):
+ if isinstance(dct, dict):
+ return {int(k): v for k, v in iteritems(dct)}
+ return dct
+
+
+def _get_entities_combinations(entities):
+ yield ()
+ for nb_entities in reversed(range(1, len(entities) + 1)):
+ for combination in combinations(entities, nb_entities):
+ yield combination