diff options
Diffstat (limited to 'snips_inference_agl/entity_parser')
5 files changed, 485 insertions, 0 deletions
diff --git a/snips_inference_agl/entity_parser/__init__.py b/snips_inference_agl/entity_parser/__init__.py new file mode 100644 index 0000000..c54f0b2 --- /dev/null +++ b/snips_inference_agl/entity_parser/__init__.py @@ -0,0 +1,6 @@ +# coding=utf-8 +from __future__ import unicode_literals + +from snips_inference_agl.entity_parser.builtin_entity_parser import BuiltinEntityParser +from snips_inference_agl.entity_parser.custom_entity_parser import ( + CustomEntityParser, CustomEntityParserUsage) diff --git a/snips_inference_agl/entity_parser/builtin_entity_parser.py b/snips_inference_agl/entity_parser/builtin_entity_parser.py new file mode 100644 index 0000000..02fa610 --- /dev/null +++ b/snips_inference_agl/entity_parser/builtin_entity_parser.py @@ -0,0 +1,162 @@ +from __future__ import unicode_literals + +import json +import shutil + +from future.builtins import str + +from snips_inference_agl.common.io_utils import temp_dir +from snips_inference_agl.common.utils import json_string +from snips_inference_agl.constants import DATA_PATH, ENTITIES, LANGUAGE +from snips_inference_agl.entity_parser.entity_parser import EntityParser +from snips_inference_agl.result import parsed_entity + +_BUILTIN_ENTITY_PARSERS = dict() + +try: + FileNotFoundError +except NameError: + FileNotFoundError = IOError + + +class BuiltinEntityParser(EntityParser): + def __init__(self, parser): + super(BuiltinEntityParser, self).__init__() + self._parser = parser + + def _parse(self, text, scope=None): + entities = self._parser.parse(text.lower(), scope=scope) + result = [] + for entity in entities: + ent = parsed_entity( + entity_kind=entity["entity_kind"], + entity_value=entity["value"], + entity_resolved_value=entity["entity"], + entity_range=entity["range"] + ) + result.append(ent) + return result + + def persist(self, path): + self._parser.persist(path) + + @classmethod + def from_path(cls, path): + from snips_nlu_parsers import ( + BuiltinEntityParser as _BuiltinEntityParser) + + parser = _BuiltinEntityParser.from_path(path) + return cls(parser) + + @classmethod + def build(cls, dataset=None, language=None, gazetteer_entity_scope=None): + from snips_nlu_parsers import get_supported_gazetteer_entities + + global _BUILTIN_ENTITY_PARSERS + + if dataset is not None: + language = dataset[LANGUAGE] + gazetteer_entity_scope = [entity for entity in dataset[ENTITIES] + if is_gazetteer_entity(entity)] + + if language is None: + raise ValueError("Either a dataset or a language must be provided " + "in order to build a BuiltinEntityParser") + + if gazetteer_entity_scope is None: + gazetteer_entity_scope = [] + caching_key = _get_caching_key(language, gazetteer_entity_scope) + if caching_key not in _BUILTIN_ENTITY_PARSERS: + for entity in gazetteer_entity_scope: + if entity not in get_supported_gazetteer_entities(language): + raise ValueError( + "Gazetteer entity '%s' is not supported in " + "language '%s'" % (entity, language)) + _BUILTIN_ENTITY_PARSERS[caching_key] = _build_builtin_parser( + language, gazetteer_entity_scope) + return _BUILTIN_ENTITY_PARSERS[caching_key] + + +def _build_builtin_parser(language, gazetteer_entities): + from snips_nlu_parsers import BuiltinEntityParser as _BuiltinEntityParser + + with temp_dir() as serialization_dir: + gazetteer_entity_parser = None + if gazetteer_entities: + gazetteer_entity_parser = _build_gazetteer_parser( + serialization_dir, gazetteer_entities, language) + + metadata = { + "language": language.upper(), + "gazetteer_parser": gazetteer_entity_parser + } + metadata_path = serialization_dir / "metadata.json" + with metadata_path.open("w", encoding="utf-8") as f: + f.write(json_string(metadata)) + parser = _BuiltinEntityParser.from_path(serialization_dir) + return BuiltinEntityParser(parser) + + +def _build_gazetteer_parser(target_dir, gazetteer_entities, language): + from snips_nlu_parsers import get_builtin_entity_shortname + + gazetteer_parser_name = "gazetteer_entity_parser" + gazetteer_parser_path = target_dir / gazetteer_parser_name + gazetteer_parser_metadata = [] + for ent in sorted(gazetteer_entities): + # Fetch the compiled parser in the resources + source_parser_path = find_gazetteer_entity_data_path(language, ent) + short_name = get_builtin_entity_shortname(ent).lower() + target_parser_path = gazetteer_parser_path / short_name + parser_metadata = { + "entity_identifier": ent, + "entity_parser": short_name + } + gazetteer_parser_metadata.append(parser_metadata) + # Copy the single entity parser + shutil.copytree(str(source_parser_path), str(target_parser_path)) + # Dump the parser metadata + gazetteer_entity_parser_metadata = { + "parsers_metadata": gazetteer_parser_metadata + } + gazetteer_parser_metadata_path = gazetteer_parser_path / "metadata.json" + with gazetteer_parser_metadata_path.open("w", encoding="utf-8") as f: + f.write(json_string(gazetteer_entity_parser_metadata)) + return gazetteer_parser_name + + +def is_builtin_entity(entity_label): + from snips_nlu_parsers import get_all_builtin_entities + + return entity_label in get_all_builtin_entities() + + +def is_gazetteer_entity(entity_label): + from snips_nlu_parsers import get_all_gazetteer_entities + + return entity_label in get_all_gazetteer_entities() + + +def find_gazetteer_entity_data_path(language, entity_name): + for directory in DATA_PATH.iterdir(): + if not directory.is_dir(): + continue + metadata_path = directory / "metadata.json" + if not metadata_path.exists(): + continue + with metadata_path.open(encoding="utf8") as f: + metadata = json.load(f) + if metadata.get("entity_name") == entity_name \ + and metadata.get("language") == language: + return directory / metadata["data_directory"] + raise FileNotFoundError( + "No data found for the '{e}' builtin entity in language '{lang}'. " + "You must download the corresponding resources by running " + "'python -m snips_nlu download-entity {e} {lang}' before you can use " + "this builtin entity.".format(e=entity_name, lang=language)) + + +def _get_caching_key(language, entity_scope): + tuple_key = (language,) + tuple_key += tuple(entity for entity in sorted(entity_scope)) + return tuple_key diff --git a/snips_inference_agl/entity_parser/custom_entity_parser.py b/snips_inference_agl/entity_parser/custom_entity_parser.py new file mode 100644 index 0000000..949df1f --- /dev/null +++ b/snips_inference_agl/entity_parser/custom_entity_parser.py @@ -0,0 +1,209 @@ +# coding=utf-8 +from __future__ import unicode_literals + +import json +import operator +from copy import deepcopy +from pathlib import Path + +from future.utils import iteritems, viewvalues + +from snips_inference_agl.common.utils import json_string +from snips_inference_agl.constants import ( + END, ENTITIES, LANGUAGE, MATCHING_STRICTNESS, START, UTTERANCES, + LICENSE_INFO) +from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity +from snips_inference_agl.entity_parser.custom_entity_parser_usage import ( + CustomEntityParserUsage) +from snips_inference_agl.entity_parser.entity_parser import EntityParser +from snips_inference_agl.preprocessing import stem, tokenize, tokenize_light +from snips_inference_agl.result import parsed_entity + +STOPWORDS_FRACTION = 1e-3 + + +class CustomEntityParser(EntityParser): + def __init__(self, parser, language, parser_usage): + super(CustomEntityParser, self).__init__() + self._parser = parser + self.language = language + self.parser_usage = parser_usage + + def _parse(self, text, scope=None): + tokens = tokenize(text, self.language) + shifts = _compute_char_shifts(tokens) + cleaned_text = " ".join(token.value for token in tokens) + + entities = self._parser.parse(cleaned_text, scope) + result = [] + for entity in entities: + start = entity["range"]["start"] + start -= shifts[start] + end = entity["range"]["end"] + end -= shifts[end - 1] + entity_range = {START: start, END: end} + ent = parsed_entity( + entity_kind=entity["entity_identifier"], + entity_value=entity["value"], + entity_resolved_value=entity["resolved_value"], + entity_range=entity_range + ) + result.append(ent) + return result + + def persist(self, path): + path = Path(path) + path.mkdir() + parser_directory = "parser" + metadata = { + "language": self.language, + "parser_usage": self.parser_usage.value, + "parser_directory": parser_directory + } + with (path / "metadata.json").open(mode="w", encoding="utf8") as f: + f.write(json_string(metadata)) + self._parser.persist(path / parser_directory) + + @classmethod + def from_path(cls, path): + from snips_nlu_parsers import GazetteerEntityParser + + path = Path(path) + with (path / "metadata.json").open(encoding="utf8") as f: + metadata = json.load(f) + language = metadata["language"] + parser_usage = CustomEntityParserUsage(metadata["parser_usage"]) + parser_path = path / metadata["parser_directory"] + parser = GazetteerEntityParser.from_path(parser_path) + return cls(parser, language, parser_usage) + + @classmethod + def build(cls, dataset, parser_usage, resources): + from snips_nlu_parsers import GazetteerEntityParser + from snips_inference_agl.dataset import validate_and_format_dataset + + dataset = validate_and_format_dataset(dataset) + language = dataset[LANGUAGE] + custom_entities = { + entity_name: deepcopy(entity) + for entity_name, entity in iteritems(dataset[ENTITIES]) + if not is_builtin_entity(entity_name) + } + if parser_usage == CustomEntityParserUsage.WITH_AND_WITHOUT_STEMS: + for ent in viewvalues(custom_entities): + stemmed_utterances = _stem_entity_utterances( + ent[UTTERANCES], language, resources) + ent[UTTERANCES] = _merge_entity_utterances( + ent[UTTERANCES], stemmed_utterances) + elif parser_usage == CustomEntityParserUsage.WITH_STEMS: + for ent in viewvalues(custom_entities): + ent[UTTERANCES] = _stem_entity_utterances( + ent[UTTERANCES], language, resources) + elif parser_usage is None: + raise ValueError("A parser usage must be defined in order to fit " + "a CustomEntityParser") + configuration = _create_custom_entity_parser_configuration( + custom_entities, + language=dataset[LANGUAGE], + stopwords_fraction=STOPWORDS_FRACTION, + ) + parser = GazetteerEntityParser.build(configuration) + return cls(parser, language, parser_usage) + + +def _stem_entity_utterances(entity_utterances, language, resources): + values = dict() + # Sort by resolved value, so that values conflict in a deterministic way + for raw_value, resolved_value in sorted( + iteritems(entity_utterances), key=operator.itemgetter(1)): + stemmed_value = stem(raw_value, language, resources) + if stemmed_value not in values: + values[stemmed_value] = resolved_value + return values + + +def _merge_entity_utterances(raw_utterances, stemmed_utterances): + # Sort by resolved value, so that values conflict in a deterministic way + for raw_stemmed_value, resolved_value in sorted( + iteritems(stemmed_utterances), key=operator.itemgetter(1)): + if raw_stemmed_value not in raw_utterances: + raw_utterances[raw_stemmed_value] = resolved_value + return raw_utterances + + +def _create_custom_entity_parser_configuration( + entities, stopwords_fraction, language): + """Dynamically creates the gazetteer parser configuration. + + Args: + entities (dict): entity for the dataset + stopwords_fraction (float): fraction of the vocabulary of + the entity values that will be considered as stop words ( + the top n_vocabulary * stopwords_fraction most frequent words will + be considered stop words) + language (str): language of the entities + + Returns: the parser configuration as dictionary + """ + + if not 0 < stopwords_fraction < 1: + raise ValueError("stopwords_fraction must be in ]0.0, 1.0[") + + parser_configurations = [] + for entity_name, entity in sorted(iteritems(entities)): + vocabulary = set( + t for raw_value in entity[UTTERANCES] + for t in tokenize_light(raw_value, language) + ) + num_stopwords = int(stopwords_fraction * len(vocabulary)) + config = { + "entity_identifier": entity_name, + "entity_parser": { + "threshold": entity[MATCHING_STRICTNESS], + "n_gazetteer_stop_words": num_stopwords, + "gazetteer": [ + { + "raw_value": k, + "resolved_value": v + } for k, v in sorted(iteritems(entity[UTTERANCES])) + ] + } + } + if LICENSE_INFO in entity: + config["entity_parser"][LICENSE_INFO] = entity[LICENSE_INFO] + parser_configurations.append(config) + + configuration = { + "entity_parsers": parser_configurations + } + + return configuration + + +def _compute_char_shifts(tokens): + """Compute the shifts in characters that occur when comparing the + tokens string with the string consisting of all tokens separated with a + space + + For instance, if "hello?world" is tokenized in ["hello", "?", "world"], + then the character shifts between "hello?world" and "hello ? world" are + [0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2] + """ + characters_shifts = [] + if not tokens: + return characters_shifts + + current_shift = 0 + for token_index, token in enumerate(tokens): + if token_index == 0: + previous_token_end = 0 + previous_space_len = 0 + else: + previous_token_end = tokens[token_index - 1].end + previous_space_len = 1 + offset = (token.start - previous_token_end) - previous_space_len + current_shift -= offset + token_len = token.end - token.start + index_shift = token_len + previous_space_len + characters_shifts += [current_shift for _ in range(index_shift)] + return characters_shifts diff --git a/snips_inference_agl/entity_parser/custom_entity_parser_usage.py b/snips_inference_agl/entity_parser/custom_entity_parser_usage.py new file mode 100644 index 0000000..72d420a --- /dev/null +++ b/snips_inference_agl/entity_parser/custom_entity_parser_usage.py @@ -0,0 +1,23 @@ +from __future__ import unicode_literals + +from enum import Enum, unique + + +@unique +class CustomEntityParserUsage(Enum): + WITH_STEMS = 0 + """The parser is used with stemming""" + WITHOUT_STEMS = 1 + """The parser is used without stemming""" + WITH_AND_WITHOUT_STEMS = 2 + """The parser is used both with and without stemming""" + + @classmethod + def merge_usages(cls, lhs_usage, rhs_usage): + if lhs_usage is None: + return rhs_usage + if rhs_usage is None: + return lhs_usage + if lhs_usage == rhs_usage: + return lhs_usage + return cls.WITH_AND_WITHOUT_STEMS diff --git a/snips_inference_agl/entity_parser/entity_parser.py b/snips_inference_agl/entity_parser/entity_parser.py new file mode 100644 index 0000000..46de55e --- /dev/null +++ b/snips_inference_agl/entity_parser/entity_parser.py @@ -0,0 +1,85 @@ +# coding=utf-8 +from __future__ import unicode_literals + +from abc import ABCMeta, abstractmethod + +from future.builtins import object +from future.utils import with_metaclass + +from snips_inference_agl.common.dict_utils import LimitedSizeDict + +# pylint: disable=ungrouped-imports + +try: + from abc import abstractclassmethod +except ImportError: + from snips_inference_agl.common.abc_utils import abstractclassmethod + + +# pylint: enable=ungrouped-imports + + +class EntityParser(with_metaclass(ABCMeta, object)): + """Abstraction of a entity parser implementing some basic caching + """ + + def __init__(self): + self._cache = LimitedSizeDict(size_limit=1000) + + def parse(self, text, scope=None, use_cache=True): + """Search the given text for entities defined in the scope. If no + scope is provided, search for all kinds of entities. + + Args: + text (str): input text + scope (list or set of str, optional): if provided the parser + will only look for entities which entity kind is given in + the scope. By default the scope is None and the parser + will search for all kinds of supported entities + use_cache (bool): if False the internal cache will not be use, + this can be useful if the output of the parser depends on + the current timestamp. Defaults to True. + + Returns: + list of dict: list of the parsed entities formatted as a dict + containing the string value, the resolved value, the + entity kind and the entity range + """ + if not use_cache: + return self._parse(text, scope) + scope_key = tuple(sorted(scope)) if scope is not None else scope + cache_key = (text, scope_key) + if cache_key not in self._cache: + parser_result = self._parse(text, scope) + self._cache[cache_key] = parser_result + return self._cache[cache_key] + + @abstractmethod + def _parse(self, text, scope=None): + """Internal parse method to implement in each subclass of + :class:`.EntityParser` + + Args: + text (str): input text + scope (list or set of str, optional): if provided the parser + will only look for entities which entity kind is given in + the scope. By default the scope is None and the parser + will search for all kinds of supported entities + use_cache (bool): if False the internal cache will not be use, + this can be useful if the output of the parser depends on + the current timestamp. Defaults to True. + + Returns: + list of dict: list of the parsed entities. These entity must + have the same output format as the + :func:`snips_inference_agl.utils.result.parsed_entity` function + """ + pass + + @abstractmethod + def persist(self, path): + pass + + @abstractclassmethod + def from_path(cls, path): + pass |