aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/dataset/validation.py
diff options
context:
space:
mode:
Diffstat (limited to 'snips_inference_agl/dataset/validation.py')
-rw-r--r--snips_inference_agl/dataset/validation.py254
1 files changed, 254 insertions, 0 deletions
diff --git a/snips_inference_agl/dataset/validation.py b/snips_inference_agl/dataset/validation.py
new file mode 100644
index 0000000..d6fc4a1
--- /dev/null
+++ b/snips_inference_agl/dataset/validation.py
@@ -0,0 +1,254 @@
+from __future__ import division, unicode_literals
+
+import json
+from builtins import str
+from collections import Counter
+from copy import deepcopy
+
+from future.utils import iteritems, itervalues
+
+from snips_inference_agl.common.dataset_utils import (validate_key, validate_keys,
+ validate_type)
+from snips_inference_agl.constants import (
+ AUTOMATICALLY_EXTENSIBLE, CAPITALIZE, DATA, ENTITIES, ENTITY, INTENTS,
+ LANGUAGE, MATCHING_STRICTNESS, SLOT_NAME, SYNONYMS, TEXT, USE_SYNONYMS,
+ UTTERANCES, VALIDATED, VALUE, LICENSE_INFO)
+from snips_inference_agl.dataset import extract_utterance_entities, Dataset
+from snips_inference_agl.entity_parser.builtin_entity_parser import (
+ BuiltinEntityParser, is_builtin_entity)
+from snips_inference_agl.exceptions import DatasetFormatError
+from snips_inference_agl.preprocessing import tokenize_light
+from snips_inference_agl.string_variations import get_string_variations
+
+NUMBER_VARIATIONS_THRESHOLD = 1e3
+VARIATIONS_GENERATION_THRESHOLD = 1e4
+
+
+def validate_and_format_dataset(dataset):
+ """Checks that the dataset is valid and format it
+
+ Raise:
+ DatasetFormatError: When the dataset format is wrong
+ """
+ from snips_nlu_parsers import get_all_languages
+
+ if isinstance(dataset, Dataset):
+ dataset = dataset.json
+
+ # Make this function idempotent
+ if dataset.get(VALIDATED, False):
+ return dataset
+ dataset = deepcopy(dataset)
+ dataset = json.loads(json.dumps(dataset))
+ validate_type(dataset, dict, object_label="dataset")
+ mandatory_keys = [INTENTS, ENTITIES, LANGUAGE]
+ for key in mandatory_keys:
+ validate_key(dataset, key, object_label="dataset")
+ validate_type(dataset[ENTITIES], dict, object_label="entities")
+ validate_type(dataset[INTENTS], dict, object_label="intents")
+ language = dataset[LANGUAGE]
+ validate_type(language, str, object_label="language")
+ if language not in get_all_languages():
+ raise DatasetFormatError("Unknown language: '%s'" % language)
+
+ dataset[INTENTS] = {
+ intent_name: intent_data
+ for intent_name, intent_data in sorted(iteritems(dataset[INTENTS]))}
+ for intent in itervalues(dataset[INTENTS]):
+ _validate_and_format_intent(intent, dataset[ENTITIES])
+
+ utterance_entities_values = extract_utterance_entities(dataset)
+ builtin_entity_parser = BuiltinEntityParser.build(dataset=dataset)
+
+ dataset[ENTITIES] = {
+ intent_name: entity_data
+ for intent_name, entity_data in sorted(iteritems(dataset[ENTITIES]))}
+
+ for entity_name, entity in iteritems(dataset[ENTITIES]):
+ uterrance_entities = utterance_entities_values[entity_name]
+ if is_builtin_entity(entity_name):
+ dataset[ENTITIES][entity_name] = \
+ _validate_and_format_builtin_entity(entity, uterrance_entities)
+ else:
+ dataset[ENTITIES][entity_name] = \
+ _validate_and_format_custom_entity(
+ entity, uterrance_entities, language,
+ builtin_entity_parser)
+ dataset[VALIDATED] = True
+ return dataset
+
+
+def _validate_and_format_intent(intent, entities):
+ validate_type(intent, dict, "intent")
+ validate_key(intent, UTTERANCES, object_label="intent dict")
+ validate_type(intent[UTTERANCES], list, object_label="utterances")
+ for utterance in intent[UTTERANCES]:
+ validate_type(utterance, dict, object_label="utterance")
+ validate_key(utterance, DATA, object_label="utterance")
+ validate_type(utterance[DATA], list, object_label="utterance data")
+ for chunk in utterance[DATA]:
+ validate_type(chunk, dict, object_label="utterance chunk")
+ validate_key(chunk, TEXT, object_label="chunk")
+ if ENTITY in chunk or SLOT_NAME in chunk:
+ mandatory_keys = [ENTITY, SLOT_NAME]
+ validate_keys(chunk, mandatory_keys, object_label="chunk")
+ if is_builtin_entity(chunk[ENTITY]):
+ continue
+ else:
+ validate_key(entities, chunk[ENTITY],
+ object_label=ENTITIES)
+ return intent
+
+
+def _has_any_capitalization(entity_utterances, language):
+ for utterance in entity_utterances:
+ tokens = tokenize_light(utterance, language)
+ if any(t.isupper() or t.istitle() for t in tokens):
+ return True
+ return False
+
+
+def _add_entity_variations(utterances, entity_variations, entity_value):
+ utterances[entity_value] = entity_value
+ for variation in entity_variations[entity_value]:
+ if variation:
+ utterances[variation] = entity_value
+ return utterances
+
+
+def _extract_entity_values(entity):
+ values = set()
+ for ent in entity[DATA]:
+ values.add(ent[VALUE])
+ if entity[USE_SYNONYMS]:
+ values.update(set(ent[SYNONYMS]))
+ return values
+
+
+def _validate_and_format_custom_entity(entity, utterance_entities, language,
+ builtin_entity_parser):
+ validate_type(entity, dict, object_label="entity")
+
+ # TODO: this is here temporarily, only to allow backward compatibility
+ if MATCHING_STRICTNESS not in entity:
+ strictness = entity.get("parser_threshold", 1.0)
+
+ entity[MATCHING_STRICTNESS] = strictness
+
+ mandatory_keys = [USE_SYNONYMS, AUTOMATICALLY_EXTENSIBLE, DATA,
+ MATCHING_STRICTNESS]
+ validate_keys(entity, mandatory_keys, object_label="custom entity")
+ validate_type(entity[USE_SYNONYMS], bool, object_label="use_synonyms")
+ validate_type(entity[AUTOMATICALLY_EXTENSIBLE], bool,
+ object_label="automatically_extensible")
+ validate_type(entity[DATA], list, object_label="entity data")
+ validate_type(entity[MATCHING_STRICTNESS], (float, int),
+ object_label="matching_strictness")
+
+ formatted_entity = dict()
+ formatted_entity[AUTOMATICALLY_EXTENSIBLE] = entity[
+ AUTOMATICALLY_EXTENSIBLE]
+ formatted_entity[MATCHING_STRICTNESS] = entity[MATCHING_STRICTNESS]
+ if LICENSE_INFO in entity:
+ formatted_entity[LICENSE_INFO] = entity[LICENSE_INFO]
+ use_synonyms = entity[USE_SYNONYMS]
+
+ # Validate format and filter out unused data
+ valid_entity_data = []
+ for entry in entity[DATA]:
+ validate_type(entry, dict, object_label="entity entry")
+ validate_keys(entry, [VALUE, SYNONYMS], object_label="entity entry")
+ entry[VALUE] = entry[VALUE].strip()
+ if not entry[VALUE]:
+ continue
+ validate_type(entry[SYNONYMS], list, object_label="entity synonyms")
+ entry[SYNONYMS] = [s.strip() for s in entry[SYNONYMS] if s.strip()]
+ valid_entity_data.append(entry)
+ entity[DATA] = valid_entity_data
+
+ # Compute capitalization before normalizing
+ # Normalization lowercase and hence lead to bad capitalization calculation
+ formatted_entity[CAPITALIZE] = _has_any_capitalization(utterance_entities,
+ language)
+
+ validated_utterances = dict()
+ # Map original values an synonyms
+ for data in entity[DATA]:
+ ent_value = data[VALUE]
+ validated_utterances[ent_value] = ent_value
+ if use_synonyms:
+ for s in data[SYNONYMS]:
+ if s not in validated_utterances:
+ validated_utterances[s] = ent_value
+
+ # Number variations in entities values are expensive since each entity
+ # value is parsed with the builtin entity parser before creating the
+ # variations. We avoid generating these variations if there's enough entity
+ # values
+
+ # Add variations if not colliding
+ all_original_values = _extract_entity_values(entity)
+ if len(entity[DATA]) < VARIATIONS_GENERATION_THRESHOLD:
+ variations_args = {
+ "case": True,
+ "and_": True,
+ "punctuation": True
+ }
+ else:
+ variations_args = {
+ "case": False,
+ "and_": False,
+ "punctuation": False
+ }
+
+ variations_args["numbers"] = len(
+ entity[DATA]) < NUMBER_VARIATIONS_THRESHOLD
+
+ variations = dict()
+ for data in entity[DATA]:
+ ent_value = data[VALUE]
+ values_to_variate = {ent_value}
+ if use_synonyms:
+ values_to_variate.update(set(data[SYNONYMS]))
+ variations[ent_value] = set(
+ v for value in values_to_variate
+ for v in get_string_variations(
+ value, language, builtin_entity_parser, **variations_args)
+ )
+ variation_counter = Counter(
+ [v for variations_ in itervalues(variations) for v in variations_])
+ non_colliding_variations = {
+ value: [
+ v for v in variations if
+ v not in all_original_values and variation_counter[v] == 1
+ ]
+ for value, variations in iteritems(variations)
+ }
+
+ for entry in entity[DATA]:
+ entry_value = entry[VALUE]
+ validated_utterances = _add_entity_variations(
+ validated_utterances, non_colliding_variations, entry_value)
+
+ # Merge utterances entities
+ utterance_entities_variations = {
+ ent: get_string_variations(
+ ent, language, builtin_entity_parser, **variations_args)
+ for ent in utterance_entities
+ }
+
+ for original_ent, variations in iteritems(utterance_entities_variations):
+ if not original_ent or original_ent in validated_utterances:
+ continue
+ validated_utterances[original_ent] = original_ent
+ for variation in variations:
+ if variation and variation not in validated_utterances \
+ and variation not in utterance_entities:
+ validated_utterances[variation] = original_ent
+ formatted_entity[UTTERANCES] = validated_utterances
+ return formatted_entity
+
+
+def _validate_and_format_builtin_entity(entity, utterance_entities):
+ validate_type(entity, dict, object_label="builtin entity")
+ return {UTTERANCES: set(utterance_entities)}