aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/dataset/utils.py
blob: f147f0fa0833058cc7db198474ee6bac19974ab7 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
from __future__ import unicode_literals

from future.utils import iteritems, itervalues

from snips_inference_agl.constants import (
    DATA, ENTITIES, ENTITY, INTENTS, TEXT, UTTERANCES)
from snips_inference_agl.entity_parser.builtin_entity_parser import is_gazetteer_entity


def extract_utterance_entities(dataset):
    entities_values = {ent_name: set() for ent_name in dataset[ENTITIES]}

    for intent in itervalues(dataset[INTENTS]):
        for utterance in intent[UTTERANCES]:
            for chunk in utterance[DATA]:
                if ENTITY in chunk:
                    entities_values[chunk[ENTITY]].add(chunk[TEXT].strip())
    return {k: list(v) for k, v in iteritems(entities_values)}


def extract_intent_entities(dataset, entity_filter=None):
    intent_entities = {intent: set() for intent in dataset[INTENTS]}
    for intent_name, intent_data in iteritems(dataset[INTENTS]):
        for utterance in intent_data[UTTERANCES]:
            for chunk in utterance[DATA]:
                if ENTITY in chunk:
                    if entity_filter and not entity_filter(chunk[ENTITY]):
                        continue
                    intent_entities[intent_name].add(chunk[ENTITY])
    return intent_entities


def extract_entity_values(dataset, apply_normalization):
    from snips_nlu_utils import normalize

    entities_per_intent = {intent: set() for intent in dataset[INTENTS]}
    intent_entities = extract_intent_entities(dataset)
    for intent, entities in iteritems(intent_entities):
        for entity in entities:
            entity_values = set(dataset[ENTITIES][entity][UTTERANCES])
            if apply_normalization:
                entity_values = {normalize(v) for v in entity_values}
            entities_per_intent[intent].update(entity_values)
    return entities_per_intent


def get_text_from_chunks(chunks):
    return "".join(chunk[TEXT] for chunk in chunks)


def get_dataset_gazetteer_entities(dataset, intent=None):
    if intent is not None:
        return extract_intent_entities(dataset, is_gazetteer_entity)[intent]
    return {e for e in dataset[ENTITIES] if is_gazetteer_entity(e)}


def get_stop_words_whitelist(dataset, stop_words):
    """Extracts stop words whitelists per intent consisting of entity values
    that appear in the stop_words list"""
    entity_values_per_intent = extract_entity_values(
        dataset, apply_normalization=True)
    stop_words_whitelist = dict()
    for intent, entity_values in iteritems(entity_values_per_intent):
        whitelist = stop_words.intersection(entity_values)
        if whitelist:
            stop_words_whitelist[intent] = whitelist
    return stop_words_whitelist