aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/dataset/dataset.py
diff options
context:
space:
mode:
Diffstat (limited to 'snips_inference_agl/dataset/dataset.py')
-rw-r--r--snips_inference_agl/dataset/dataset.py102
1 files changed, 102 insertions, 0 deletions
diff --git a/snips_inference_agl/dataset/dataset.py b/snips_inference_agl/dataset/dataset.py
new file mode 100644
index 0000000..2ad7867
--- /dev/null
+++ b/snips_inference_agl/dataset/dataset.py
@@ -0,0 +1,102 @@
+# coding=utf-8
+from __future__ import print_function, unicode_literals
+
+import io
+from itertools import cycle
+
+from snips_inference_agl.common.utils import unicode_string
+from snips_inference_agl.dataset.entity import Entity
+from snips_inference_agl.dataset.intent import Intent
+from snips_inference_agl.exceptions import DatasetFormatError
+
+
+class Dataset(object):
+ """Dataset used in the main NLU training API
+
+ Consists of intents and entities data. This object can be built either from
+ text files (:meth:`.Dataset.from_files`) or from YAML files
+ (:meth:`.Dataset.from_yaml_files`).
+
+ Attributes:
+ language (str): language of the intents
+ intents (list of :class:`.Intent`): intents data
+ entities (list of :class:`.Entity`): entities data
+ """
+
+ def __init__(self, language, intents, entities):
+ self.language = language
+ self.intents = intents
+ self.entities = entities
+ self._add_missing_entities()
+ self._ensure_entity_values()
+
+ @classmethod
+ def _load_dataset_parts(cls, stream, stream_description):
+ from snips_inference_agl.dataset.yaml_wrapper import yaml
+
+ intents = []
+ entities = []
+ for doc in yaml.safe_load_all(stream):
+ doc_type = doc.get("type")
+ if doc_type == "entity":
+ entities.append(Entity.from_yaml(doc))
+ elif doc_type == "intent":
+ intents.append(Intent.from_yaml(doc))
+ else:
+ raise DatasetFormatError(
+ "Invalid 'type' value in YAML file '%s': '%s'"
+ % (stream_description, doc_type))
+ return intents, entities
+
+ def _add_missing_entities(self):
+ entity_names = set(e.name for e in self.entities)
+
+ # Add entities appearing only in the intents utterances
+ for intent in self.intents:
+ for entity_name in intent.entities_names:
+ if entity_name not in entity_names:
+ entity_names.add(entity_name)
+ self.entities.append(Entity(name=entity_name))
+
+ def _ensure_entity_values(self):
+ entities_values = {entity.name: self._get_entity_values(entity)
+ for entity in self.entities}
+ for intent in self.intents:
+ for utterance in intent.utterances:
+ for chunk in utterance.slot_chunks:
+ if chunk.text is not None:
+ continue
+ try:
+ chunk.text = next(entities_values[chunk.entity])
+ except StopIteration:
+ raise DatasetFormatError(
+ "At least one entity value must be provided for "
+ "entity '%s'" % chunk.entity)
+ return self
+
+ def _get_entity_values(self, entity):
+ from snips_nlu_parsers import get_builtin_entity_examples
+
+ if entity.is_builtin:
+ return cycle(get_builtin_entity_examples(
+ entity.name, self.language))
+ values = [v for utterance in entity.utterances
+ for v in utterance.variations]
+ values_set = set(values)
+ for intent in self.intents:
+ for utterance in intent.utterances:
+ for chunk in utterance.slot_chunks:
+ if not chunk.text or chunk.entity != entity.name:
+ continue
+ if chunk.text not in values_set:
+ values_set.add(chunk.text)
+ values.append(chunk.text)
+ return cycle(values)
+
+ @property
+ def json(self):
+ """Dataset data in json format"""
+ intents = {intent_data.intent_name: intent_data.json
+ for intent_data in self.intents}
+ entities = {entity.name: entity.json for entity in self.entities}
+ return dict(language=self.language, intents=intents, entities=entities)