aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/common/dataset_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'snips_inference_agl/common/dataset_utils.py')
-rw-r--r--snips_inference_agl/common/dataset_utils.py48
1 files changed, 48 insertions, 0 deletions
diff --git a/snips_inference_agl/common/dataset_utils.py b/snips_inference_agl/common/dataset_utils.py
new file mode 100644
index 0000000..34648e6
--- /dev/null
+++ b/snips_inference_agl/common/dataset_utils.py
@@ -0,0 +1,48 @@
+from snips_inference_agl.constants import INTENTS, UTTERANCES, DATA, SLOT_NAME, ENTITY
+from snips_inference_agl.exceptions import DatasetFormatError
+
+
+def type_error(expected_type, found_type, object_label=None):
+ if object_label is None:
+ raise DatasetFormatError("Invalid type: expected %s but found %s"
+ % (expected_type, found_type))
+ raise DatasetFormatError("Invalid type for '%s': expected %s but found %s"
+ % (object_label, expected_type, found_type))
+
+
+def validate_type(obj, expected_type, object_label=None):
+ if not isinstance(obj, expected_type):
+ type_error(expected_type, type(obj), object_label)
+
+
+def missing_key_error(key, object_label=None):
+ if object_label is None:
+ raise DatasetFormatError("Missing key: '%s'" % key)
+ raise DatasetFormatError("Expected %s to have key: '%s'"
+ % (object_label, key))
+
+
+def validate_key(obj, key, object_label=None):
+ if key not in obj:
+ missing_key_error(key, object_label)
+
+
+def validate_keys(obj, keys, object_label=None):
+ for key in keys:
+ validate_key(obj, key, object_label)
+
+
+def get_slot_name_mapping(dataset, intent):
+ """Returns a dict which maps slot names to entities for the provided intent
+ """
+ slot_name_mapping = dict()
+ for utterance in dataset[INTENTS][intent][UTTERANCES]:
+ for chunk in utterance[DATA]:
+ if SLOT_NAME in chunk:
+ slot_name_mapping[chunk[SLOT_NAME]] = chunk[ENTITY]
+ return slot_name_mapping
+
+def get_slot_name_mappings(dataset):
+ """Returns a dict which maps intents to their slot name mapping"""
+ return {intent: get_slot_name_mapping(dataset, intent)
+ for intent in dataset[INTENTS]} \ No newline at end of file