aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/resources.py
diff options
context:
space:
mode:
Diffstat (limited to 'snips_inference_agl/resources.py')
-rw-r--r--snips_inference_agl/resources.py267
1 files changed, 267 insertions, 0 deletions
diff --git a/snips_inference_agl/resources.py b/snips_inference_agl/resources.py
new file mode 100644
index 0000000..3b3e478
--- /dev/null
+++ b/snips_inference_agl/resources.py
@@ -0,0 +1,267 @@
+from __future__ import unicode_literals
+
+import json
+from copy import deepcopy
+from pathlib import Path
+
+from snips_inference_agl.common.utils import get_package_path, is_package
+from snips_inference_agl.constants import (
+ CUSTOM_ENTITY_PARSER_USAGE, DATA_PATH, GAZETTEERS, NOISE,
+ STEMS, STOP_WORDS, WORD_CLUSTERS, METADATA)
+from snips_inference_agl.entity_parser.custom_entity_parser import (
+ CustomEntityParserUsage)
+
+
+class MissingResource(LookupError):
+ pass
+
+
+def load_resources(name, required_resources=None):
+ """Load language specific resources
+
+ Args:
+ name (str): Resource name as in ``snips-nlu download <name>``. Can also
+ be the name of a python package or a directory path.
+ required_resources (dict, optional): Resources requirement
+ dict which, when provided, allows to limit the amount of resources
+ to load. By default, all existing resources are loaded.
+ """
+ if name in set(d.name for d in DATA_PATH.iterdir()):
+ return load_resources_from_dir(DATA_PATH / name, required_resources)
+ elif is_package(name):
+ package_path = get_package_path(name)
+ resources_sub_dir = get_resources_sub_directory(package_path)
+ return load_resources_from_dir(resources_sub_dir, required_resources)
+ elif Path(name).exists():
+ path = Path(name)
+ if (path / "__init__.py").exists():
+ path = get_resources_sub_directory(path)
+ return load_resources_from_dir(path, required_resources)
+ else:
+ raise MissingResource("Language resource '{r}' not found. This may be "
+ "solved by running "
+ "'python -m snips_nlu download {r}'"
+ .format(r=name))
+
+
+def load_resources_from_dir(resources_dir, required_resources=None):
+ with (resources_dir / "metadata.json").open(encoding="utf8") as f:
+ metadata = json.load(f)
+ metadata = _update_metadata(metadata, required_resources)
+ gazetteer_names = metadata["gazetteers"]
+ clusters_names = metadata["word_clusters"]
+ stop_words_filename = metadata["stop_words"]
+ stems_filename = metadata["stems"]
+ noise_filename = metadata["noise"]
+
+ gazetteers = _get_gazetteers(resources_dir / "gazetteers", gazetteer_names)
+ word_clusters = _get_word_clusters(resources_dir / "word_clusters",
+ clusters_names)
+
+ stems = None
+ stop_words = None
+ noise = None
+
+ if stems_filename is not None:
+ stems = _get_stems(resources_dir / "stemming", stems_filename)
+ if stop_words_filename is not None:
+ stop_words = _get_stop_words(resources_dir, stop_words_filename)
+ if noise_filename is not None:
+ noise = _get_noise(resources_dir, noise_filename)
+
+ return {
+ METADATA: metadata,
+ WORD_CLUSTERS: word_clusters,
+ GAZETTEERS: gazetteers,
+ STOP_WORDS: stop_words,
+ NOISE: noise,
+ STEMS: stems,
+ }
+
+
+def _update_metadata(metadata, required_resources):
+ if required_resources is None:
+ return metadata
+ metadata = deepcopy(metadata)
+ required_gazetteers = required_resources.get(GAZETTEERS, [])
+ required_word_clusters = required_resources.get(WORD_CLUSTERS, [])
+ for gazetter in required_gazetteers:
+ if gazetter not in metadata["gazetteers"]:
+ raise ValueError("Unknown gazetteer for language '%s': '%s'"
+ % (metadata["language"], gazetter))
+ for word_clusters in required_word_clusters:
+ if word_clusters not in metadata["word_clusters"]:
+ raise ValueError("Unknown word clusters for language '%s': '%s'"
+ % (metadata["language"], word_clusters))
+ metadata["gazetteers"] = required_gazetteers
+ metadata["word_clusters"] = required_word_clusters
+ if not required_resources.get(STEMS, False):
+ metadata["stems"] = None
+ if not required_resources.get(NOISE, False):
+ metadata["noise"] = None
+ if not required_resources.get(STOP_WORDS, False):
+ metadata["stop_words"] = None
+ return metadata
+
+
+def get_resources_sub_directory(resources_dir):
+ resources_dir = Path(resources_dir)
+ with (resources_dir / "metadata.json").open(encoding="utf8") as f:
+ metadata = json.load(f)
+ resource_name = metadata["name"]
+ version = metadata["version"]
+ sub_dir_name = "{r}-{v}".format(r=resource_name, v=version)
+ return resources_dir / sub_dir_name
+
+
+def get_stop_words(resources):
+ return _get_resource(resources, STOP_WORDS)
+
+
+def get_noise(resources):
+ return _get_resource(resources, NOISE)
+
+
+def get_word_clusters(resources):
+ return _get_resource(resources, WORD_CLUSTERS)
+
+
+def get_word_cluster(resources, cluster_name):
+ word_clusters = get_word_clusters(resources)
+ if cluster_name not in word_clusters:
+ raise MissingResource("Word cluster '{}' not found" % cluster_name)
+ return word_clusters[cluster_name]
+
+
+def get_gazetteer(resources, gazetteer_name):
+ gazetteers = _get_resource(resources, GAZETTEERS)
+ if gazetteer_name not in gazetteers:
+ raise MissingResource("Gazetteer '%s' not found in resources"
+ % gazetteer_name)
+ return gazetteers[gazetteer_name]
+
+
+def get_stems(resources):
+ return _get_resource(resources, STEMS)
+
+
+def merge_required_resources(lhs, rhs):
+ if not lhs:
+ return dict() if rhs is None else rhs
+ if not rhs:
+ return dict() if lhs is None else lhs
+ merged_resources = dict()
+ if lhs.get(NOISE, False) or rhs.get(NOISE, False):
+ merged_resources[NOISE] = True
+ if lhs.get(STOP_WORDS, False) or rhs.get(STOP_WORDS, False):
+ merged_resources[STOP_WORDS] = True
+ if lhs.get(STEMS, False) or rhs.get(STEMS, False):
+ merged_resources[STEMS] = True
+ lhs_parser_usage = lhs.get(CUSTOM_ENTITY_PARSER_USAGE)
+ rhs_parser_usage = rhs.get(CUSTOM_ENTITY_PARSER_USAGE)
+ parser_usage = CustomEntityParserUsage.merge_usages(
+ lhs_parser_usage, rhs_parser_usage)
+ merged_resources[CUSTOM_ENTITY_PARSER_USAGE] = parser_usage
+ gazetteers = lhs.get(GAZETTEERS, set()).union(rhs.get(GAZETTEERS, set()))
+ if gazetteers:
+ merged_resources[GAZETTEERS] = gazetteers
+ word_clusters = lhs.get(WORD_CLUSTERS, set()).union(
+ rhs.get(WORD_CLUSTERS, set()))
+ if word_clusters:
+ merged_resources[WORD_CLUSTERS] = word_clusters
+ return merged_resources
+
+
+def _get_resource(resources, resource_name):
+ if resource_name not in resources or resources[resource_name] is None:
+ raise MissingResource("Resource '%s' not found" % resource_name)
+ return resources[resource_name]
+
+
+def _get_stop_words(resources_dir, stop_words_filename):
+ if not stop_words_filename:
+ return None
+ stop_words_path = (resources_dir / stop_words_filename).with_suffix(".txt")
+ return _load_stop_words(stop_words_path)
+
+
+def _load_stop_words(stop_words_path):
+ with stop_words_path.open(encoding="utf8") as f:
+ stop_words = set(l.strip() for l in f if l)
+ return stop_words
+
+
+def _get_noise(resources_dir, noise_filename):
+ if not noise_filename:
+ return None
+ noise_path = (resources_dir / noise_filename).with_suffix(".txt")
+ return _load_noise(noise_path)
+
+
+def _load_noise(noise_path):
+ with noise_path.open(encoding="utf8") as f:
+ # Here we split on a " " knowing that it's always ignored by
+ # the tokenization (see tokenization unit tests)
+ # It is not important to tokenize precisely as this noise is just used
+ # to generate utterances for the None intent
+ noise = [word for l in f for word in l.split()]
+ return noise
+
+
+def _get_word_clusters(word_clusters_dir, clusters_names):
+ if not clusters_names:
+ return dict()
+
+ clusters = dict()
+ for clusters_name in clusters_names:
+ clusters_path = (word_clusters_dir / clusters_name).with_suffix(".txt")
+ clusters[clusters_name] = _load_word_clusters(clusters_path)
+ return clusters
+
+
+def _load_word_clusters(path):
+ clusters = dict()
+ with path.open(encoding="utf8") as f:
+ for line in f:
+ split = line.rstrip().split("\t")
+ if not split:
+ continue
+ clusters[split[0]] = split[1]
+ return clusters
+
+
+def _get_gazetteers(gazetteers_dir, gazetteer_names):
+ if not gazetteer_names:
+ return dict()
+
+ gazetteers = dict()
+ for gazetteer_name in gazetteer_names:
+ gazetteer_path = (gazetteers_dir / gazetteer_name).with_suffix(".txt")
+ gazetteers[gazetteer_name] = _load_gazetteer(gazetteer_path)
+ return gazetteers
+
+
+def _load_gazetteer(path):
+ with path.open(encoding="utf8") as f:
+ gazetteer = set(v.strip() for v in f if v)
+ return gazetteer
+
+
+
+def _get_stems(stems_dir, filename):
+ if not filename:
+ return None
+ stems_path = (stems_dir / filename).with_suffix(".txt")
+ return _load_stems(stems_path)
+
+
+def _load_stems(path):
+ with path.open(encoding="utf8") as f:
+ stems = dict()
+ for line in f:
+ elements = line.strip().split(',')
+ stem = elements[0]
+ for value in elements[1:]:
+ stems[value] = stem
+ return stems
+