diff options
Diffstat (limited to 'snips_inference_agl/entity_parser/custom_entity_parser.py')
-rw-r--r-- | snips_inference_agl/entity_parser/custom_entity_parser.py | 209 |
1 files changed, 209 insertions, 0 deletions
diff --git a/snips_inference_agl/entity_parser/custom_entity_parser.py b/snips_inference_agl/entity_parser/custom_entity_parser.py new file mode 100644 index 0000000..949df1f --- /dev/null +++ b/snips_inference_agl/entity_parser/custom_entity_parser.py @@ -0,0 +1,209 @@ +# coding=utf-8 +from __future__ import unicode_literals + +import json +import operator +from copy import deepcopy +from pathlib import Path + +from future.utils import iteritems, viewvalues + +from snips_inference_agl.common.utils import json_string +from snips_inference_agl.constants import ( + END, ENTITIES, LANGUAGE, MATCHING_STRICTNESS, START, UTTERANCES, + LICENSE_INFO) +from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity +from snips_inference_agl.entity_parser.custom_entity_parser_usage import ( + CustomEntityParserUsage) +from snips_inference_agl.entity_parser.entity_parser import EntityParser +from snips_inference_agl.preprocessing import stem, tokenize, tokenize_light +from snips_inference_agl.result import parsed_entity + +STOPWORDS_FRACTION = 1e-3 + + +class CustomEntityParser(EntityParser): + def __init__(self, parser, language, parser_usage): + super(CustomEntityParser, self).__init__() + self._parser = parser + self.language = language + self.parser_usage = parser_usage + + def _parse(self, text, scope=None): + tokens = tokenize(text, self.language) + shifts = _compute_char_shifts(tokens) + cleaned_text = " ".join(token.value for token in tokens) + + entities = self._parser.parse(cleaned_text, scope) + result = [] + for entity in entities: + start = entity["range"]["start"] + start -= shifts[start] + end = entity["range"]["end"] + end -= shifts[end - 1] + entity_range = {START: start, END: end} + ent = parsed_entity( + entity_kind=entity["entity_identifier"], + entity_value=entity["value"], + entity_resolved_value=entity["resolved_value"], + entity_range=entity_range + ) + result.append(ent) + return result + + def persist(self, path): + path = Path(path) + path.mkdir() + parser_directory = "parser" + metadata = { + "language": self.language, + "parser_usage": self.parser_usage.value, + "parser_directory": parser_directory + } + with (path / "metadata.json").open(mode="w", encoding="utf8") as f: + f.write(json_string(metadata)) + self._parser.persist(path / parser_directory) + + @classmethod + def from_path(cls, path): + from snips_nlu_parsers import GazetteerEntityParser + + path = Path(path) + with (path / "metadata.json").open(encoding="utf8") as f: + metadata = json.load(f) + language = metadata["language"] + parser_usage = CustomEntityParserUsage(metadata["parser_usage"]) + parser_path = path / metadata["parser_directory"] + parser = GazetteerEntityParser.from_path(parser_path) + return cls(parser, language, parser_usage) + + @classmethod + def build(cls, dataset, parser_usage, resources): + from snips_nlu_parsers import GazetteerEntityParser + from snips_inference_agl.dataset import validate_and_format_dataset + + dataset = validate_and_format_dataset(dataset) + language = dataset[LANGUAGE] + custom_entities = { + entity_name: deepcopy(entity) + for entity_name, entity in iteritems(dataset[ENTITIES]) + if not is_builtin_entity(entity_name) + } + if parser_usage == CustomEntityParserUsage.WITH_AND_WITHOUT_STEMS: + for ent in viewvalues(custom_entities): + stemmed_utterances = _stem_entity_utterances( + ent[UTTERANCES], language, resources) + ent[UTTERANCES] = _merge_entity_utterances( + ent[UTTERANCES], stemmed_utterances) + elif parser_usage == CustomEntityParserUsage.WITH_STEMS: + for ent in viewvalues(custom_entities): + ent[UTTERANCES] = _stem_entity_utterances( + ent[UTTERANCES], language, resources) + elif parser_usage is None: + raise ValueError("A parser usage must be defined in order to fit " + "a CustomEntityParser") + configuration = _create_custom_entity_parser_configuration( + custom_entities, + language=dataset[LANGUAGE], + stopwords_fraction=STOPWORDS_FRACTION, + ) + parser = GazetteerEntityParser.build(configuration) + return cls(parser, language, parser_usage) + + +def _stem_entity_utterances(entity_utterances, language, resources): + values = dict() + # Sort by resolved value, so that values conflict in a deterministic way + for raw_value, resolved_value in sorted( + iteritems(entity_utterances), key=operator.itemgetter(1)): + stemmed_value = stem(raw_value, language, resources) + if stemmed_value not in values: + values[stemmed_value] = resolved_value + return values + + +def _merge_entity_utterances(raw_utterances, stemmed_utterances): + # Sort by resolved value, so that values conflict in a deterministic way + for raw_stemmed_value, resolved_value in sorted( + iteritems(stemmed_utterances), key=operator.itemgetter(1)): + if raw_stemmed_value not in raw_utterances: + raw_utterances[raw_stemmed_value] = resolved_value + return raw_utterances + + +def _create_custom_entity_parser_configuration( + entities, stopwords_fraction, language): + """Dynamically creates the gazetteer parser configuration. + + Args: + entities (dict): entity for the dataset + stopwords_fraction (float): fraction of the vocabulary of + the entity values that will be considered as stop words ( + the top n_vocabulary * stopwords_fraction most frequent words will + be considered stop words) + language (str): language of the entities + + Returns: the parser configuration as dictionary + """ + + if not 0 < stopwords_fraction < 1: + raise ValueError("stopwords_fraction must be in ]0.0, 1.0[") + + parser_configurations = [] + for entity_name, entity in sorted(iteritems(entities)): + vocabulary = set( + t for raw_value in entity[UTTERANCES] + for t in tokenize_light(raw_value, language) + ) + num_stopwords = int(stopwords_fraction * len(vocabulary)) + config = { + "entity_identifier": entity_name, + "entity_parser": { + "threshold": entity[MATCHING_STRICTNESS], + "n_gazetteer_stop_words": num_stopwords, + "gazetteer": [ + { + "raw_value": k, + "resolved_value": v + } for k, v in sorted(iteritems(entity[UTTERANCES])) + ] + } + } + if LICENSE_INFO in entity: + config["entity_parser"][LICENSE_INFO] = entity[LICENSE_INFO] + parser_configurations.append(config) + + configuration = { + "entity_parsers": parser_configurations + } + + return configuration + + +def _compute_char_shifts(tokens): + """Compute the shifts in characters that occur when comparing the + tokens string with the string consisting of all tokens separated with a + space + + For instance, if "hello?world" is tokenized in ["hello", "?", "world"], + then the character shifts between "hello?world" and "hello ? world" are + [0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2] + """ + characters_shifts = [] + if not tokens: + return characters_shifts + + current_shift = 0 + for token_index, token in enumerate(tokens): + if token_index == 0: + previous_token_end = 0 + previous_space_len = 0 + else: + previous_token_end = tokens[token_index - 1].end + previous_space_len = 1 + offset = (token.start - previous_token_end) - previous_space_len + current_shift -= offset + token_len = token.end - token.start + index_shift = token_len + previous_space_len + characters_shifts += [current_shift for _ in range(index_shift)] + return characters_shifts |