aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/dataset/dataset.py
blob: 2ad7867bcc7803c063e8f2dbc0ada0f7910a0241 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# coding=utf-8
from __future__ import print_function, unicode_literals

import io
from itertools import cycle

from snips_inference_agl.common.utils import unicode_string
from snips_inference_agl.dataset.entity import Entity
from snips_inference_agl.dataset.intent import Intent
from snips_inference_agl.exceptions import DatasetFormatError


class Dataset(object):
    """Dataset used in the main NLU training API

    Consists of intents and entities data. This object can be built either from
    text files (:meth:`.Dataset.from_files`) or from YAML files
    (:meth:`.Dataset.from_yaml_files`).

    Attributes:
        language (str): language of the intents
        intents (list of :class:`.Intent`): intents data
        entities (list of :class:`.Entity`): entities data
    """

    def __init__(self, language, intents, entities):
        self.language = language
        self.intents = intents
        self.entities = entities
        self._add_missing_entities()
        self._ensure_entity_values()

    @classmethod
    def _load_dataset_parts(cls, stream, stream_description):
        from snips_inference_agl.dataset.yaml_wrapper import yaml

        intents = []
        entities = []
        for doc in yaml.safe_load_all(stream):
            doc_type = doc.get("type")
            if doc_type == "entity":
                entities.append(Entity.from_yaml(doc))
            elif doc_type == "intent":
                intents.append(Intent.from_yaml(doc))
            else:
                raise DatasetFormatError(
                    "Invalid 'type' value in YAML file '%s': '%s'"
                    % (stream_description, doc_type))
        return intents, entities

    def _add_missing_entities(self):
        entity_names = set(e.name for e in self.entities)

        # Add entities appearing only in the intents utterances
        for intent in self.intents:
            for entity_name in intent.entities_names:
                if entity_name not in entity_names:
                    entity_names.add(entity_name)
                    self.entities.append(Entity(name=entity_name))

    def _ensure_entity_values(self):
        entities_values = {entity.name: self._get_entity_values(entity)
                           for entity in self.entities}
        for intent in self.intents:
            for utterance in intent.utterances:
                for chunk in utterance.slot_chunks:
                    if chunk.text is not None:
                        continue
                    try:
                        chunk.text = next(entities_values[chunk.entity])
                    except StopIteration:
                        raise DatasetFormatError(
                            "At least one entity value must be provided for "
                            "entity '%s'" % chunk.entity)
        return self

    def _get_entity_values(self, entity):
        from snips_nlu_parsers import get_builtin_entity_examples

        if entity.is_builtin:
            return cycle(get_builtin_entity_examples(
                entity.name, self.language))
        values = [v for utterance in entity.utterances
                  for v in utterance.variations]
        values_set = set(values)
        for intent in self.intents:
            for utterance in intent.utterances:
                for chunk in utterance.slot_chunks:
                    if not chunk.text or chunk.entity != entity.name:
                        continue
                    if chunk.text not in values_set:
                        values_set.add(chunk.text)
                        values.append(chunk.text)
        return cycle(values)

    @property
    def json(self):
        """Dataset data in json format"""
        intents = {intent_data.intent_name: intent_data.json
                   for intent_data in self.intents}
        entities = {entity.name: entity.json for entity in self.entities}
        return dict(language=self.language, intents=intents, entities=entities)