diff options
author | Malik Talha <talhamalik727x@gmail.com> | 2023-10-22 21:06:23 +0500 |
---|---|---|
committer | Jan-Simon Moeller <jsmoeller@linuxfoundation.org> | 2023-10-23 14:38:13 +0000 |
commit | 697a1adce1e463079e640b55d6386cf82d7bd6bc (patch) | |
tree | 86e299cc7fe12b10c2e549f640924b61c7d07a95 /snips_inference_agl/dataset | |
parent | 97029ab8141e654a170a2282106f854037da294f (diff) |
Add Snips Inference Module
Add slightly modified version of the original Snips NLU
library. This module adds support for Python upto version
3.10.
Bug-AGL: SPEC-4856
Signed-off-by: Malik Talha <talhamalik727x@gmail.com>
Change-Id: I6d7e9eb181e6ff4aed9b6291027877ccb9f0d846
Diffstat (limited to 'snips_inference_agl/dataset')
-rw-r--r-- | snips_inference_agl/dataset/__init__.py | 7 | ||||
-rw-r--r-- | snips_inference_agl/dataset/dataset.py | 102 | ||||
-rw-r--r-- | snips_inference_agl/dataset/entity.py | 175 | ||||
-rw-r--r-- | snips_inference_agl/dataset/intent.py | 339 | ||||
-rw-r--r-- | snips_inference_agl/dataset/utils.py | 67 | ||||
-rw-r--r-- | snips_inference_agl/dataset/validation.py | 254 | ||||
-rw-r--r-- | snips_inference_agl/dataset/yaml_wrapper.py | 11 |
7 files changed, 955 insertions, 0 deletions
diff --git a/snips_inference_agl/dataset/__init__.py b/snips_inference_agl/dataset/__init__.py new file mode 100644 index 0000000..4cbed08 --- /dev/null +++ b/snips_inference_agl/dataset/__init__.py @@ -0,0 +1,7 @@ +from snips_inference_agl.dataset.dataset import Dataset +from snips_inference_agl.dataset.entity import Entity +from snips_inference_agl.dataset.intent import Intent +from snips_inference_agl.dataset.utils import ( + extract_intent_entities, extract_utterance_entities, + get_dataset_gazetteer_entities, get_text_from_chunks) +from snips_inference_agl.dataset.validation import validate_and_format_dataset diff --git a/snips_inference_agl/dataset/dataset.py b/snips_inference_agl/dataset/dataset.py new file mode 100644 index 0000000..2ad7867 --- /dev/null +++ b/snips_inference_agl/dataset/dataset.py @@ -0,0 +1,102 @@ +# coding=utf-8 +from __future__ import print_function, unicode_literals + +import io +from itertools import cycle + +from snips_inference_agl.common.utils import unicode_string +from snips_inference_agl.dataset.entity import Entity +from snips_inference_agl.dataset.intent import Intent +from snips_inference_agl.exceptions import DatasetFormatError + + +class Dataset(object): + """Dataset used in the main NLU training API + + Consists of intents and entities data. This object can be built either from + text files (:meth:`.Dataset.from_files`) or from YAML files + (:meth:`.Dataset.from_yaml_files`). + + Attributes: + language (str): language of the intents + intents (list of :class:`.Intent`): intents data + entities (list of :class:`.Entity`): entities data + """ + + def __init__(self, language, intents, entities): + self.language = language + self.intents = intents + self.entities = entities + self._add_missing_entities() + self._ensure_entity_values() + + @classmethod + def _load_dataset_parts(cls, stream, stream_description): + from snips_inference_agl.dataset.yaml_wrapper import yaml + + intents = [] + entities = [] + for doc in yaml.safe_load_all(stream): + doc_type = doc.get("type") + if doc_type == "entity": + entities.append(Entity.from_yaml(doc)) + elif doc_type == "intent": + intents.append(Intent.from_yaml(doc)) + else: + raise DatasetFormatError( + "Invalid 'type' value in YAML file '%s': '%s'" + % (stream_description, doc_type)) + return intents, entities + + def _add_missing_entities(self): + entity_names = set(e.name for e in self.entities) + + # Add entities appearing only in the intents utterances + for intent in self.intents: + for entity_name in intent.entities_names: + if entity_name not in entity_names: + entity_names.add(entity_name) + self.entities.append(Entity(name=entity_name)) + + def _ensure_entity_values(self): + entities_values = {entity.name: self._get_entity_values(entity) + for entity in self.entities} + for intent in self.intents: + for utterance in intent.utterances: + for chunk in utterance.slot_chunks: + if chunk.text is not None: + continue + try: + chunk.text = next(entities_values[chunk.entity]) + except StopIteration: + raise DatasetFormatError( + "At least one entity value must be provided for " + "entity '%s'" % chunk.entity) + return self + + def _get_entity_values(self, entity): + from snips_nlu_parsers import get_builtin_entity_examples + + if entity.is_builtin: + return cycle(get_builtin_entity_examples( + entity.name, self.language)) + values = [v for utterance in entity.utterances + for v in utterance.variations] + values_set = set(values) + for intent in self.intents: + for utterance in intent.utterances: + for chunk in utterance.slot_chunks: + if not chunk.text or chunk.entity != entity.name: + continue + if chunk.text not in values_set: + values_set.add(chunk.text) + values.append(chunk.text) + return cycle(values) + + @property + def json(self): + """Dataset data in json format""" + intents = {intent_data.intent_name: intent_data.json + for intent_data in self.intents} + entities = {entity.name: entity.json for entity in self.entities} + return dict(language=self.language, intents=intents, entities=entities) diff --git a/snips_inference_agl/dataset/entity.py b/snips_inference_agl/dataset/entity.py new file mode 100644 index 0000000..65b9994 --- /dev/null +++ b/snips_inference_agl/dataset/entity.py @@ -0,0 +1,175 @@ +# coding=utf-8 +from __future__ import unicode_literals + +from builtins import str +from io import IOBase + +from snips_inference_agl.constants import ( + AUTOMATICALLY_EXTENSIBLE, DATA, MATCHING_STRICTNESS, SYNONYMS, + USE_SYNONYMS, VALUE) +from snips_inference_agl.exceptions import EntityFormatError + + +class Entity(object): + """Entity data of a :class:`.Dataset` + + This class can represents both a custom or a builtin entity. When the + entity is a builtin one, only the `name` attribute is relevant. + + Attributes: + name (str): name of the entity + utterances (list of :class:`.EntityUtterance`): entity utterances + (only for custom entities) + automatically_extensible (bool): whether or not the entity can be + extended to values not present in the data (only for custom + entities) + use_synonyms (bool): whether or not to map entity values using + synonyms (only for custom entities) + matching_strictness (float): controls the matching strictness of the + entity (only for custom entities). Must be between 0.0 and 1.0. + """ + + def __init__(self, name, utterances=None, automatically_extensible=True, + use_synonyms=True, matching_strictness=1.0): + if utterances is None: + utterances = [] + self.name = name + self.utterances = utterances + self.automatically_extensible = automatically_extensible + self.use_synonyms = use_synonyms + self.matching_strictness = matching_strictness + + @property + def is_builtin(self): + from snips_nlu_parsers import get_all_builtin_entities + + return self.name in get_all_builtin_entities() + + @classmethod + def from_yaml(cls, yaml_dict): + """Build an :class:`.Entity` from its YAML definition object + + Args: + yaml_dict (dict or :class:`.IOBase`): object containing the YAML + definition of the entity. It can be either a stream, or the + corresponding python dict. + + Examples: + An entity can be defined with a YAML document following the schema + illustrated in the example below: + + >>> import io + >>> from snips_inference_agl.common.utils import json_string + >>> entity_yaml = io.StringIO(''' + ... # City Entity + ... --- + ... type: entity + ... name: city + ... automatically_extensible: false # default value is true + ... use_synonyms: false # default value is true + ... matching_strictness: 0.8 # default value is 1.0 + ... values: + ... - london + ... - [new york, big apple] + ... - [paris, city of lights]''') + >>> entity = Entity.from_yaml(entity_yaml) + >>> print(json_string(entity.json, indent=4, sort_keys=True)) + { + "automatically_extensible": false, + "data": [ + { + "synonyms": [], + "value": "london" + }, + { + "synonyms": [ + "big apple" + ], + "value": "new york" + }, + { + "synonyms": [ + "city of lights" + ], + "value": "paris" + } + ], + "matching_strictness": 0.8, + "use_synonyms": false + } + + Raises: + EntityFormatError: When the YAML dict does not correspond to the + :ref:`expected entity format <yaml_entity_format>` + """ + if isinstance(yaml_dict, IOBase): + from snips_inference_agl.dataset.yaml_wrapper import yaml + + yaml_dict = yaml.safe_load(yaml_dict) + + object_type = yaml_dict.get("type") + if object_type and object_type != "entity": + raise EntityFormatError("Wrong type: '%s'" % object_type) + entity_name = yaml_dict.get("name") + if not entity_name: + raise EntityFormatError("Missing 'name' attribute") + auto_extensible = yaml_dict.get(AUTOMATICALLY_EXTENSIBLE, True) + use_synonyms = yaml_dict.get(USE_SYNONYMS, True) + matching_strictness = yaml_dict.get("matching_strictness", 1.0) + utterances = [] + for entity_value in yaml_dict.get("values", []): + if isinstance(entity_value, list): + utterance = EntityUtterance(entity_value[0], entity_value[1:]) + elif isinstance(entity_value, str): + utterance = EntityUtterance(entity_value) + else: + raise EntityFormatError( + "YAML entity values must be either strings or lists, but " + "found: %s" % type(entity_value)) + utterances.append(utterance) + + return cls(name=entity_name, + utterances=utterances, + automatically_extensible=auto_extensible, + use_synonyms=use_synonyms, + matching_strictness=matching_strictness) + + @property + def json(self): + """Returns the entity in json format""" + if self.is_builtin: + return dict() + return { + AUTOMATICALLY_EXTENSIBLE: self.automatically_extensible, + USE_SYNONYMS: self.use_synonyms, + DATA: [u.json for u in self.utterances], + MATCHING_STRICTNESS: self.matching_strictness + } + + +class EntityUtterance(object): + """Represents a value of a :class:`.CustomEntity` with potential synonyms + + Attributes: + value (str): entity value + synonyms (list of str): The values to remap to the utterance value + """ + + def __init__(self, value, synonyms=None): + self.value = value + if synonyms is None: + synonyms = [] + self.synonyms = synonyms + + @property + def variations(self): + return [self.value] + self.synonyms + + @property + def json(self): + return {VALUE: self.value, SYNONYMS: self.synonyms} + + +def utf_8_encoder(f): + for line in f: + yield line.encode("utf-8") diff --git a/snips_inference_agl/dataset/intent.py b/snips_inference_agl/dataset/intent.py new file mode 100644 index 0000000..0a915ce --- /dev/null +++ b/snips_inference_agl/dataset/intent.py @@ -0,0 +1,339 @@ +from __future__ import absolute_import, print_function, unicode_literals + +from abc import ABCMeta, abstractmethod +from builtins import object +from io import IOBase + +from future.utils import with_metaclass + +from snips_inference_agl.constants import DATA, ENTITY, SLOT_NAME, TEXT, UTTERANCES +from snips_inference_agl.exceptions import IntentFormatError + + +class Intent(object): + """Intent data of a :class:`.Dataset` + + Attributes: + intent_name (str): name of the intent + utterances (list of :class:`.IntentUtterance`): annotated intent + utterances + slot_mapping (dict): mapping between slot names and entities + """ + + def __init__(self, intent_name, utterances, slot_mapping=None): + if slot_mapping is None: + slot_mapping = dict() + self.intent_name = intent_name + self.utterances = utterances + self.slot_mapping = slot_mapping + self._complete_slot_name_mapping() + self._ensure_entity_names() + + @classmethod + def from_yaml(cls, yaml_dict): + """Build an :class:`.Intent` from its YAML definition object + + Args: + yaml_dict (dict or :class:`.IOBase`): object containing the YAML + definition of the intent. It can be either a stream, or the + corresponding python dict. + + Examples: + An intent can be defined with a YAML document following the schema + illustrated in the example below: + + >>> import io + >>> from snips_inference_agl.common.utils import json_string + >>> intent_yaml = io.StringIO(''' + ... # searchFlight Intent + ... --- + ... type: intent + ... name: searchFlight + ... slots: + ... - name: origin + ... entity: city + ... - name: destination + ... entity: city + ... - name: date + ... entity: snips/datetime + ... utterances: + ... - find me a flight from [origin](Oslo) to [destination](Lima) + ... - I need a flight leaving to [destination](Berlin)''') + >>> intent = Intent.from_yaml(intent_yaml) + >>> print(json_string(intent.json, indent=4, sort_keys=True)) + { + "utterances": [ + { + "data": [ + { + "text": "find me a flight from " + }, + { + "entity": "city", + "slot_name": "origin", + "text": "Oslo" + }, + { + "text": " to " + }, + { + "entity": "city", + "slot_name": "destination", + "text": "Lima" + } + ] + }, + { + "data": [ + { + "text": "I need a flight leaving to " + }, + { + "entity": "city", + "slot_name": "destination", + "text": "Berlin" + } + ] + } + ] + } + + Raises: + IntentFormatError: When the YAML dict does not correspond to the + :ref:`expected intent format <yaml_intent_format>` + """ + + if isinstance(yaml_dict, IOBase): + from snips_inference_agl.dataset.yaml_wrapper import yaml + + yaml_dict = yaml.safe_load(yaml_dict) + + object_type = yaml_dict.get("type") + if object_type and object_type != "intent": + raise IntentFormatError("Wrong type: '%s'" % object_type) + intent_name = yaml_dict.get("name") + if not intent_name: + raise IntentFormatError("Missing 'name' attribute") + slot_mapping = dict() + for slot in yaml_dict.get("slots", []): + slot_mapping[slot["name"]] = slot["entity"] + utterances = [IntentUtterance.parse(u.strip()) + for u in yaml_dict["utterances"] if u.strip()] + if not utterances: + raise IntentFormatError( + "Intent must contain at least one utterance") + return cls(intent_name, utterances, slot_mapping) + + def _complete_slot_name_mapping(self): + for utterance in self.utterances: + for chunk in utterance.slot_chunks: + if chunk.entity and chunk.slot_name not in self.slot_mapping: + self.slot_mapping[chunk.slot_name] = chunk.entity + return self + + def _ensure_entity_names(self): + for utterance in self.utterances: + for chunk in utterance.slot_chunks: + if chunk.entity: + continue + chunk.entity = self.slot_mapping.get( + chunk.slot_name, chunk.slot_name) + return self + + @property + def json(self): + """Intent data in json format""" + return { + UTTERANCES: [ + {DATA: [chunk.json for chunk in utterance.chunks]} + for utterance in self.utterances + ] + } + + @property + def entities_names(self): + return set(chunk.entity for u in self.utterances + for chunk in u.chunks if isinstance(chunk, SlotChunk)) + + +class IntentUtterance(object): + def __init__(self, chunks): + self.chunks = chunks + + @property + def text(self): + return "".join((chunk.text for chunk in self.chunks)) + + @property + def slot_chunks(self): + return (chunk for chunk in self.chunks if isinstance(chunk, SlotChunk)) + + @classmethod + def parse(cls, string): + """Parses an utterance + + Args: + string (str): an utterance in the class:`.Utterance` format + + Examples: + + >>> from snips_inference_agl.dataset.intent import IntentUtterance + >>> u = IntentUtterance.\ + parse("president of [country:default](France)") + >>> u.text + 'president of France' + >>> len(u.chunks) + 2 + >>> u.chunks[0].text + 'president of ' + >>> u.chunks[1].slot_name + 'country' + >>> u.chunks[1].entity + 'default' + """ + sm = SM(string) + capture_text(sm) + return cls(sm.chunks) + + +class Chunk(with_metaclass(ABCMeta, object)): + def __init__(self, text): + self.text = text + + @abstractmethod + def json(self): + pass + + +class SlotChunk(Chunk): + def __init__(self, slot_name, entity, text): + super(SlotChunk, self).__init__(text) + self.slot_name = slot_name + self.entity = entity + + @property + def json(self): + return { + TEXT: self.text, + SLOT_NAME: self.slot_name, + ENTITY: self.entity, + } + + +class TextChunk(Chunk): + @property + def json(self): + return { + TEXT: self.text + } + + +class SM(object): + """State Machine for parsing""" + + def __init__(self, input): + self.input = input + self.chunks = [] + self.current = 0 + + @property + def end_of_input(self): + return self.current >= len(self.input) + + def add_slot(self, name, entity=None): + """Adds a named slot + + Args: + name (str): slot name + entity (str): entity name + """ + chunk = SlotChunk(slot_name=name, entity=entity, text=None) + self.chunks.append(chunk) + + def add_text(self, text): + """Adds a simple text chunk using the current position""" + chunk = TextChunk(text=text) + self.chunks.append(chunk) + + def add_tagged(self, text): + """Adds text to the last slot""" + if not self.chunks: + raise AssertionError("Cannot add tagged text because chunks list " + "is empty") + self.chunks[-1].text = text + + def find(self, s): + return self.input.find(s, self.current) + + def move(self, pos): + """Moves the cursor of the state to position after given + + Args: + pos (int): position to place the cursor just after + """ + self.current = pos + 1 + + def peek(self): + if self.end_of_input: + return None + return self[0] + + def read(self): + c = self[0] + self.current += 1 + return c + + def __getitem__(self, key): + current = self.current + if isinstance(key, int): + return self.input[current + key] + elif isinstance(key, slice): + start = current + key.start if key.start else current + return self.input[slice(start, key.stop, key.step)] + else: + raise TypeError("Bad key type: %s" % type(key)) + + +def capture_text(state): + next_pos = state.find('[') + sub = state[:] if next_pos < 0 else state[:next_pos] + if sub: + state.add_text(sub) + if next_pos >= 0: + state.move(next_pos) + capture_slot(state) + + +def capture_slot(state): + next_colon_pos = state.find(':') + next_square_bracket_pos = state.find(']') + if next_square_bracket_pos < 0: + raise IntentFormatError( + "Missing ending ']' in annotated utterance \"%s\"" % state.input) + if next_colon_pos < 0 or next_square_bracket_pos < next_colon_pos: + slot_name = state[:next_square_bracket_pos] + state.move(next_square_bracket_pos) + state.add_slot(slot_name) + else: + slot_name = state[:next_colon_pos] + state.move(next_colon_pos) + entity = state[:next_square_bracket_pos] + state.move(next_square_bracket_pos) + state.add_slot(slot_name, entity) + if state.peek() == '(': + state.read() + capture_tagged(state) + else: + capture_text(state) + + +def capture_tagged(state): + next_pos = state.find(')') + if next_pos < 1: + raise IntentFormatError( + "Missing ending ')' in annotated utterance \"%s\"" % state.input) + else: + tagged_text = state[:next_pos] + state.add_tagged(tagged_text) + state.move(next_pos) + capture_text(state) diff --git a/snips_inference_agl/dataset/utils.py b/snips_inference_agl/dataset/utils.py new file mode 100644 index 0000000..f147f0f --- /dev/null +++ b/snips_inference_agl/dataset/utils.py @@ -0,0 +1,67 @@ +from __future__ import unicode_literals + +from future.utils import iteritems, itervalues + +from snips_inference_agl.constants import ( + DATA, ENTITIES, ENTITY, INTENTS, TEXT, UTTERANCES) +from snips_inference_agl.entity_parser.builtin_entity_parser import is_gazetteer_entity + + +def extract_utterance_entities(dataset): + entities_values = {ent_name: set() for ent_name in dataset[ENTITIES]} + + for intent in itervalues(dataset[INTENTS]): + for utterance in intent[UTTERANCES]: + for chunk in utterance[DATA]: + if ENTITY in chunk: + entities_values[chunk[ENTITY]].add(chunk[TEXT].strip()) + return {k: list(v) for k, v in iteritems(entities_values)} + + +def extract_intent_entities(dataset, entity_filter=None): + intent_entities = {intent: set() for intent in dataset[INTENTS]} + for intent_name, intent_data in iteritems(dataset[INTENTS]): + for utterance in intent_data[UTTERANCES]: + for chunk in utterance[DATA]: + if ENTITY in chunk: + if entity_filter and not entity_filter(chunk[ENTITY]): + continue + intent_entities[intent_name].add(chunk[ENTITY]) + return intent_entities + + +def extract_entity_values(dataset, apply_normalization): + from snips_nlu_utils import normalize + + entities_per_intent = {intent: set() for intent in dataset[INTENTS]} + intent_entities = extract_intent_entities(dataset) + for intent, entities in iteritems(intent_entities): + for entity in entities: + entity_values = set(dataset[ENTITIES][entity][UTTERANCES]) + if apply_normalization: + entity_values = {normalize(v) for v in entity_values} + entities_per_intent[intent].update(entity_values) + return entities_per_intent + + +def get_text_from_chunks(chunks): + return "".join(chunk[TEXT] for chunk in chunks) + + +def get_dataset_gazetteer_entities(dataset, intent=None): + if intent is not None: + return extract_intent_entities(dataset, is_gazetteer_entity)[intent] + return {e for e in dataset[ENTITIES] if is_gazetteer_entity(e)} + + +def get_stop_words_whitelist(dataset, stop_words): + """Extracts stop words whitelists per intent consisting of entity values + that appear in the stop_words list""" + entity_values_per_intent = extract_entity_values( + dataset, apply_normalization=True) + stop_words_whitelist = dict() + for intent, entity_values in iteritems(entity_values_per_intent): + whitelist = stop_words.intersection(entity_values) + if whitelist: + stop_words_whitelist[intent] = whitelist + return stop_words_whitelist diff --git a/snips_inference_agl/dataset/validation.py b/snips_inference_agl/dataset/validation.py new file mode 100644 index 0000000..d6fc4a1 --- /dev/null +++ b/snips_inference_agl/dataset/validation.py @@ -0,0 +1,254 @@ +from __future__ import division, unicode_literals + +import json +from builtins import str +from collections import Counter +from copy import deepcopy + +from future.utils import iteritems, itervalues + +from snips_inference_agl.common.dataset_utils import (validate_key, validate_keys, + validate_type) +from snips_inference_agl.constants import ( + AUTOMATICALLY_EXTENSIBLE, CAPITALIZE, DATA, ENTITIES, ENTITY, INTENTS, + LANGUAGE, MATCHING_STRICTNESS, SLOT_NAME, SYNONYMS, TEXT, USE_SYNONYMS, + UTTERANCES, VALIDATED, VALUE, LICENSE_INFO) +from snips_inference_agl.dataset import extract_utterance_entities, Dataset +from snips_inference_agl.entity_parser.builtin_entity_parser import ( + BuiltinEntityParser, is_builtin_entity) +from snips_inference_agl.exceptions import DatasetFormatError +from snips_inference_agl.preprocessing import tokenize_light +from snips_inference_agl.string_variations import get_string_variations + +NUMBER_VARIATIONS_THRESHOLD = 1e3 +VARIATIONS_GENERATION_THRESHOLD = 1e4 + + +def validate_and_format_dataset(dataset): + """Checks that the dataset is valid and format it + + Raise: + DatasetFormatError: When the dataset format is wrong + """ + from snips_nlu_parsers import get_all_languages + + if isinstance(dataset, Dataset): + dataset = dataset.json + + # Make this function idempotent + if dataset.get(VALIDATED, False): + return dataset + dataset = deepcopy(dataset) + dataset = json.loads(json.dumps(dataset)) + validate_type(dataset, dict, object_label="dataset") + mandatory_keys = [INTENTS, ENTITIES, LANGUAGE] + for key in mandatory_keys: + validate_key(dataset, key, object_label="dataset") + validate_type(dataset[ENTITIES], dict, object_label="entities") + validate_type(dataset[INTENTS], dict, object_label="intents") + language = dataset[LANGUAGE] + validate_type(language, str, object_label="language") + if language not in get_all_languages(): + raise DatasetFormatError("Unknown language: '%s'" % language) + + dataset[INTENTS] = { + intent_name: intent_data + for intent_name, intent_data in sorted(iteritems(dataset[INTENTS]))} + for intent in itervalues(dataset[INTENTS]): + _validate_and_format_intent(intent, dataset[ENTITIES]) + + utterance_entities_values = extract_utterance_entities(dataset) + builtin_entity_parser = BuiltinEntityParser.build(dataset=dataset) + + dataset[ENTITIES] = { + intent_name: entity_data + for intent_name, entity_data in sorted(iteritems(dataset[ENTITIES]))} + + for entity_name, entity in iteritems(dataset[ENTITIES]): + uterrance_entities = utterance_entities_values[entity_name] + if is_builtin_entity(entity_name): + dataset[ENTITIES][entity_name] = \ + _validate_and_format_builtin_entity(entity, uterrance_entities) + else: + dataset[ENTITIES][entity_name] = \ + _validate_and_format_custom_entity( + entity, uterrance_entities, language, + builtin_entity_parser) + dataset[VALIDATED] = True + return dataset + + +def _validate_and_format_intent(intent, entities): + validate_type(intent, dict, "intent") + validate_key(intent, UTTERANCES, object_label="intent dict") + validate_type(intent[UTTERANCES], list, object_label="utterances") + for utterance in intent[UTTERANCES]: + validate_type(utterance, dict, object_label="utterance") + validate_key(utterance, DATA, object_label="utterance") + validate_type(utterance[DATA], list, object_label="utterance data") + for chunk in utterance[DATA]: + validate_type(chunk, dict, object_label="utterance chunk") + validate_key(chunk, TEXT, object_label="chunk") + if ENTITY in chunk or SLOT_NAME in chunk: + mandatory_keys = [ENTITY, SLOT_NAME] + validate_keys(chunk, mandatory_keys, object_label="chunk") + if is_builtin_entity(chunk[ENTITY]): + continue + else: + validate_key(entities, chunk[ENTITY], + object_label=ENTITIES) + return intent + + +def _has_any_capitalization(entity_utterances, language): + for utterance in entity_utterances: + tokens = tokenize_light(utterance, language) + if any(t.isupper() or t.istitle() for t in tokens): + return True + return False + + +def _add_entity_variations(utterances, entity_variations, entity_value): + utterances[entity_value] = entity_value + for variation in entity_variations[entity_value]: + if variation: + utterances[variation] = entity_value + return utterances + + +def _extract_entity_values(entity): + values = set() + for ent in entity[DATA]: + values.add(ent[VALUE]) + if entity[USE_SYNONYMS]: + values.update(set(ent[SYNONYMS])) + return values + + +def _validate_and_format_custom_entity(entity, utterance_entities, language, + builtin_entity_parser): + validate_type(entity, dict, object_label="entity") + + # TODO: this is here temporarily, only to allow backward compatibility + if MATCHING_STRICTNESS not in entity: + strictness = entity.get("parser_threshold", 1.0) + + entity[MATCHING_STRICTNESS] = strictness + + mandatory_keys = [USE_SYNONYMS, AUTOMATICALLY_EXTENSIBLE, DATA, + MATCHING_STRICTNESS] + validate_keys(entity, mandatory_keys, object_label="custom entity") + validate_type(entity[USE_SYNONYMS], bool, object_label="use_synonyms") + validate_type(entity[AUTOMATICALLY_EXTENSIBLE], bool, + object_label="automatically_extensible") + validate_type(entity[DATA], list, object_label="entity data") + validate_type(entity[MATCHING_STRICTNESS], (float, int), + object_label="matching_strictness") + + formatted_entity = dict() + formatted_entity[AUTOMATICALLY_EXTENSIBLE] = entity[ + AUTOMATICALLY_EXTENSIBLE] + formatted_entity[MATCHING_STRICTNESS] = entity[MATCHING_STRICTNESS] + if LICENSE_INFO in entity: + formatted_entity[LICENSE_INFO] = entity[LICENSE_INFO] + use_synonyms = entity[USE_SYNONYMS] + + # Validate format and filter out unused data + valid_entity_data = [] + for entry in entity[DATA]: + validate_type(entry, dict, object_label="entity entry") + validate_keys(entry, [VALUE, SYNONYMS], object_label="entity entry") + entry[VALUE] = entry[VALUE].strip() + if not entry[VALUE]: + continue + validate_type(entry[SYNONYMS], list, object_label="entity synonyms") + entry[SYNONYMS] = [s.strip() for s in entry[SYNONYMS] if s.strip()] + valid_entity_data.append(entry) + entity[DATA] = valid_entity_data + + # Compute capitalization before normalizing + # Normalization lowercase and hence lead to bad capitalization calculation + formatted_entity[CAPITALIZE] = _has_any_capitalization(utterance_entities, + language) + + validated_utterances = dict() + # Map original values an synonyms + for data in entity[DATA]: + ent_value = data[VALUE] + validated_utterances[ent_value] = ent_value + if use_synonyms: + for s in data[SYNONYMS]: + if s not in validated_utterances: + validated_utterances[s] = ent_value + + # Number variations in entities values are expensive since each entity + # value is parsed with the builtin entity parser before creating the + # variations. We avoid generating these variations if there's enough entity + # values + + # Add variations if not colliding + all_original_values = _extract_entity_values(entity) + if len(entity[DATA]) < VARIATIONS_GENERATION_THRESHOLD: + variations_args = { + "case": True, + "and_": True, + "punctuation": True + } + else: + variations_args = { + "case": False, + "and_": False, + "punctuation": False + } + + variations_args["numbers"] = len( + entity[DATA]) < NUMBER_VARIATIONS_THRESHOLD + + variations = dict() + for data in entity[DATA]: + ent_value = data[VALUE] + values_to_variate = {ent_value} + if use_synonyms: + values_to_variate.update(set(data[SYNONYMS])) + variations[ent_value] = set( + v for value in values_to_variate + for v in get_string_variations( + value, language, builtin_entity_parser, **variations_args) + ) + variation_counter = Counter( + [v for variations_ in itervalues(variations) for v in variations_]) + non_colliding_variations = { + value: [ + v for v in variations if + v not in all_original_values and variation_counter[v] == 1 + ] + for value, variations in iteritems(variations) + } + + for entry in entity[DATA]: + entry_value = entry[VALUE] + validated_utterances = _add_entity_variations( + validated_utterances, non_colliding_variations, entry_value) + + # Merge utterances entities + utterance_entities_variations = { + ent: get_string_variations( + ent, language, builtin_entity_parser, **variations_args) + for ent in utterance_entities + } + + for original_ent, variations in iteritems(utterance_entities_variations): + if not original_ent or original_ent in validated_utterances: + continue + validated_utterances[original_ent] = original_ent + for variation in variations: + if variation and variation not in validated_utterances \ + and variation not in utterance_entities: + validated_utterances[variation] = original_ent + formatted_entity[UTTERANCES] = validated_utterances + return formatted_entity + + +def _validate_and_format_builtin_entity(entity, utterance_entities): + validate_type(entity, dict, object_label="builtin entity") + return {UTTERANCES: set(utterance_entities)} diff --git a/snips_inference_agl/dataset/yaml_wrapper.py b/snips_inference_agl/dataset/yaml_wrapper.py new file mode 100644 index 0000000..ba8390d --- /dev/null +++ b/snips_inference_agl/dataset/yaml_wrapper.py @@ -0,0 +1,11 @@ +import yaml + + +def _construct_yaml_str(self, node): + # Override the default string handling function + # to always return unicode objects + return self.construct_scalar(node) + + +yaml.Loader.add_constructor("tag:yaml.org,2002:str", _construct_yaml_str) +yaml.SafeLoader.add_constructor("tag:yaml.org,2002:str", _construct_yaml_str) |