aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/dataset
diff options
context:
space:
mode:
authorMalik Talha <talhamalik727x@gmail.com>2023-10-22 21:06:23 +0500
committerJan-Simon Moeller <jsmoeller@linuxfoundation.org>2023-10-23 14:38:13 +0000
commit697a1adce1e463079e640b55d6386cf82d7bd6bc (patch)
tree86e299cc7fe12b10c2e549f640924b61c7d07a95 /snips_inference_agl/dataset
parent97029ab8141e654a170a2282106f854037da294f (diff)
Add Snips Inference Module
Add slightly modified version of the original Snips NLU library. This module adds support for Python upto version 3.10. Bug-AGL: SPEC-4856 Signed-off-by: Malik Talha <talhamalik727x@gmail.com> Change-Id: I6d7e9eb181e6ff4aed9b6291027877ccb9f0d846
Diffstat (limited to 'snips_inference_agl/dataset')
-rw-r--r--snips_inference_agl/dataset/__init__.py7
-rw-r--r--snips_inference_agl/dataset/dataset.py102
-rw-r--r--snips_inference_agl/dataset/entity.py175
-rw-r--r--snips_inference_agl/dataset/intent.py339
-rw-r--r--snips_inference_agl/dataset/utils.py67
-rw-r--r--snips_inference_agl/dataset/validation.py254
-rw-r--r--snips_inference_agl/dataset/yaml_wrapper.py11
7 files changed, 955 insertions, 0 deletions
diff --git a/snips_inference_agl/dataset/__init__.py b/snips_inference_agl/dataset/__init__.py
new file mode 100644
index 0000000..4cbed08
--- /dev/null
+++ b/snips_inference_agl/dataset/__init__.py
@@ -0,0 +1,7 @@
+from snips_inference_agl.dataset.dataset import Dataset
+from snips_inference_agl.dataset.entity import Entity
+from snips_inference_agl.dataset.intent import Intent
+from snips_inference_agl.dataset.utils import (
+ extract_intent_entities, extract_utterance_entities,
+ get_dataset_gazetteer_entities, get_text_from_chunks)
+from snips_inference_agl.dataset.validation import validate_and_format_dataset
diff --git a/snips_inference_agl/dataset/dataset.py b/snips_inference_agl/dataset/dataset.py
new file mode 100644
index 0000000..2ad7867
--- /dev/null
+++ b/snips_inference_agl/dataset/dataset.py
@@ -0,0 +1,102 @@
+# coding=utf-8
+from __future__ import print_function, unicode_literals
+
+import io
+from itertools import cycle
+
+from snips_inference_agl.common.utils import unicode_string
+from snips_inference_agl.dataset.entity import Entity
+from snips_inference_agl.dataset.intent import Intent
+from snips_inference_agl.exceptions import DatasetFormatError
+
+
+class Dataset(object):
+ """Dataset used in the main NLU training API
+
+ Consists of intents and entities data. This object can be built either from
+ text files (:meth:`.Dataset.from_files`) or from YAML files
+ (:meth:`.Dataset.from_yaml_files`).
+
+ Attributes:
+ language (str): language of the intents
+ intents (list of :class:`.Intent`): intents data
+ entities (list of :class:`.Entity`): entities data
+ """
+
+ def __init__(self, language, intents, entities):
+ self.language = language
+ self.intents = intents
+ self.entities = entities
+ self._add_missing_entities()
+ self._ensure_entity_values()
+
+ @classmethod
+ def _load_dataset_parts(cls, stream, stream_description):
+ from snips_inference_agl.dataset.yaml_wrapper import yaml
+
+ intents = []
+ entities = []
+ for doc in yaml.safe_load_all(stream):
+ doc_type = doc.get("type")
+ if doc_type == "entity":
+ entities.append(Entity.from_yaml(doc))
+ elif doc_type == "intent":
+ intents.append(Intent.from_yaml(doc))
+ else:
+ raise DatasetFormatError(
+ "Invalid 'type' value in YAML file '%s': '%s'"
+ % (stream_description, doc_type))
+ return intents, entities
+
+ def _add_missing_entities(self):
+ entity_names = set(e.name for e in self.entities)
+
+ # Add entities appearing only in the intents utterances
+ for intent in self.intents:
+ for entity_name in intent.entities_names:
+ if entity_name not in entity_names:
+ entity_names.add(entity_name)
+ self.entities.append(Entity(name=entity_name))
+
+ def _ensure_entity_values(self):
+ entities_values = {entity.name: self._get_entity_values(entity)
+ for entity in self.entities}
+ for intent in self.intents:
+ for utterance in intent.utterances:
+ for chunk in utterance.slot_chunks:
+ if chunk.text is not None:
+ continue
+ try:
+ chunk.text = next(entities_values[chunk.entity])
+ except StopIteration:
+ raise DatasetFormatError(
+ "At least one entity value must be provided for "
+ "entity '%s'" % chunk.entity)
+ return self
+
+ def _get_entity_values(self, entity):
+ from snips_nlu_parsers import get_builtin_entity_examples
+
+ if entity.is_builtin:
+ return cycle(get_builtin_entity_examples(
+ entity.name, self.language))
+ values = [v for utterance in entity.utterances
+ for v in utterance.variations]
+ values_set = set(values)
+ for intent in self.intents:
+ for utterance in intent.utterances:
+ for chunk in utterance.slot_chunks:
+ if not chunk.text or chunk.entity != entity.name:
+ continue
+ if chunk.text not in values_set:
+ values_set.add(chunk.text)
+ values.append(chunk.text)
+ return cycle(values)
+
+ @property
+ def json(self):
+ """Dataset data in json format"""
+ intents = {intent_data.intent_name: intent_data.json
+ for intent_data in self.intents}
+ entities = {entity.name: entity.json for entity in self.entities}
+ return dict(language=self.language, intents=intents, entities=entities)
diff --git a/snips_inference_agl/dataset/entity.py b/snips_inference_agl/dataset/entity.py
new file mode 100644
index 0000000..65b9994
--- /dev/null
+++ b/snips_inference_agl/dataset/entity.py
@@ -0,0 +1,175 @@
+# coding=utf-8
+from __future__ import unicode_literals
+
+from builtins import str
+from io import IOBase
+
+from snips_inference_agl.constants import (
+ AUTOMATICALLY_EXTENSIBLE, DATA, MATCHING_STRICTNESS, SYNONYMS,
+ USE_SYNONYMS, VALUE)
+from snips_inference_agl.exceptions import EntityFormatError
+
+
+class Entity(object):
+ """Entity data of a :class:`.Dataset`
+
+ This class can represents both a custom or a builtin entity. When the
+ entity is a builtin one, only the `name` attribute is relevant.
+
+ Attributes:
+ name (str): name of the entity
+ utterances (list of :class:`.EntityUtterance`): entity utterances
+ (only for custom entities)
+ automatically_extensible (bool): whether or not the entity can be
+ extended to values not present in the data (only for custom
+ entities)
+ use_synonyms (bool): whether or not to map entity values using
+ synonyms (only for custom entities)
+ matching_strictness (float): controls the matching strictness of the
+ entity (only for custom entities). Must be between 0.0 and 1.0.
+ """
+
+ def __init__(self, name, utterances=None, automatically_extensible=True,
+ use_synonyms=True, matching_strictness=1.0):
+ if utterances is None:
+ utterances = []
+ self.name = name
+ self.utterances = utterances
+ self.automatically_extensible = automatically_extensible
+ self.use_synonyms = use_synonyms
+ self.matching_strictness = matching_strictness
+
+ @property
+ def is_builtin(self):
+ from snips_nlu_parsers import get_all_builtin_entities
+
+ return self.name in get_all_builtin_entities()
+
+ @classmethod
+ def from_yaml(cls, yaml_dict):
+ """Build an :class:`.Entity` from its YAML definition object
+
+ Args:
+ yaml_dict (dict or :class:`.IOBase`): object containing the YAML
+ definition of the entity. It can be either a stream, or the
+ corresponding python dict.
+
+ Examples:
+ An entity can be defined with a YAML document following the schema
+ illustrated in the example below:
+
+ >>> import io
+ >>> from snips_inference_agl.common.utils import json_string
+ >>> entity_yaml = io.StringIO('''
+ ... # City Entity
+ ... ---
+ ... type: entity
+ ... name: city
+ ... automatically_extensible: false # default value is true
+ ... use_synonyms: false # default value is true
+ ... matching_strictness: 0.8 # default value is 1.0
+ ... values:
+ ... - london
+ ... - [new york, big apple]
+ ... - [paris, city of lights]''')
+ >>> entity = Entity.from_yaml(entity_yaml)
+ >>> print(json_string(entity.json, indent=4, sort_keys=True))
+ {
+ "automatically_extensible": false,
+ "data": [
+ {
+ "synonyms": [],
+ "value": "london"
+ },
+ {
+ "synonyms": [
+ "big apple"
+ ],
+ "value": "new york"
+ },
+ {
+ "synonyms": [
+ "city of lights"
+ ],
+ "value": "paris"
+ }
+ ],
+ "matching_strictness": 0.8,
+ "use_synonyms": false
+ }
+
+ Raises:
+ EntityFormatError: When the YAML dict does not correspond to the
+ :ref:`expected entity format <yaml_entity_format>`
+ """
+ if isinstance(yaml_dict, IOBase):
+ from snips_inference_agl.dataset.yaml_wrapper import yaml
+
+ yaml_dict = yaml.safe_load(yaml_dict)
+
+ object_type = yaml_dict.get("type")
+ if object_type and object_type != "entity":
+ raise EntityFormatError("Wrong type: '%s'" % object_type)
+ entity_name = yaml_dict.get("name")
+ if not entity_name:
+ raise EntityFormatError("Missing 'name' attribute")
+ auto_extensible = yaml_dict.get(AUTOMATICALLY_EXTENSIBLE, True)
+ use_synonyms = yaml_dict.get(USE_SYNONYMS, True)
+ matching_strictness = yaml_dict.get("matching_strictness", 1.0)
+ utterances = []
+ for entity_value in yaml_dict.get("values", []):
+ if isinstance(entity_value, list):
+ utterance = EntityUtterance(entity_value[0], entity_value[1:])
+ elif isinstance(entity_value, str):
+ utterance = EntityUtterance(entity_value)
+ else:
+ raise EntityFormatError(
+ "YAML entity values must be either strings or lists, but "
+ "found: %s" % type(entity_value))
+ utterances.append(utterance)
+
+ return cls(name=entity_name,
+ utterances=utterances,
+ automatically_extensible=auto_extensible,
+ use_synonyms=use_synonyms,
+ matching_strictness=matching_strictness)
+
+ @property
+ def json(self):
+ """Returns the entity in json format"""
+ if self.is_builtin:
+ return dict()
+ return {
+ AUTOMATICALLY_EXTENSIBLE: self.automatically_extensible,
+ USE_SYNONYMS: self.use_synonyms,
+ DATA: [u.json for u in self.utterances],
+ MATCHING_STRICTNESS: self.matching_strictness
+ }
+
+
+class EntityUtterance(object):
+ """Represents a value of a :class:`.CustomEntity` with potential synonyms
+
+ Attributes:
+ value (str): entity value
+ synonyms (list of str): The values to remap to the utterance value
+ """
+
+ def __init__(self, value, synonyms=None):
+ self.value = value
+ if synonyms is None:
+ synonyms = []
+ self.synonyms = synonyms
+
+ @property
+ def variations(self):
+ return [self.value] + self.synonyms
+
+ @property
+ def json(self):
+ return {VALUE: self.value, SYNONYMS: self.synonyms}
+
+
+def utf_8_encoder(f):
+ for line in f:
+ yield line.encode("utf-8")
diff --git a/snips_inference_agl/dataset/intent.py b/snips_inference_agl/dataset/intent.py
new file mode 100644
index 0000000..0a915ce
--- /dev/null
+++ b/snips_inference_agl/dataset/intent.py
@@ -0,0 +1,339 @@
+from __future__ import absolute_import, print_function, unicode_literals
+
+from abc import ABCMeta, abstractmethod
+from builtins import object
+from io import IOBase
+
+from future.utils import with_metaclass
+
+from snips_inference_agl.constants import DATA, ENTITY, SLOT_NAME, TEXT, UTTERANCES
+from snips_inference_agl.exceptions import IntentFormatError
+
+
+class Intent(object):
+ """Intent data of a :class:`.Dataset`
+
+ Attributes:
+ intent_name (str): name of the intent
+ utterances (list of :class:`.IntentUtterance`): annotated intent
+ utterances
+ slot_mapping (dict): mapping between slot names and entities
+ """
+
+ def __init__(self, intent_name, utterances, slot_mapping=None):
+ if slot_mapping is None:
+ slot_mapping = dict()
+ self.intent_name = intent_name
+ self.utterances = utterances
+ self.slot_mapping = slot_mapping
+ self._complete_slot_name_mapping()
+ self._ensure_entity_names()
+
+ @classmethod
+ def from_yaml(cls, yaml_dict):
+ """Build an :class:`.Intent` from its YAML definition object
+
+ Args:
+ yaml_dict (dict or :class:`.IOBase`): object containing the YAML
+ definition of the intent. It can be either a stream, or the
+ corresponding python dict.
+
+ Examples:
+ An intent can be defined with a YAML document following the schema
+ illustrated in the example below:
+
+ >>> import io
+ >>> from snips_inference_agl.common.utils import json_string
+ >>> intent_yaml = io.StringIO('''
+ ... # searchFlight Intent
+ ... ---
+ ... type: intent
+ ... name: searchFlight
+ ... slots:
+ ... - name: origin
+ ... entity: city
+ ... - name: destination
+ ... entity: city
+ ... - name: date
+ ... entity: snips/datetime
+ ... utterances:
+ ... - find me a flight from [origin](Oslo) to [destination](Lima)
+ ... - I need a flight leaving to [destination](Berlin)''')
+ >>> intent = Intent.from_yaml(intent_yaml)
+ >>> print(json_string(intent.json, indent=4, sort_keys=True))
+ {
+ "utterances": [
+ {
+ "data": [
+ {
+ "text": "find me a flight from "
+ },
+ {
+ "entity": "city",
+ "slot_name": "origin",
+ "text": "Oslo"
+ },
+ {
+ "text": " to "
+ },
+ {
+ "entity": "city",
+ "slot_name": "destination",
+ "text": "Lima"
+ }
+ ]
+ },
+ {
+ "data": [
+ {
+ "text": "I need a flight leaving to "
+ },
+ {
+ "entity": "city",
+ "slot_name": "destination",
+ "text": "Berlin"
+ }
+ ]
+ }
+ ]
+ }
+
+ Raises:
+ IntentFormatError: When the YAML dict does not correspond to the
+ :ref:`expected intent format <yaml_intent_format>`
+ """
+
+ if isinstance(yaml_dict, IOBase):
+ from snips_inference_agl.dataset.yaml_wrapper import yaml
+
+ yaml_dict = yaml.safe_load(yaml_dict)
+
+ object_type = yaml_dict.get("type")
+ if object_type and object_type != "intent":
+ raise IntentFormatError("Wrong type: '%s'" % object_type)
+ intent_name = yaml_dict.get("name")
+ if not intent_name:
+ raise IntentFormatError("Missing 'name' attribute")
+ slot_mapping = dict()
+ for slot in yaml_dict.get("slots", []):
+ slot_mapping[slot["name"]] = slot["entity"]
+ utterances = [IntentUtterance.parse(u.strip())
+ for u in yaml_dict["utterances"] if u.strip()]
+ if not utterances:
+ raise IntentFormatError(
+ "Intent must contain at least one utterance")
+ return cls(intent_name, utterances, slot_mapping)
+
+ def _complete_slot_name_mapping(self):
+ for utterance in self.utterances:
+ for chunk in utterance.slot_chunks:
+ if chunk.entity and chunk.slot_name not in self.slot_mapping:
+ self.slot_mapping[chunk.slot_name] = chunk.entity
+ return self
+
+ def _ensure_entity_names(self):
+ for utterance in self.utterances:
+ for chunk in utterance.slot_chunks:
+ if chunk.entity:
+ continue
+ chunk.entity = self.slot_mapping.get(
+ chunk.slot_name, chunk.slot_name)
+ return self
+
+ @property
+ def json(self):
+ """Intent data in json format"""
+ return {
+ UTTERANCES: [
+ {DATA: [chunk.json for chunk in utterance.chunks]}
+ for utterance in self.utterances
+ ]
+ }
+
+ @property
+ def entities_names(self):
+ return set(chunk.entity for u in self.utterances
+ for chunk in u.chunks if isinstance(chunk, SlotChunk))
+
+
+class IntentUtterance(object):
+ def __init__(self, chunks):
+ self.chunks = chunks
+
+ @property
+ def text(self):
+ return "".join((chunk.text for chunk in self.chunks))
+
+ @property
+ def slot_chunks(self):
+ return (chunk for chunk in self.chunks if isinstance(chunk, SlotChunk))
+
+ @classmethod
+ def parse(cls, string):
+ """Parses an utterance
+
+ Args:
+ string (str): an utterance in the class:`.Utterance` format
+
+ Examples:
+
+ >>> from snips_inference_agl.dataset.intent import IntentUtterance
+ >>> u = IntentUtterance.\
+ parse("president of [country:default](France)")
+ >>> u.text
+ 'president of France'
+ >>> len(u.chunks)
+ 2
+ >>> u.chunks[0].text
+ 'president of '
+ >>> u.chunks[1].slot_name
+ 'country'
+ >>> u.chunks[1].entity
+ 'default'
+ """
+ sm = SM(string)
+ capture_text(sm)
+ return cls(sm.chunks)
+
+
+class Chunk(with_metaclass(ABCMeta, object)):
+ def __init__(self, text):
+ self.text = text
+
+ @abstractmethod
+ def json(self):
+ pass
+
+
+class SlotChunk(Chunk):
+ def __init__(self, slot_name, entity, text):
+ super(SlotChunk, self).__init__(text)
+ self.slot_name = slot_name
+ self.entity = entity
+
+ @property
+ def json(self):
+ return {
+ TEXT: self.text,
+ SLOT_NAME: self.slot_name,
+ ENTITY: self.entity,
+ }
+
+
+class TextChunk(Chunk):
+ @property
+ def json(self):
+ return {
+ TEXT: self.text
+ }
+
+
+class SM(object):
+ """State Machine for parsing"""
+
+ def __init__(self, input):
+ self.input = input
+ self.chunks = []
+ self.current = 0
+
+ @property
+ def end_of_input(self):
+ return self.current >= len(self.input)
+
+ def add_slot(self, name, entity=None):
+ """Adds a named slot
+
+ Args:
+ name (str): slot name
+ entity (str): entity name
+ """
+ chunk = SlotChunk(slot_name=name, entity=entity, text=None)
+ self.chunks.append(chunk)
+
+ def add_text(self, text):
+ """Adds a simple text chunk using the current position"""
+ chunk = TextChunk(text=text)
+ self.chunks.append(chunk)
+
+ def add_tagged(self, text):
+ """Adds text to the last slot"""
+ if not self.chunks:
+ raise AssertionError("Cannot add tagged text because chunks list "
+ "is empty")
+ self.chunks[-1].text = text
+
+ def find(self, s):
+ return self.input.find(s, self.current)
+
+ def move(self, pos):
+ """Moves the cursor of the state to position after given
+
+ Args:
+ pos (int): position to place the cursor just after
+ """
+ self.current = pos + 1
+
+ def peek(self):
+ if self.end_of_input:
+ return None
+ return self[0]
+
+ def read(self):
+ c = self[0]
+ self.current += 1
+ return c
+
+ def __getitem__(self, key):
+ current = self.current
+ if isinstance(key, int):
+ return self.input[current + key]
+ elif isinstance(key, slice):
+ start = current + key.start if key.start else current
+ return self.input[slice(start, key.stop, key.step)]
+ else:
+ raise TypeError("Bad key type: %s" % type(key))
+
+
+def capture_text(state):
+ next_pos = state.find('[')
+ sub = state[:] if next_pos < 0 else state[:next_pos]
+ if sub:
+ state.add_text(sub)
+ if next_pos >= 0:
+ state.move(next_pos)
+ capture_slot(state)
+
+
+def capture_slot(state):
+ next_colon_pos = state.find(':')
+ next_square_bracket_pos = state.find(']')
+ if next_square_bracket_pos < 0:
+ raise IntentFormatError(
+ "Missing ending ']' in annotated utterance \"%s\"" % state.input)
+ if next_colon_pos < 0 or next_square_bracket_pos < next_colon_pos:
+ slot_name = state[:next_square_bracket_pos]
+ state.move(next_square_bracket_pos)
+ state.add_slot(slot_name)
+ else:
+ slot_name = state[:next_colon_pos]
+ state.move(next_colon_pos)
+ entity = state[:next_square_bracket_pos]
+ state.move(next_square_bracket_pos)
+ state.add_slot(slot_name, entity)
+ if state.peek() == '(':
+ state.read()
+ capture_tagged(state)
+ else:
+ capture_text(state)
+
+
+def capture_tagged(state):
+ next_pos = state.find(')')
+ if next_pos < 1:
+ raise IntentFormatError(
+ "Missing ending ')' in annotated utterance \"%s\"" % state.input)
+ else:
+ tagged_text = state[:next_pos]
+ state.add_tagged(tagged_text)
+ state.move(next_pos)
+ capture_text(state)
diff --git a/snips_inference_agl/dataset/utils.py b/snips_inference_agl/dataset/utils.py
new file mode 100644
index 0000000..f147f0f
--- /dev/null
+++ b/snips_inference_agl/dataset/utils.py
@@ -0,0 +1,67 @@
+from __future__ import unicode_literals
+
+from future.utils import iteritems, itervalues
+
+from snips_inference_agl.constants import (
+ DATA, ENTITIES, ENTITY, INTENTS, TEXT, UTTERANCES)
+from snips_inference_agl.entity_parser.builtin_entity_parser import is_gazetteer_entity
+
+
+def extract_utterance_entities(dataset):
+ entities_values = {ent_name: set() for ent_name in dataset[ENTITIES]}
+
+ for intent in itervalues(dataset[INTENTS]):
+ for utterance in intent[UTTERANCES]:
+ for chunk in utterance[DATA]:
+ if ENTITY in chunk:
+ entities_values[chunk[ENTITY]].add(chunk[TEXT].strip())
+ return {k: list(v) for k, v in iteritems(entities_values)}
+
+
+def extract_intent_entities(dataset, entity_filter=None):
+ intent_entities = {intent: set() for intent in dataset[INTENTS]}
+ for intent_name, intent_data in iteritems(dataset[INTENTS]):
+ for utterance in intent_data[UTTERANCES]:
+ for chunk in utterance[DATA]:
+ if ENTITY in chunk:
+ if entity_filter and not entity_filter(chunk[ENTITY]):
+ continue
+ intent_entities[intent_name].add(chunk[ENTITY])
+ return intent_entities
+
+
+def extract_entity_values(dataset, apply_normalization):
+ from snips_nlu_utils import normalize
+
+ entities_per_intent = {intent: set() for intent in dataset[INTENTS]}
+ intent_entities = extract_intent_entities(dataset)
+ for intent, entities in iteritems(intent_entities):
+ for entity in entities:
+ entity_values = set(dataset[ENTITIES][entity][UTTERANCES])
+ if apply_normalization:
+ entity_values = {normalize(v) for v in entity_values}
+ entities_per_intent[intent].update(entity_values)
+ return entities_per_intent
+
+
+def get_text_from_chunks(chunks):
+ return "".join(chunk[TEXT] for chunk in chunks)
+
+
+def get_dataset_gazetteer_entities(dataset, intent=None):
+ if intent is not None:
+ return extract_intent_entities(dataset, is_gazetteer_entity)[intent]
+ return {e for e in dataset[ENTITIES] if is_gazetteer_entity(e)}
+
+
+def get_stop_words_whitelist(dataset, stop_words):
+ """Extracts stop words whitelists per intent consisting of entity values
+ that appear in the stop_words list"""
+ entity_values_per_intent = extract_entity_values(
+ dataset, apply_normalization=True)
+ stop_words_whitelist = dict()
+ for intent, entity_values in iteritems(entity_values_per_intent):
+ whitelist = stop_words.intersection(entity_values)
+ if whitelist:
+ stop_words_whitelist[intent] = whitelist
+ return stop_words_whitelist
diff --git a/snips_inference_agl/dataset/validation.py b/snips_inference_agl/dataset/validation.py
new file mode 100644
index 0000000..d6fc4a1
--- /dev/null
+++ b/snips_inference_agl/dataset/validation.py
@@ -0,0 +1,254 @@
+from __future__ import division, unicode_literals
+
+import json
+from builtins import str
+from collections import Counter
+from copy import deepcopy
+
+from future.utils import iteritems, itervalues
+
+from snips_inference_agl.common.dataset_utils import (validate_key, validate_keys,
+ validate_type)
+from snips_inference_agl.constants import (
+ AUTOMATICALLY_EXTENSIBLE, CAPITALIZE, DATA, ENTITIES, ENTITY, INTENTS,
+ LANGUAGE, MATCHING_STRICTNESS, SLOT_NAME, SYNONYMS, TEXT, USE_SYNONYMS,
+ UTTERANCES, VALIDATED, VALUE, LICENSE_INFO)
+from snips_inference_agl.dataset import extract_utterance_entities, Dataset
+from snips_inference_agl.entity_parser.builtin_entity_parser import (
+ BuiltinEntityParser, is_builtin_entity)
+from snips_inference_agl.exceptions import DatasetFormatError
+from snips_inference_agl.preprocessing import tokenize_light
+from snips_inference_agl.string_variations import get_string_variations
+
+NUMBER_VARIATIONS_THRESHOLD = 1e3
+VARIATIONS_GENERATION_THRESHOLD = 1e4
+
+
+def validate_and_format_dataset(dataset):
+ """Checks that the dataset is valid and format it
+
+ Raise:
+ DatasetFormatError: When the dataset format is wrong
+ """
+ from snips_nlu_parsers import get_all_languages
+
+ if isinstance(dataset, Dataset):
+ dataset = dataset.json
+
+ # Make this function idempotent
+ if dataset.get(VALIDATED, False):
+ return dataset
+ dataset = deepcopy(dataset)
+ dataset = json.loads(json.dumps(dataset))
+ validate_type(dataset, dict, object_label="dataset")
+ mandatory_keys = [INTENTS, ENTITIES, LANGUAGE]
+ for key in mandatory_keys:
+ validate_key(dataset, key, object_label="dataset")
+ validate_type(dataset[ENTITIES], dict, object_label="entities")
+ validate_type(dataset[INTENTS], dict, object_label="intents")
+ language = dataset[LANGUAGE]
+ validate_type(language, str, object_label="language")
+ if language not in get_all_languages():
+ raise DatasetFormatError("Unknown language: '%s'" % language)
+
+ dataset[INTENTS] = {
+ intent_name: intent_data
+ for intent_name, intent_data in sorted(iteritems(dataset[INTENTS]))}
+ for intent in itervalues(dataset[INTENTS]):
+ _validate_and_format_intent(intent, dataset[ENTITIES])
+
+ utterance_entities_values = extract_utterance_entities(dataset)
+ builtin_entity_parser = BuiltinEntityParser.build(dataset=dataset)
+
+ dataset[ENTITIES] = {
+ intent_name: entity_data
+ for intent_name, entity_data in sorted(iteritems(dataset[ENTITIES]))}
+
+ for entity_name, entity in iteritems(dataset[ENTITIES]):
+ uterrance_entities = utterance_entities_values[entity_name]
+ if is_builtin_entity(entity_name):
+ dataset[ENTITIES][entity_name] = \
+ _validate_and_format_builtin_entity(entity, uterrance_entities)
+ else:
+ dataset[ENTITIES][entity_name] = \
+ _validate_and_format_custom_entity(
+ entity, uterrance_entities, language,
+ builtin_entity_parser)
+ dataset[VALIDATED] = True
+ return dataset
+
+
+def _validate_and_format_intent(intent, entities):
+ validate_type(intent, dict, "intent")
+ validate_key(intent, UTTERANCES, object_label="intent dict")
+ validate_type(intent[UTTERANCES], list, object_label="utterances")
+ for utterance in intent[UTTERANCES]:
+ validate_type(utterance, dict, object_label="utterance")
+ validate_key(utterance, DATA, object_label="utterance")
+ validate_type(utterance[DATA], list, object_label="utterance data")
+ for chunk in utterance[DATA]:
+ validate_type(chunk, dict, object_label="utterance chunk")
+ validate_key(chunk, TEXT, object_label="chunk")
+ if ENTITY in chunk or SLOT_NAME in chunk:
+ mandatory_keys = [ENTITY, SLOT_NAME]
+ validate_keys(chunk, mandatory_keys, object_label="chunk")
+ if is_builtin_entity(chunk[ENTITY]):
+ continue
+ else:
+ validate_key(entities, chunk[ENTITY],
+ object_label=ENTITIES)
+ return intent
+
+
+def _has_any_capitalization(entity_utterances, language):
+ for utterance in entity_utterances:
+ tokens = tokenize_light(utterance, language)
+ if any(t.isupper() or t.istitle() for t in tokens):
+ return True
+ return False
+
+
+def _add_entity_variations(utterances, entity_variations, entity_value):
+ utterances[entity_value] = entity_value
+ for variation in entity_variations[entity_value]:
+ if variation:
+ utterances[variation] = entity_value
+ return utterances
+
+
+def _extract_entity_values(entity):
+ values = set()
+ for ent in entity[DATA]:
+ values.add(ent[VALUE])
+ if entity[USE_SYNONYMS]:
+ values.update(set(ent[SYNONYMS]))
+ return values
+
+
+def _validate_and_format_custom_entity(entity, utterance_entities, language,
+ builtin_entity_parser):
+ validate_type(entity, dict, object_label="entity")
+
+ # TODO: this is here temporarily, only to allow backward compatibility
+ if MATCHING_STRICTNESS not in entity:
+ strictness = entity.get("parser_threshold", 1.0)
+
+ entity[MATCHING_STRICTNESS] = strictness
+
+ mandatory_keys = [USE_SYNONYMS, AUTOMATICALLY_EXTENSIBLE, DATA,
+ MATCHING_STRICTNESS]
+ validate_keys(entity, mandatory_keys, object_label="custom entity")
+ validate_type(entity[USE_SYNONYMS], bool, object_label="use_synonyms")
+ validate_type(entity[AUTOMATICALLY_EXTENSIBLE], bool,
+ object_label="automatically_extensible")
+ validate_type(entity[DATA], list, object_label="entity data")
+ validate_type(entity[MATCHING_STRICTNESS], (float, int),
+ object_label="matching_strictness")
+
+ formatted_entity = dict()
+ formatted_entity[AUTOMATICALLY_EXTENSIBLE] = entity[
+ AUTOMATICALLY_EXTENSIBLE]
+ formatted_entity[MATCHING_STRICTNESS] = entity[MATCHING_STRICTNESS]
+ if LICENSE_INFO in entity:
+ formatted_entity[LICENSE_INFO] = entity[LICENSE_INFO]
+ use_synonyms = entity[USE_SYNONYMS]
+
+ # Validate format and filter out unused data
+ valid_entity_data = []
+ for entry in entity[DATA]:
+ validate_type(entry, dict, object_label="entity entry")
+ validate_keys(entry, [VALUE, SYNONYMS], object_label="entity entry")
+ entry[VALUE] = entry[VALUE].strip()
+ if not entry[VALUE]:
+ continue
+ validate_type(entry[SYNONYMS], list, object_label="entity synonyms")
+ entry[SYNONYMS] = [s.strip() for s in entry[SYNONYMS] if s.strip()]
+ valid_entity_data.append(entry)
+ entity[DATA] = valid_entity_data
+
+ # Compute capitalization before normalizing
+ # Normalization lowercase and hence lead to bad capitalization calculation
+ formatted_entity[CAPITALIZE] = _has_any_capitalization(utterance_entities,
+ language)
+
+ validated_utterances = dict()
+ # Map original values an synonyms
+ for data in entity[DATA]:
+ ent_value = data[VALUE]
+ validated_utterances[ent_value] = ent_value
+ if use_synonyms:
+ for s in data[SYNONYMS]:
+ if s not in validated_utterances:
+ validated_utterances[s] = ent_value
+
+ # Number variations in entities values are expensive since each entity
+ # value is parsed with the builtin entity parser before creating the
+ # variations. We avoid generating these variations if there's enough entity
+ # values
+
+ # Add variations if not colliding
+ all_original_values = _extract_entity_values(entity)
+ if len(entity[DATA]) < VARIATIONS_GENERATION_THRESHOLD:
+ variations_args = {
+ "case": True,
+ "and_": True,
+ "punctuation": True
+ }
+ else:
+ variations_args = {
+ "case": False,
+ "and_": False,
+ "punctuation": False
+ }
+
+ variations_args["numbers"] = len(
+ entity[DATA]) < NUMBER_VARIATIONS_THRESHOLD
+
+ variations = dict()
+ for data in entity[DATA]:
+ ent_value = data[VALUE]
+ values_to_variate = {ent_value}
+ if use_synonyms:
+ values_to_variate.update(set(data[SYNONYMS]))
+ variations[ent_value] = set(
+ v for value in values_to_variate
+ for v in get_string_variations(
+ value, language, builtin_entity_parser, **variations_args)
+ )
+ variation_counter = Counter(
+ [v for variations_ in itervalues(variations) for v in variations_])
+ non_colliding_variations = {
+ value: [
+ v for v in variations if
+ v not in all_original_values and variation_counter[v] == 1
+ ]
+ for value, variations in iteritems(variations)
+ }
+
+ for entry in entity[DATA]:
+ entry_value = entry[VALUE]
+ validated_utterances = _add_entity_variations(
+ validated_utterances, non_colliding_variations, entry_value)
+
+ # Merge utterances entities
+ utterance_entities_variations = {
+ ent: get_string_variations(
+ ent, language, builtin_entity_parser, **variations_args)
+ for ent in utterance_entities
+ }
+
+ for original_ent, variations in iteritems(utterance_entities_variations):
+ if not original_ent or original_ent in validated_utterances:
+ continue
+ validated_utterances[original_ent] = original_ent
+ for variation in variations:
+ if variation and variation not in validated_utterances \
+ and variation not in utterance_entities:
+ validated_utterances[variation] = original_ent
+ formatted_entity[UTTERANCES] = validated_utterances
+ return formatted_entity
+
+
+def _validate_and_format_builtin_entity(entity, utterance_entities):
+ validate_type(entity, dict, object_label="builtin entity")
+ return {UTTERANCES: set(utterance_entities)}
diff --git a/snips_inference_agl/dataset/yaml_wrapper.py b/snips_inference_agl/dataset/yaml_wrapper.py
new file mode 100644
index 0000000..ba8390d
--- /dev/null
+++ b/snips_inference_agl/dataset/yaml_wrapper.py
@@ -0,0 +1,11 @@
+import yaml
+
+
+def _construct_yaml_str(self, node):
+ # Override the default string handling function
+ # to always return unicode objects
+ return self.construct_scalar(node)
+
+
+yaml.Loader.add_constructor("tag:yaml.org,2002:str", _construct_yaml_str)
+yaml.SafeLoader.add_constructor("tag:yaml.org,2002:str", _construct_yaml_str)