aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/pipeline/configs/slot_filler.py
diff options
context:
space:
mode:
Diffstat (limited to 'snips_inference_agl/pipeline/configs/slot_filler.py')
-rw-r--r--snips_inference_agl/pipeline/configs/slot_filler.py145
1 files changed, 145 insertions, 0 deletions
diff --git a/snips_inference_agl/pipeline/configs/slot_filler.py b/snips_inference_agl/pipeline/configs/slot_filler.py
new file mode 100644
index 0000000..be36e9c
--- /dev/null
+++ b/snips_inference_agl/pipeline/configs/slot_filler.py
@@ -0,0 +1,145 @@
+from __future__ import unicode_literals
+
+from snips_inference_agl.common.from_dict import FromDict
+from snips_inference_agl.constants import STOP_WORDS
+from snips_inference_agl.pipeline.configs import (
+ Config, ProcessingUnitConfig, default_features_factories)
+from snips_inference_agl.resources import merge_required_resources
+
+
+class CRFSlotFillerConfig(FromDict, ProcessingUnitConfig):
+ # pylint: disable=line-too-long
+ """Configuration of a :class:`.CRFSlotFiller`
+
+ Args:
+ feature_factory_configs (list, optional): List of configurations that
+ specify the list of :class:`.CRFFeatureFactory` to use with the CRF
+ tagging_scheme (:class:`.TaggingScheme`, optional): Tagging scheme to
+ use to enrich CRF labels (default=BIO)
+ crf_args (dict, optional): Allow to overwrite the parameters of the CRF
+ defined in *sklearn_crfsuite*, see :class:`sklearn_crfsuite.CRF`
+ (default={"c1": .1, "c2": .1, "algorithm": "lbfgs"})
+ data_augmentation_config (dict or :class:`.SlotFillerDataAugmentationConfig`, optional):
+ Specify how to augment data before training the CRF, see the
+ corresponding config object for more details.
+ random_seed (int, optional): Specify to make the CRF training
+ deterministic and reproducible (default=None)
+ """
+
+ # pylint: enable=line-too-long
+
+ def __init__(self, feature_factory_configs=None,
+ tagging_scheme=None, crf_args=None,
+ data_augmentation_config=None):
+ if tagging_scheme is None:
+ from snips_inference_agl.slot_filler.crf_utils import TaggingScheme
+ tagging_scheme = TaggingScheme.BIO
+ if feature_factory_configs is None:
+ feature_factory_configs = default_features_factories()
+ if crf_args is None:
+ crf_args = _default_crf_args()
+ if data_augmentation_config is None:
+ data_augmentation_config = SlotFillerDataAugmentationConfig()
+ self.feature_factory_configs = feature_factory_configs
+ self._tagging_scheme = None
+ self.tagging_scheme = tagging_scheme
+ self.crf_args = crf_args
+ self._data_augmentation_config = None
+ self.data_augmentation_config = data_augmentation_config
+
+ @property
+ def tagging_scheme(self):
+ return self._tagging_scheme
+
+ @tagging_scheme.setter
+ def tagging_scheme(self, value):
+ from snips_inference_agl.slot_filler.crf_utils import TaggingScheme
+ if isinstance(value, TaggingScheme):
+ self._tagging_scheme = value
+ elif isinstance(value, int):
+ self._tagging_scheme = TaggingScheme(value)
+ else:
+ raise TypeError("Expected instance of TaggingScheme or int but"
+ "received: %s" % type(value))
+
+ @property
+ def data_augmentation_config(self):
+ return self._data_augmentation_config
+
+ @data_augmentation_config.setter
+ def data_augmentation_config(self, value):
+ if isinstance(value, dict):
+ self._data_augmentation_config = \
+ SlotFillerDataAugmentationConfig.from_dict(value)
+ elif isinstance(value, SlotFillerDataAugmentationConfig):
+ self._data_augmentation_config = value
+ else:
+ raise TypeError("Expected instance of "
+ "SlotFillerDataAugmentationConfig or dict but "
+ "received: %s" % type(value))
+
+ @property
+ def unit_name(self):
+ from snips_inference_agl.slot_filler import CRFSlotFiller
+ return CRFSlotFiller.unit_name
+
+ def get_required_resources(self):
+ # Import here to avoid circular imports
+ from snips_inference_agl.slot_filler.feature_factory import CRFFeatureFactory
+
+ resources = self.data_augmentation_config.get_required_resources()
+ for config in self.feature_factory_configs:
+ factory = CRFFeatureFactory.from_config(config)
+ resources = merge_required_resources(
+ resources, factory.get_required_resources())
+ return resources
+
+ def to_dict(self):
+ return {
+ "unit_name": self.unit_name,
+ "feature_factory_configs": self.feature_factory_configs,
+ "crf_args": self.crf_args,
+ "tagging_scheme": self.tagging_scheme.value,
+ "data_augmentation_config":
+ self.data_augmentation_config.to_dict()
+ }
+
+
+class SlotFillerDataAugmentationConfig(FromDict, Config):
+ """Specify how to augment data before training the CRF
+
+ Data augmentation essentially consists in creating additional utterances
+ by combining utterance patterns and slot values
+
+ Args:
+ min_utterances (int, optional): Specify the minimum amount of
+ utterances to generate per intent (default=200)
+ capitalization_ratio (float, optional): If an entity has one or more
+ capitalized values, the data augmentation will randomly capitalize
+ its values with a ratio of *capitalization_ratio* (default=.2)
+ add_builtin_entities_examples (bool, optional): If True, some builtin
+ entity examples will be automatically added to the training data.
+ Default is True.
+ """
+
+ def __init__(self, min_utterances=200, capitalization_ratio=.2,
+ add_builtin_entities_examples=True):
+ self.min_utterances = min_utterances
+ self.capitalization_ratio = capitalization_ratio
+ self.add_builtin_entities_examples = add_builtin_entities_examples
+
+ def get_required_resources(self):
+ return {
+ STOP_WORDS: True
+ }
+
+ def to_dict(self):
+ return {
+ "min_utterances": self.min_utterances,
+ "capitalization_ratio": self.capitalization_ratio,
+ "add_builtin_entities_examples": self.add_builtin_entities_examples
+ }
+
+
+def _default_crf_args():
+ return {"c1": .1, "c2": .1, "algorithm": "lbfgs"}