aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/pipeline/configs/slot_filler.py
blob: be36e9c0c74f9a14b0db37bcd0dd590b9d914a57 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
from __future__ import unicode_literals

from snips_inference_agl.common.from_dict import FromDict
from snips_inference_agl.constants import STOP_WORDS
from snips_inference_agl.pipeline.configs import (
    Config, ProcessingUnitConfig, default_features_factories)
from snips_inference_agl.resources import merge_required_resources


class CRFSlotFillerConfig(FromDict, ProcessingUnitConfig):
    # pylint: disable=line-too-long
    """Configuration of a :class:`.CRFSlotFiller`

    Args:
        feature_factory_configs (list, optional): List of configurations that
            specify the list of :class:`.CRFFeatureFactory` to use with the CRF
        tagging_scheme (:class:`.TaggingScheme`, optional): Tagging scheme to
            use to enrich CRF labels (default=BIO)
        crf_args (dict, optional): Allow to overwrite the parameters of the CRF
            defined in *sklearn_crfsuite*, see :class:`sklearn_crfsuite.CRF`
            (default={"c1": .1, "c2": .1, "algorithm": "lbfgs"})
        data_augmentation_config (dict or :class:`.SlotFillerDataAugmentationConfig`, optional):
            Specify how to augment data before training the CRF, see the
            corresponding config object for more details.
        random_seed (int, optional): Specify to make the CRF training
            deterministic and reproducible (default=None)
    """

    # pylint: enable=line-too-long

    def __init__(self, feature_factory_configs=None,
                 tagging_scheme=None, crf_args=None,
                 data_augmentation_config=None):
        if tagging_scheme is None:
            from snips_inference_agl.slot_filler.crf_utils import TaggingScheme
            tagging_scheme = TaggingScheme.BIO
        if feature_factory_configs is None:
            feature_factory_configs = default_features_factories()
        if crf_args is None:
            crf_args = _default_crf_args()
        if data_augmentation_config is None:
            data_augmentation_config = SlotFillerDataAugmentationConfig()
        self.feature_factory_configs = feature_factory_configs
        self._tagging_scheme = None
        self.tagging_scheme = tagging_scheme
        self.crf_args = crf_args
        self._data_augmentation_config = None
        self.data_augmentation_config = data_augmentation_config

    @property
    def tagging_scheme(self):
        return self._tagging_scheme

    @tagging_scheme.setter
    def tagging_scheme(self, value):
        from snips_inference_agl.slot_filler.crf_utils import TaggingScheme
        if isinstance(value, TaggingScheme):
            self._tagging_scheme = value
        elif isinstance(value, int):
            self._tagging_scheme = TaggingScheme(value)
        else:
            raise TypeError("Expected instance of TaggingScheme or int but"
                            "received: %s" % type(value))

    @property
    def data_augmentation_config(self):
        return self._data_augmentation_config

    @data_augmentation_config.setter
    def data_augmentation_config(self, value):
        if isinstance(value, dict):
            self._data_augmentation_config = \
                SlotFillerDataAugmentationConfig.from_dict(value)
        elif isinstance(value, SlotFillerDataAugmentationConfig):
            self._data_augmentation_config = value
        else:
            raise TypeError("Expected instance of "
                            "SlotFillerDataAugmentationConfig or dict but "
                            "received: %s" % type(value))

    @property
    def unit_name(self):
        from snips_inference_agl.slot_filler import CRFSlotFiller
        return CRFSlotFiller.unit_name

    def get_required_resources(self):
        # Import here to avoid circular imports
        from snips_inference_agl.slot_filler.feature_factory import CRFFeatureFactory

        resources = self.data_augmentation_config.get_required_resources()
        for config in self.feature_factory_configs:
            factory = CRFFeatureFactory.from_config(config)
            resources = merge_required_resources(
                resources, factory.get_required_resources())
        return resources

    def to_dict(self):
        return {
            "unit_name": self.unit_name,
            "feature_factory_configs": self.feature_factory_configs,
            "crf_args": self.crf_args,
            "tagging_scheme": self.tagging_scheme.value,
            "data_augmentation_config":
                self.data_augmentation_config.to_dict()
        }


class SlotFillerDataAugmentationConfig(FromDict, Config):
    """Specify how to augment data before training the CRF

    Data augmentation essentially consists in creating additional utterances
    by combining utterance patterns and slot values

    Args:
        min_utterances (int, optional): Specify the minimum amount of
            utterances to generate per intent (default=200)
        capitalization_ratio (float, optional): If an entity has one or more
            capitalized values, the data augmentation will randomly capitalize
            its values with a ratio of *capitalization_ratio* (default=.2)
        add_builtin_entities_examples (bool, optional): If True, some builtin
            entity examples will be automatically added to the training data.
            Default is True.
    """

    def __init__(self, min_utterances=200, capitalization_ratio=.2,
                 add_builtin_entities_examples=True):
        self.min_utterances = min_utterances
        self.capitalization_ratio = capitalization_ratio
        self.add_builtin_entities_examples = add_builtin_entities_examples

    def get_required_resources(self):
        return {
            STOP_WORDS: True
        }

    def to_dict(self):
        return {
            "min_utterances": self.min_utterances,
            "capitalization_ratio": self.capitalization_ratio,
            "add_builtin_entities_examples": self.add_builtin_entities_examples
        }


def _default_crf_args():
    return {"c1": .1, "c2": .1, "algorithm": "lbfgs"}