aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/entity_parser/builtin_entity_parser.py
diff options
context:
space:
mode:
Diffstat (limited to 'snips_inference_agl/entity_parser/builtin_entity_parser.py')
-rw-r--r--snips_inference_agl/entity_parser/builtin_entity_parser.py162
1 files changed, 162 insertions, 0 deletions
diff --git a/snips_inference_agl/entity_parser/builtin_entity_parser.py b/snips_inference_agl/entity_parser/builtin_entity_parser.py
new file mode 100644
index 0000000..02fa610
--- /dev/null
+++ b/snips_inference_agl/entity_parser/builtin_entity_parser.py
@@ -0,0 +1,162 @@
+from __future__ import unicode_literals
+
+import json
+import shutil
+
+from future.builtins import str
+
+from snips_inference_agl.common.io_utils import temp_dir
+from snips_inference_agl.common.utils import json_string
+from snips_inference_agl.constants import DATA_PATH, ENTITIES, LANGUAGE
+from snips_inference_agl.entity_parser.entity_parser import EntityParser
+from snips_inference_agl.result import parsed_entity
+
+_BUILTIN_ENTITY_PARSERS = dict()
+
+try:
+ FileNotFoundError
+except NameError:
+ FileNotFoundError = IOError
+
+
+class BuiltinEntityParser(EntityParser):
+ def __init__(self, parser):
+ super(BuiltinEntityParser, self).__init__()
+ self._parser = parser
+
+ def _parse(self, text, scope=None):
+ entities = self._parser.parse(text.lower(), scope=scope)
+ result = []
+ for entity in entities:
+ ent = parsed_entity(
+ entity_kind=entity["entity_kind"],
+ entity_value=entity["value"],
+ entity_resolved_value=entity["entity"],
+ entity_range=entity["range"]
+ )
+ result.append(ent)
+ return result
+
+ def persist(self, path):
+ self._parser.persist(path)
+
+ @classmethod
+ def from_path(cls, path):
+ from snips_nlu_parsers import (
+ BuiltinEntityParser as _BuiltinEntityParser)
+
+ parser = _BuiltinEntityParser.from_path(path)
+ return cls(parser)
+
+ @classmethod
+ def build(cls, dataset=None, language=None, gazetteer_entity_scope=None):
+ from snips_nlu_parsers import get_supported_gazetteer_entities
+
+ global _BUILTIN_ENTITY_PARSERS
+
+ if dataset is not None:
+ language = dataset[LANGUAGE]
+ gazetteer_entity_scope = [entity for entity in dataset[ENTITIES]
+ if is_gazetteer_entity(entity)]
+
+ if language is None:
+ raise ValueError("Either a dataset or a language must be provided "
+ "in order to build a BuiltinEntityParser")
+
+ if gazetteer_entity_scope is None:
+ gazetteer_entity_scope = []
+ caching_key = _get_caching_key(language, gazetteer_entity_scope)
+ if caching_key not in _BUILTIN_ENTITY_PARSERS:
+ for entity in gazetteer_entity_scope:
+ if entity not in get_supported_gazetteer_entities(language):
+ raise ValueError(
+ "Gazetteer entity '%s' is not supported in "
+ "language '%s'" % (entity, language))
+ _BUILTIN_ENTITY_PARSERS[caching_key] = _build_builtin_parser(
+ language, gazetteer_entity_scope)
+ return _BUILTIN_ENTITY_PARSERS[caching_key]
+
+
+def _build_builtin_parser(language, gazetteer_entities):
+ from snips_nlu_parsers import BuiltinEntityParser as _BuiltinEntityParser
+
+ with temp_dir() as serialization_dir:
+ gazetteer_entity_parser = None
+ if gazetteer_entities:
+ gazetteer_entity_parser = _build_gazetteer_parser(
+ serialization_dir, gazetteer_entities, language)
+
+ metadata = {
+ "language": language.upper(),
+ "gazetteer_parser": gazetteer_entity_parser
+ }
+ metadata_path = serialization_dir / "metadata.json"
+ with metadata_path.open("w", encoding="utf-8") as f:
+ f.write(json_string(metadata))
+ parser = _BuiltinEntityParser.from_path(serialization_dir)
+ return BuiltinEntityParser(parser)
+
+
+def _build_gazetteer_parser(target_dir, gazetteer_entities, language):
+ from snips_nlu_parsers import get_builtin_entity_shortname
+
+ gazetteer_parser_name = "gazetteer_entity_parser"
+ gazetteer_parser_path = target_dir / gazetteer_parser_name
+ gazetteer_parser_metadata = []
+ for ent in sorted(gazetteer_entities):
+ # Fetch the compiled parser in the resources
+ source_parser_path = find_gazetteer_entity_data_path(language, ent)
+ short_name = get_builtin_entity_shortname(ent).lower()
+ target_parser_path = gazetteer_parser_path / short_name
+ parser_metadata = {
+ "entity_identifier": ent,
+ "entity_parser": short_name
+ }
+ gazetteer_parser_metadata.append(parser_metadata)
+ # Copy the single entity parser
+ shutil.copytree(str(source_parser_path), str(target_parser_path))
+ # Dump the parser metadata
+ gazetteer_entity_parser_metadata = {
+ "parsers_metadata": gazetteer_parser_metadata
+ }
+ gazetteer_parser_metadata_path = gazetteer_parser_path / "metadata.json"
+ with gazetteer_parser_metadata_path.open("w", encoding="utf-8") as f:
+ f.write(json_string(gazetteer_entity_parser_metadata))
+ return gazetteer_parser_name
+
+
+def is_builtin_entity(entity_label):
+ from snips_nlu_parsers import get_all_builtin_entities
+
+ return entity_label in get_all_builtin_entities()
+
+
+def is_gazetteer_entity(entity_label):
+ from snips_nlu_parsers import get_all_gazetteer_entities
+
+ return entity_label in get_all_gazetteer_entities()
+
+
+def find_gazetteer_entity_data_path(language, entity_name):
+ for directory in DATA_PATH.iterdir():
+ if not directory.is_dir():
+ continue
+ metadata_path = directory / "metadata.json"
+ if not metadata_path.exists():
+ continue
+ with metadata_path.open(encoding="utf8") as f:
+ metadata = json.load(f)
+ if metadata.get("entity_name") == entity_name \
+ and metadata.get("language") == language:
+ return directory / metadata["data_directory"]
+ raise FileNotFoundError(
+ "No data found for the '{e}' builtin entity in language '{lang}'. "
+ "You must download the corresponding resources by running "
+ "'python -m snips_nlu download-entity {e} {lang}' before you can use "
+ "this builtin entity.".format(e=entity_name, lang=language))
+
+
+def _get_caching_key(language, entity_scope):
+ tuple_key = (language,)
+ tuple_key += tuple(entity for entity in sorted(entity_scope))
+ return tuple_key