diff options
Diffstat (limited to 'snips_inference_agl/data_augmentation.py')
-rw-r--r-- | snips_inference_agl/data_augmentation.py | 121 |
1 files changed, 121 insertions, 0 deletions
diff --git a/snips_inference_agl/data_augmentation.py b/snips_inference_agl/data_augmentation.py new file mode 100644 index 0000000..5a37f5e --- /dev/null +++ b/snips_inference_agl/data_augmentation.py @@ -0,0 +1,121 @@ +from __future__ import unicode_literals + +from builtins import next +from copy import deepcopy +from itertools import cycle + +from future.utils import iteritems + +from snips_inference_agl.constants import ( + CAPITALIZE, DATA, ENTITIES, ENTITY, INTENTS, TEXT, UTTERANCES) +from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity +from snips_inference_agl.languages import get_default_sep +from snips_inference_agl.preprocessing import tokenize_light +from snips_inference_agl.resources import get_stop_words + + +def capitalize(text, language, resources): + tokens = tokenize_light(text, language) + stop_words = get_stop_words(resources) + return get_default_sep(language).join( + t.title() if t.lower() not in stop_words + else t.lower() for t in tokens) + + +def capitalize_utterances(utterances, entities, language, ratio, resources, + random_state): + capitalized_utterances = [] + for utterance in utterances: + capitalized_utterance = deepcopy(utterance) + for i, chunk in enumerate(capitalized_utterance[DATA]): + capitalized_utterance[DATA][i][TEXT] = chunk[TEXT].lower() + if ENTITY not in chunk: + continue + entity_label = chunk[ENTITY] + if is_builtin_entity(entity_label): + continue + if not entities[entity_label][CAPITALIZE]: + continue + if random_state.rand() > ratio: + continue + capitalized_utterance[DATA][i][TEXT] = capitalize( + chunk[TEXT], language, resources) + capitalized_utterances.append(capitalized_utterance) + return capitalized_utterances + + +def generate_utterance(contexts_iterator, entities_iterators): + context = deepcopy(next(contexts_iterator)) + context_data = [] + for chunk in context[DATA]: + if ENTITY in chunk: + chunk[TEXT] = deepcopy( + next(entities_iterators[chunk[ENTITY]])) + chunk[TEXT] = chunk[TEXT].strip() + " " + context_data.append(chunk) + context[DATA] = context_data + return context + + +def get_contexts_iterator(dataset, intent_name, random_state): + shuffled_utterances = random_state.permutation( + dataset[INTENTS][intent_name][UTTERANCES]) + return cycle(shuffled_utterances) + + +def get_entities_iterators(intent_entities, language, + add_builtin_entities_examples, random_state): + from snips_nlu_parsers import get_builtin_entity_examples + + entities_its = dict() + for entity_name, entity in iteritems(intent_entities): + utterance_values = random_state.permutation(sorted(entity[UTTERANCES])) + if add_builtin_entities_examples and is_builtin_entity(entity_name): + entity_examples = get_builtin_entity_examples( + entity_name, language) + # Builtin entity examples must be kept first in the iterator to + # ensure that they are used when augmenting data + iterator_values = entity_examples + list(utterance_values) + else: + iterator_values = utterance_values + entities_its[entity_name] = cycle(iterator_values) + return entities_its + + +def get_intent_entities(dataset, intent_name): + intent_entities = set() + for utterance in dataset[INTENTS][intent_name][UTTERANCES]: + for chunk in utterance[DATA]: + if ENTITY in chunk: + intent_entities.add(chunk[ENTITY]) + return sorted(intent_entities) + + +def num_queries_to_generate(dataset, intent_name, min_utterances): + nb_utterances = len(dataset[INTENTS][intent_name][UTTERANCES]) + return max(nb_utterances, min_utterances) + + +def augment_utterances(dataset, intent_name, language, min_utterances, + capitalization_ratio, add_builtin_entities_examples, + resources, random_state): + contexts_it = get_contexts_iterator(dataset, intent_name, random_state) + intent_entities = {e: dataset[ENTITIES][e] + for e in get_intent_entities(dataset, intent_name)} + entities_its = get_entities_iterators(intent_entities, language, + add_builtin_entities_examples, + random_state) + generated_utterances = [] + nb_to_generate = num_queries_to_generate(dataset, intent_name, + min_utterances) + while nb_to_generate > 0: + generated_utterance = generate_utterance(contexts_it, entities_its) + generated_utterances.append(generated_utterance) + nb_to_generate -= 1 + + generated_utterances = capitalize_utterances( + generated_utterances, dataset[ENTITIES], language, + ratio=capitalization_ratio, resources=resources, + random_state=random_state) + + return generated_utterances |