aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/slot_filler
diff options
context:
space:
mode:
Diffstat (limited to 'snips_inference_agl/slot_filler')
-rw-r--r--snips_inference_agl/slot_filler/__init__.py3
-rw-r--r--snips_inference_agl/slot_filler/crf_slot_filler.py467
-rw-r--r--snips_inference_agl/slot_filler/crf_utils.py219
-rw-r--r--snips_inference_agl/slot_filler/feature.py69
-rw-r--r--snips_inference_agl/slot_filler/feature_factory.py568
-rw-r--r--snips_inference_agl/slot_filler/features_utils.py47
-rw-r--r--snips_inference_agl/slot_filler/keyword_slot_filler.py70
-rw-r--r--snips_inference_agl/slot_filler/slot_filler.py33
8 files changed, 1476 insertions, 0 deletions
diff --git a/snips_inference_agl/slot_filler/__init__.py b/snips_inference_agl/slot_filler/__init__.py
new file mode 100644
index 0000000..70974aa
--- /dev/null
+++ b/snips_inference_agl/slot_filler/__init__.py
@@ -0,0 +1,3 @@
+from .crf_slot_filler import CRFSlotFiller
+from .feature import Feature
+from .slot_filler import SlotFiller
diff --git a/snips_inference_agl/slot_filler/crf_slot_filler.py b/snips_inference_agl/slot_filler/crf_slot_filler.py
new file mode 100644
index 0000000..e6ec7e6
--- /dev/null
+++ b/snips_inference_agl/slot_filler/crf_slot_filler.py
@@ -0,0 +1,467 @@
+from __future__ import unicode_literals
+
+import base64
+import json
+import logging
+import math
+import os
+import shutil
+import tempfile
+from builtins import range
+from copy import deepcopy
+from pathlib import Path
+
+from future.utils import iteritems
+
+from snips_inference_agl.common.dataset_utils import get_slot_name_mapping
+from snips_inference_agl.common.dict_utils import UnupdatableDict
+from snips_inference_agl.common.io_utils import mkdir_p
+from snips_inference_agl.common.log_utils import DifferedLoggingMessage, log_elapsed_time
+from snips_inference_agl.common.utils import (
+ check_persisted_path, fitted_required, json_string)
+from snips_inference_agl.constants import DATA, LANGUAGE
+from snips_inference_agl.data_augmentation import augment_utterances
+from snips_inference_agl.dataset import validate_and_format_dataset
+from snips_inference_agl.exceptions import LoadingError
+from snips_inference_agl.pipeline.configs import CRFSlotFillerConfig
+from snips_inference_agl.preprocessing import tokenize
+from snips_inference_agl.slot_filler.crf_utils import (
+ OUTSIDE, TAGS, TOKENS, tags_to_slots, utterance_to_sample)
+from snips_inference_agl.slot_filler.feature import TOKEN_NAME
+from snips_inference_agl.slot_filler.feature_factory import CRFFeatureFactory
+from snips_inference_agl.slot_filler.slot_filler import SlotFiller
+
+CRF_MODEL_FILENAME = "model.crfsuite"
+
+logger = logging.getLogger(__name__)
+
+
+@SlotFiller.register("crf_slot_filler")
+class CRFSlotFiller(SlotFiller):
+ """Slot filler which uses Linear-Chain Conditional Random Fields underneath
+
+ Check https://en.wikipedia.org/wiki/Conditional_random_field to learn
+ more about CRFs
+ """
+
+ config_type = CRFSlotFillerConfig
+
+ def __init__(self, config=None, **shared):
+ """The CRF slot filler can be configured by passing a
+ :class:`.CRFSlotFillerConfig`"""
+ # The CRFSlotFillerConfig must be deep-copied as it is mutated when
+ # fitting the feature factories
+ config = deepcopy(config)
+ super(CRFSlotFiller, self).__init__(config, **shared)
+ self.crf_model = None
+ self.features_factories = [
+ CRFFeatureFactory.from_config(conf, **shared)
+ for conf in self.config.feature_factory_configs]
+ self._features = None
+ self.language = None
+ self.intent = None
+ self.slot_name_mapping = None
+
+ @property
+ def features(self):
+ """List of :class:`.Feature` used by the CRF"""
+ if self._features is None:
+ self._features = []
+ feature_names = set()
+ for factory in self.features_factories:
+ for feature in factory.build_features():
+ if feature.name in feature_names:
+ raise KeyError("Duplicated feature: %s" % feature.name)
+ feature_names.add(feature.name)
+ self._features.append(feature)
+ return self._features
+
+ @property
+ def labels(self):
+ """List of CRF labels
+
+ These labels differ from the slot names as they contain an additional
+ prefix which depends on the :class:`.TaggingScheme` that is used
+ (BIO by default).
+ """
+ labels = []
+ if self.crf_model.tagger_ is not None:
+ labels = [_decode_tag(label) for label in
+ self.crf_model.tagger_.labels()]
+ return labels
+
+ @property
+ def fitted(self):
+ """Whether or not the slot filler has already been fitted"""
+ return self.slot_name_mapping is not None
+
+ @log_elapsed_time(logger, logging.INFO,
+ "Fitted CRFSlotFiller in {elapsed_time}")
+ # pylint:disable=arguments-differ
+ def fit(self, dataset, intent):
+ """Fits the slot filler
+
+ Args:
+ dataset (dict): A valid Snips dataset
+ intent (str): The specific intent of the dataset to train
+ the slot filler on
+
+ Returns:
+ :class:`CRFSlotFiller`: The same instance, trained
+ """
+ logger.info("Fitting %s slot filler...", intent)
+ dataset = validate_and_format_dataset(dataset)
+ self.load_resources_if_needed(dataset[LANGUAGE])
+ self.fit_builtin_entity_parser_if_needed(dataset)
+ self.fit_custom_entity_parser_if_needed(dataset)
+
+ for factory in self.features_factories:
+ factory.custom_entity_parser = self.custom_entity_parser
+ factory.builtin_entity_parser = self.builtin_entity_parser
+ factory.resources = self.resources
+
+ self.language = dataset[LANGUAGE]
+ self.intent = intent
+ self.slot_name_mapping = get_slot_name_mapping(dataset, intent)
+
+ if not self.slot_name_mapping:
+ # No need to train the CRF if the intent has no slots
+ return self
+
+ augmented_intent_utterances = augment_utterances(
+ dataset, self.intent, language=self.language,
+ resources=self.resources, random_state=self.random_state,
+ **self.config.data_augmentation_config.to_dict())
+
+ crf_samples = [
+ utterance_to_sample(u[DATA], self.config.tagging_scheme,
+ self.language)
+ for u in augmented_intent_utterances]
+
+ for factory in self.features_factories:
+ factory.fit(dataset, intent)
+
+ # Ensure that X, Y are safe and that the OUTSIDE label is learnt to
+ # avoid segfault at inference time
+ # pylint: disable=C0103
+ X = [self.compute_features(sample[TOKENS], drop_out=True)
+ for sample in crf_samples]
+ Y = [[tag for tag in sample[TAGS]] for sample in crf_samples]
+ X, Y = _ensure_safe(X, Y)
+
+ # ensure ascii tags
+ Y = [[_encode_tag(tag) for tag in y] for y in Y]
+
+ # pylint: enable=C0103
+ self.crf_model = _get_crf_model(self.config.crf_args)
+ self.crf_model.fit(X, Y)
+
+ logger.debug(
+ "Most relevant features for %s:\n%s", self.intent,
+ DifferedLoggingMessage(self.log_weights))
+ return self
+
+ # pylint:enable=arguments-differ
+
+ @fitted_required
+ def get_slots(self, text):
+ """Extracts slots from the provided text
+
+ Returns:
+ list of dict: The list of extracted slots
+
+ Raises:
+ NotTrained: When the slot filler is not fitted
+ """
+ if not self.slot_name_mapping:
+ # Early return if the intent has no slots
+ return []
+
+ tokens = tokenize(text, self.language)
+ if not tokens:
+ return []
+ features = self.compute_features(tokens)
+ tags = self.crf_model.predict_single(features)
+ logger.debug(DifferedLoggingMessage(
+ self.log_inference_weights, text, tokens=tokens, features=features,
+ tags=tags))
+ decoded_tags = [_decode_tag(t) for t in tags]
+ return tags_to_slots(text, tokens, decoded_tags,
+ self.config.tagging_scheme,
+ self.slot_name_mapping)
+
+ def compute_features(self, tokens, drop_out=False):
+ """Computes features on the provided tokens
+
+ The *drop_out* parameters allows to activate drop out on features that
+ have a positive drop out ratio. This should only be used during
+ training.
+ """
+
+ cache = [{TOKEN_NAME: token} for token in tokens]
+ features = []
+ for i in range(len(tokens)):
+ token_features = UnupdatableDict()
+ for feature in self.features:
+ f_drop_out = feature.drop_out
+ if drop_out and self.random_state.rand() < f_drop_out:
+ continue
+ value = feature.compute(i, cache)
+ if value is not None:
+ token_features[feature.name] = value
+ features.append(token_features)
+ return features
+
+ @fitted_required
+ def get_sequence_probability(self, tokens, labels):
+ """Gives the joint probability of a sequence of tokens and CRF labels
+
+ Args:
+ tokens (list of :class:`.Token`): list of tokens
+ labels (list of str): CRF labels with their tagging scheme prefix
+ ("B-color", "I-color", "O", etc)
+
+ Note:
+ The absolute value returned here is generally not very useful,
+ however it can be used to compare a sequence of labels relatively
+ to another one.
+ """
+ if not self.slot_name_mapping:
+ return 0.0 if any(label != OUTSIDE for label in labels) else 1.0
+ features = self.compute_features(tokens)
+ return self._get_sequence_probability(features, labels)
+
+ @fitted_required
+ def _get_sequence_probability(self, features, labels):
+ # Use a default substitution label when a label was not seen during
+ # training
+ substitution_label = OUTSIDE if OUTSIDE in self.labels else \
+ self.labels[0]
+ cleaned_labels = [
+ _encode_tag(substitution_label if l not in self.labels else l)
+ for l in labels]
+ self.crf_model.tagger_.set(features)
+ return self.crf_model.tagger_.probability(cleaned_labels)
+
+ @fitted_required
+ def log_weights(self):
+ """Returns a logs for both the label-to-label and label-to-features
+ weights"""
+ if not self.slot_name_mapping:
+ return "No weights to display: intent '%s' has no slots" \
+ % self.intent
+ log = ""
+ transition_features = self.crf_model.transition_features_
+ transition_features = sorted(
+ iteritems(transition_features), key=_weight_absolute_value,
+ reverse=True)
+ log += "\nTransition weights: \n\n"
+ for (state_1, state_2), weight in transition_features:
+ log += "\n%s %s: %s" % (
+ _decode_tag(state_1), _decode_tag(state_2), weight)
+ feature_weights = self.crf_model.state_features_
+ feature_weights = sorted(
+ iteritems(feature_weights), key=_weight_absolute_value,
+ reverse=True)
+ log += "\n\nFeature weights: \n\n"
+ for (feat, tag), weight in feature_weights:
+ log += "\n%s %s: %s" % (feat, _decode_tag(tag), weight)
+ return log
+
+ def log_inference_weights(self, text, tokens, features, tags):
+ model_features = set(
+ f for (f, _), w in iteritems(self.crf_model.state_features_))
+ log = "Feature weights for \"%s\":\n\n" % text
+ max_index = len(tokens) - 1
+ tokens_logs = []
+ for i, (token, feats, tag) in enumerate(zip(tokens, features, tags)):
+ token_log = "# Token \"%s\" (tagged as %s):" \
+ % (token.value, _decode_tag(tag))
+ if i != 0:
+ weights = sorted(self._get_outgoing_weights(tags[i - 1]),
+ key=_weight_absolute_value, reverse=True)
+ if weights:
+ token_log += "\n\nTransition weights from previous tag:"
+ weight_lines = (
+ "- (%s, %s) -> %s"
+ % (_decode_tag(a), _decode_tag(b), w)
+ for (a, b), w in weights
+ )
+ token_log += "\n" + "\n".join(weight_lines)
+ else:
+ token_log += \
+ "\n\nNo transition from previous tag seen at" \
+ " train time !"
+
+ if i != max_index:
+ weights = sorted(self._get_incoming_weights(tags[i + 1]),
+ key=_weight_absolute_value, reverse=True)
+ if weights:
+ token_log += "\n\nTransition weights to next tag:"
+ weight_lines = (
+ "- (%s, %s) -> %s"
+ % (_decode_tag(a), _decode_tag(b), w)
+ for (a, b), w in weights
+ )
+ token_log += "\n" + "\n".join(weight_lines)
+ else:
+ token_log += \
+ "\n\nNo transition to next tag seen at train time !"
+ feats = [":".join(f) for f in iteritems(feats)]
+ weights = (w for f in feats for w in self._get_feature_weight(f))
+ weights = sorted(weights, key=_weight_absolute_value, reverse=True)
+ if weights:
+ token_log += "\n\nFeature weights:\n"
+ token_log += "\n".join(
+ "- (%s, %s) -> %s"
+ % (f, _decode_tag(t), w) for (f, t), w in weights
+ )
+ else:
+ token_log += "\n\nNo feature weights !"
+
+ unseen_features = sorted(
+ set(f for f in feats if f not in model_features))
+ if unseen_features:
+ token_log += "\n\nFeatures not seen at train time:\n%s" % \
+ "\n".join("- %s" % f for f in unseen_features)
+ tokens_logs.append(token_log)
+
+ log += "\n\n\n".join(tokens_logs)
+ return log
+
+ @fitted_required
+ def _get_incoming_weights(self, tag):
+ return [((first, second), w) for (first, second), w
+ in iteritems(self.crf_model.transition_features_)
+ if second == tag]
+
+ @fitted_required
+ def _get_outgoing_weights(self, tag):
+ return [((first, second), w) for (first, second), w
+ in iteritems(self.crf_model.transition_features_)
+ if first == tag]
+
+ @fitted_required
+ def _get_feature_weight(self, feature):
+ return [((f, tag), w) for (f, tag), w
+ in iteritems(self.crf_model.state_features_) if f == feature]
+
+ @check_persisted_path
+ def persist(self, path):
+ """Persists the object at the given path"""
+ path.mkdir()
+
+ crf_model_file = None
+ if self.crf_model is not None:
+ crf_model_file = CRF_MODEL_FILENAME
+ destination = path / crf_model_file
+ shutil.copy(self.crf_model.modelfile.name, str(destination))
+ # On windows, permissions of crfsuite files are correct
+ if os.name == "posix":
+ umask = os.umask(0o022) # retrieve the system umask
+ os.umask(umask) # restore the sys umask to its original value
+ os.chmod(str(destination), 0o644 & ~umask)
+
+ model = {
+ "language_code": self.language,
+ "intent": self.intent,
+ "crf_model_file": crf_model_file,
+ "slot_name_mapping": self.slot_name_mapping,
+ "config": self.config.to_dict(),
+ }
+ model_json = json_string(model)
+ model_path = path / "slot_filler.json"
+ with model_path.open(mode="w", encoding="utf8") as f:
+ f.write(model_json)
+ self.persist_metadata(path)
+
+ @classmethod
+ def from_path(cls, path, **shared):
+ """Loads a :class:`CRFSlotFiller` instance from a path
+
+ The data at the given path must have been generated using
+ :func:`~CRFSlotFiller.persist`
+ """
+ path = Path(path)
+ model_path = path / "slot_filler.json"
+ if not model_path.exists():
+ raise LoadingError(
+ "Missing slot filler model file: %s" % model_path.name)
+
+ with model_path.open(encoding="utf8") as f:
+ model = json.load(f)
+
+ slot_filler_config = cls.config_type.from_dict(model["config"])
+ slot_filler = cls(config=slot_filler_config, **shared)
+ slot_filler.language = model["language_code"]
+ slot_filler.intent = model["intent"]
+ slot_filler.slot_name_mapping = model["slot_name_mapping"]
+ crf_model_file = model["crf_model_file"]
+ if crf_model_file is not None:
+ crf = _crf_model_from_path(path / crf_model_file)
+ slot_filler.crf_model = crf
+ return slot_filler
+
+ def _cleanup(self):
+ if self.crf_model is not None:
+ self.crf_model.modelfile.cleanup()
+
+ def __del__(self):
+ self._cleanup()
+
+
+def _get_crf_model(crf_args):
+ from sklearn_crfsuite import CRF
+
+ model_filename = crf_args.get("model_filename", None)
+ if model_filename is not None:
+ directory = Path(model_filename).parent
+ if not directory.is_dir():
+ mkdir_p(directory)
+
+ return CRF(model_filename=model_filename, **crf_args)
+
+
+def _encode_tag(tag):
+ return base64.b64encode(tag.encode("utf8"))
+
+
+def _decode_tag(tag):
+ return base64.b64decode(tag).decode("utf8")
+
+
+def _crf_model_from_path(crf_model_path):
+ from sklearn_crfsuite import CRF
+
+ with crf_model_path.open(mode="rb") as f:
+ crf_model_data = f.read()
+ with tempfile.NamedTemporaryFile(suffix=".crfsuite", prefix="model",
+ delete=False) as f:
+ f.write(crf_model_data)
+ f.flush()
+ crf = CRF(model_filename=f.name)
+ return crf
+
+
+# pylint: disable=invalid-name
+def _ensure_safe(X, Y):
+ """Ensures that Y has at least one not empty label, otherwise the CRF model
+ does not contain any label and crashes at
+
+ Args:
+ X: features
+ Y: labels
+
+ Returns:
+ (safe_X, safe_Y): a pair of safe features and labels
+ """
+ safe_X = list(X)
+ safe_Y = list(Y)
+ if not any(X) or not any(Y):
+ safe_X.append([""]) # empty feature
+ safe_Y.append([OUTSIDE]) # outside label
+ return safe_X, safe_Y
+
+
+def _weight_absolute_value(x):
+ return math.fabs(x[1])
diff --git a/snips_inference_agl/slot_filler/crf_utils.py b/snips_inference_agl/slot_filler/crf_utils.py
new file mode 100644
index 0000000..817a59b
--- /dev/null
+++ b/snips_inference_agl/slot_filler/crf_utils.py
@@ -0,0 +1,219 @@
+from __future__ import unicode_literals
+
+from builtins import range
+from enum import Enum, unique
+
+from snips_inference_agl.constants import END, SLOT_NAME, START, TEXT
+from snips_inference_agl.preprocessing import Token, tokenize
+from snips_inference_agl.result import unresolved_slot
+
+BEGINNING_PREFIX = "B-"
+INSIDE_PREFIX = "I-"
+LAST_PREFIX = "L-"
+UNIT_PREFIX = "U-"
+OUTSIDE = "O"
+
+RANGE = "range"
+TAGS = "tags"
+TOKENS = "tokens"
+
+
+@unique
+class TaggingScheme(Enum):
+ """CRF Coding Scheme"""
+
+ IO = 0
+ """Inside-Outside scheme"""
+ BIO = 1
+ """Beginning-Inside-Outside scheme"""
+ BILOU = 2
+ """Beginning-Inside-Last-Outside-Unit scheme, sometimes referred as
+ BWEMO"""
+
+
+def tag_name_to_slot_name(tag):
+ return tag[2:]
+
+
+def start_of_io_slot(tags, i):
+ if i == 0:
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ return tags[i - 1] == OUTSIDE
+
+
+def end_of_io_slot(tags, i):
+ if i + 1 == len(tags):
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ return tags[i + 1] == OUTSIDE
+
+
+def start_of_bio_slot(tags, i):
+ if i == 0:
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ if tags[i].startswith(BEGINNING_PREFIX):
+ return True
+ if tags[i - 1] != OUTSIDE:
+ return False
+ return True
+
+
+def end_of_bio_slot(tags, i):
+ if i + 1 == len(tags):
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ if tags[i + 1].startswith(INSIDE_PREFIX):
+ return False
+ return True
+
+
+def start_of_bilou_slot(tags, i):
+ if i == 0:
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ if tags[i].startswith(BEGINNING_PREFIX):
+ return True
+ if tags[i].startswith(UNIT_PREFIX):
+ return True
+ if tags[i - 1].startswith(UNIT_PREFIX):
+ return True
+ if tags[i - 1].startswith(LAST_PREFIX):
+ return True
+ if tags[i - 1] != OUTSIDE:
+ return False
+ return True
+
+
+def end_of_bilou_slot(tags, i):
+ if i + 1 == len(tags):
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ if tags[i + 1] == OUTSIDE:
+ return True
+ if tags[i].startswith(LAST_PREFIX):
+ return True
+ if tags[i].startswith(UNIT_PREFIX):
+ return True
+ if tags[i + 1].startswith(BEGINNING_PREFIX):
+ return True
+ if tags[i + 1].startswith(UNIT_PREFIX):
+ return True
+ return False
+
+
+def _tags_to_preslots(tags, tokens, is_start_of_slot, is_end_of_slot):
+ slots = []
+ current_slot_start = 0
+ for i, tag in enumerate(tags):
+ if is_start_of_slot(tags, i):
+ current_slot_start = i
+ if is_end_of_slot(tags, i):
+ slots.append({
+ RANGE: {
+ START: tokens[current_slot_start].start,
+ END: tokens[i].end
+ },
+ SLOT_NAME: tag_name_to_slot_name(tag)
+ })
+ current_slot_start = i
+ return slots
+
+
+def tags_to_preslots(tokens, tags, tagging_scheme):
+ if tagging_scheme == TaggingScheme.IO:
+ slots = _tags_to_preslots(tags, tokens, start_of_io_slot,
+ end_of_io_slot)
+ elif tagging_scheme == TaggingScheme.BIO:
+ slots = _tags_to_preslots(tags, tokens, start_of_bio_slot,
+ end_of_bio_slot)
+ elif tagging_scheme == TaggingScheme.BILOU:
+ slots = _tags_to_preslots(tags, tokens, start_of_bilou_slot,
+ end_of_bilou_slot)
+ else:
+ raise ValueError("Unknown tagging scheme %s" % tagging_scheme)
+ return slots
+
+
+def tags_to_slots(text, tokens, tags, tagging_scheme, intent_slots_mapping):
+ slots = tags_to_preslots(tokens, tags, tagging_scheme)
+ return [
+ unresolved_slot(match_range=slot[RANGE],
+ value=text[slot[RANGE][START]:slot[RANGE][END]],
+ entity=intent_slots_mapping[slot[SLOT_NAME]],
+ slot_name=slot[SLOT_NAME])
+ for slot in slots
+ ]
+
+
+def positive_tagging(tagging_scheme, slot_name, slot_size):
+ if slot_name == OUTSIDE:
+ return [OUTSIDE for _ in range(slot_size)]
+
+ if tagging_scheme == TaggingScheme.IO:
+ tags = [INSIDE_PREFIX + slot_name for _ in range(slot_size)]
+ elif tagging_scheme == TaggingScheme.BIO:
+ if slot_size > 0:
+ tags = [BEGINNING_PREFIX + slot_name]
+ tags += [INSIDE_PREFIX + slot_name for _ in range(1, slot_size)]
+ else:
+ tags = []
+ elif tagging_scheme == TaggingScheme.BILOU:
+ if slot_size == 0:
+ tags = []
+ elif slot_size == 1:
+ tags = [UNIT_PREFIX + slot_name]
+ else:
+ tags = [BEGINNING_PREFIX + slot_name]
+ tags += [INSIDE_PREFIX + slot_name
+ for _ in range(1, slot_size - 1)]
+ tags.append(LAST_PREFIX + slot_name)
+ else:
+ raise ValueError("Invalid tagging scheme %s" % tagging_scheme)
+ return tags
+
+
+def negative_tagging(size):
+ return [OUTSIDE for _ in range(size)]
+
+
+def utterance_to_sample(query_data, tagging_scheme, language):
+ tokens, tags = [], []
+ current_length = 0
+ for chunk in query_data:
+ chunk_tokens = tokenize(chunk[TEXT], language)
+ tokens += [Token(t.value, current_length + t.start,
+ current_length + t.end) for t in chunk_tokens]
+ current_length += len(chunk[TEXT])
+ if SLOT_NAME not in chunk:
+ tags += negative_tagging(len(chunk_tokens))
+ else:
+ tags += positive_tagging(tagging_scheme, chunk[SLOT_NAME],
+ len(chunk_tokens))
+ return {TOKENS: tokens, TAGS: tags}
+
+
+def get_scheme_prefix(index, indexes, tagging_scheme):
+ if tagging_scheme == TaggingScheme.IO:
+ return INSIDE_PREFIX
+ elif tagging_scheme == TaggingScheme.BIO:
+ if index == indexes[0]:
+ return BEGINNING_PREFIX
+ return INSIDE_PREFIX
+ elif tagging_scheme == TaggingScheme.BILOU:
+ if len(indexes) == 1:
+ return UNIT_PREFIX
+ if index == indexes[0]:
+ return BEGINNING_PREFIX
+ if index == indexes[-1]:
+ return LAST_PREFIX
+ return INSIDE_PREFIX
+ else:
+ raise ValueError("Invalid tagging scheme %s" % tagging_scheme)
diff --git a/snips_inference_agl/slot_filler/feature.py b/snips_inference_agl/slot_filler/feature.py
new file mode 100644
index 0000000..a6da552
--- /dev/null
+++ b/snips_inference_agl/slot_filler/feature.py
@@ -0,0 +1,69 @@
+from __future__ import unicode_literals
+
+from builtins import object
+
+TOKEN_NAME = "token"
+
+
+class Feature(object):
+ """CRF Feature which is used by :class:`.CRFSlotFiller`
+
+ Attributes:
+ base_name (str): Feature name (e.g. 'is_digit', 'is_first' etc)
+ func (function): The actual feature function for example:
+
+ def is_first(tokens, token_index):
+ return "1" if token_index == 0 else None
+
+ offset (int, optional): Token offset to consider when computing
+ the feature (e.g -1 for computing the feature on the previous word)
+ drop_out (float, optional): Drop out to use when computing the
+ feature during training
+
+ Note:
+ The easiest way to add additional features to the existing ones is
+ to create a :class:`.CRFFeatureFactory`
+ """
+
+ def __init__(self, base_name, func, offset=0, drop_out=0):
+ if base_name == TOKEN_NAME:
+ raise ValueError("'%s' name is reserved" % TOKEN_NAME)
+ self.offset = offset
+ self._name = None
+ self._base_name = None
+ self.base_name = base_name
+ self.function = func
+ self.drop_out = drop_out
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def base_name(self):
+ return self._base_name
+
+ @base_name.setter
+ def base_name(self, value):
+ self._name = _offset_name(value, self.offset)
+ self._base_name = _offset_name(value, 0)
+
+ def compute(self, token_index, cache):
+ if not 0 <= (token_index + self.offset) < len(cache):
+ return None
+
+ if self.base_name in cache[token_index + self.offset]:
+ return cache[token_index + self.offset][self.base_name]
+
+ tokens = [c["token"] for c in cache]
+ value = self.function(tokens, token_index + self.offset)
+ cache[token_index + self.offset][self.base_name] = value
+ return value
+
+
+def _offset_name(name, offset):
+ if offset > 0:
+ return "%s[+%s]" % (name, offset)
+ if offset < 0:
+ return "%s[%s]" % (name, offset)
+ return name
diff --git a/snips_inference_agl/slot_filler/feature_factory.py b/snips_inference_agl/slot_filler/feature_factory.py
new file mode 100644
index 0000000..50f4598
--- /dev/null
+++ b/snips_inference_agl/slot_filler/feature_factory.py
@@ -0,0 +1,568 @@
+from __future__ import unicode_literals
+
+import logging
+
+from abc import ABCMeta, abstractmethod
+from builtins import str
+
+from future.utils import with_metaclass
+
+from snips_inference_agl.common.abc_utils import classproperty
+from snips_inference_agl.common.registrable import Registrable
+from snips_inference_agl.common.utils import check_random_state
+from snips_inference_agl.constants import (
+ CUSTOM_ENTITY_PARSER_USAGE, END, GAZETTEERS, LANGUAGE, RES_MATCH_RANGE,
+ START, STEMS, WORD_CLUSTERS, CUSTOM_ENTITY_PARSER, BUILTIN_ENTITY_PARSER,
+ RESOURCES, RANDOM_STATE, AUTOMATICALLY_EXTENSIBLE, ENTITIES)
+from snips_inference_agl.dataset import (
+ extract_intent_entities, get_dataset_gazetteer_entities)
+from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity
+from snips_inference_agl.entity_parser.custom_entity_parser import (
+ CustomEntityParserUsage)
+from snips_inference_agl.languages import get_default_sep
+from snips_inference_agl.preprocessing import Token, normalize_token, stem_token
+from snips_inference_agl.resources import get_gazetteer, get_word_cluster
+from snips_inference_agl.slot_filler.crf_utils import TaggingScheme, get_scheme_prefix
+from snips_inference_agl.slot_filler.feature import Feature
+from snips_inference_agl.slot_filler.features_utils import (
+ entity_filter, get_word_chunk, initial_string_from_tokens)
+
+logger = logging.getLogger(__name__)
+
+
+class CRFFeatureFactory(with_metaclass(ABCMeta, Registrable)):
+ """Abstraction to implement to build CRF features
+
+ A :class:`CRFFeatureFactory` is initialized with a dict which describes
+ the feature, it must contains the three following keys:
+
+ - 'factory_name'
+ - 'args': the parameters of the feature, if any
+ - 'offsets': the offsets to consider when using the feature in the CRF.
+ An empty list corresponds to no feature.
+
+
+ In addition, a 'drop_out' to use at training time can be specified.
+ """
+
+ def __init__(self, factory_config, **shared):
+ self.factory_config = factory_config
+ self.resources = shared.get(RESOURCES)
+ self.builtin_entity_parser = shared.get(BUILTIN_ENTITY_PARSER)
+ self.custom_entity_parser = shared.get(CUSTOM_ENTITY_PARSER)
+ self.random_state = check_random_state(shared.get(RANDOM_STATE))
+
+ @classmethod
+ def from_config(cls, factory_config, **shared):
+ """Retrieve the :class:`CRFFeatureFactory` corresponding the provided
+ config
+
+ Raises:
+ NotRegisteredError: when the factory is not registered
+ """
+ factory_name = factory_config["factory_name"]
+ factory = cls.by_name(factory_name)
+ return factory(factory_config, **shared)
+
+ @classproperty
+ def name(cls): # pylint:disable=no-self-argument
+ return CRFFeatureFactory.registered_name(cls)
+
+ @property
+ def args(self):
+ return self.factory_config["args"]
+
+ @property
+ def offsets(self):
+ return self.factory_config["offsets"]
+
+ @property
+ def drop_out(self):
+ return self.factory_config.get("drop_out", 0.0)
+
+ def fit(self, dataset, intent): # pylint: disable=unused-argument
+ """Fit the factory, if needed, with the provided *dataset* and *intent*
+ """
+ return self
+
+ @abstractmethod
+ def build_features(self):
+ """Build a list of :class:`.Feature`"""
+ pass
+
+ def get_required_resources(self):
+ return None
+
+
+class SingleFeatureFactory(with_metaclass(ABCMeta, CRFFeatureFactory)):
+ """A CRF feature factory which produces only one feature"""
+
+ @property
+ def feature_name(self):
+ # by default, use the factory name
+ return self.name
+
+ @abstractmethod
+ def compute_feature(self, tokens, token_index):
+ pass
+
+ def build_features(self):
+ return [
+ Feature(
+ base_name=self.feature_name,
+ func=self.compute_feature,
+ offset=offset,
+ drop_out=self.drop_out) for offset in self.offsets
+ ]
+
+
+@CRFFeatureFactory.register("is_digit")
+class IsDigitFactory(SingleFeatureFactory):
+ """Feature: is the considered token a digit?"""
+
+ def compute_feature(self, tokens, token_index):
+ return "1" if tokens[token_index].value.isdigit() else None
+
+
+@CRFFeatureFactory.register("is_first")
+class IsFirstFactory(SingleFeatureFactory):
+ """Feature: is the considered token the first in the input?"""
+
+ def compute_feature(self, tokens, token_index):
+ return "1" if token_index == 0 else None
+
+
+@CRFFeatureFactory.register("is_last")
+class IsLastFactory(SingleFeatureFactory):
+ """Feature: is the considered token the last in the input?"""
+
+ def compute_feature(self, tokens, token_index):
+ return "1" if token_index == len(tokens) - 1 else None
+
+
+@CRFFeatureFactory.register("ngram")
+class NgramFactory(SingleFeatureFactory):
+ """Feature: the n-gram consisting of the considered token and potentially
+ the following ones
+
+ This feature has several parameters:
+
+ - 'n' (int): Corresponds to the size of the n-gram. n=1 corresponds to a
+ unigram, n=2 is a bigram etc
+ - 'use_stemming' (bool): Whether or not to stem the n-gram
+ - 'common_words_gazetteer_name' (str, optional): If defined, use a
+ gazetteer of common words and replace out-of-corpus ngram with the
+ alias
+ 'rare_word'
+
+ """
+
+ def __init__(self, factory_config, **shared):
+ super(NgramFactory, self).__init__(factory_config, **shared)
+ self.n = self.args["n"]
+ if self.n < 1:
+ raise ValueError("n should be >= 1")
+
+ self.use_stemming = self.args["use_stemming"]
+ self.common_words_gazetteer_name = self.args[
+ "common_words_gazetteer_name"]
+ self._gazetteer = None
+ self._language = None
+ self.language = self.args.get("language_code")
+
+ @property
+ def language(self):
+ return self._language
+
+ @language.setter
+ def language(self, value):
+ if value is not None:
+ self._language = value
+ self.args["language_code"] = self.language
+
+ @property
+ def gazetteer(self):
+ # Load the gazetteer lazily
+ if self.common_words_gazetteer_name is None:
+ return None
+ if self._gazetteer is None:
+ self._gazetteer = get_gazetteer(
+ self.resources, self.common_words_gazetteer_name)
+ return self._gazetteer
+
+ @property
+ def feature_name(self):
+ return "ngram_%s" % self.n
+
+ def fit(self, dataset, intent):
+ self.language = dataset[LANGUAGE]
+
+ def compute_feature(self, tokens, token_index):
+ max_len = len(tokens)
+ end = token_index + self.n
+ if 0 <= token_index < max_len and end <= max_len:
+ if self.gazetteer is None:
+ if self.use_stemming:
+ stems = (stem_token(t, self.resources)
+ for t in tokens[token_index:end])
+ return get_default_sep(self.language).join(stems)
+ normalized_values = (normalize_token(t)
+ for t in tokens[token_index:end])
+ return get_default_sep(self.language).join(normalized_values)
+ words = []
+ for t in tokens[token_index:end]:
+ if self.use_stemming:
+ value = stem_token(t, self.resources)
+ else:
+ value = normalize_token(t)
+ words.append(value if value in self.gazetteer else "rare_word")
+ return get_default_sep(self.language).join(words)
+ return None
+
+ def get_required_resources(self):
+ resources = dict()
+ if self.common_words_gazetteer_name is not None:
+ resources[GAZETTEERS] = {self.common_words_gazetteer_name}
+ if self.use_stemming:
+ resources[STEMS] = True
+ return resources
+
+
+@CRFFeatureFactory.register("shape_ngram")
+class ShapeNgramFactory(SingleFeatureFactory):
+ """Feature: the shape of the n-gram consisting of the considered token and
+ potentially the following ones
+
+ This feature has one parameters, *n*, which corresponds to the size of the
+ n-gram.
+
+ Possible types of shape are:
+
+ - 'xxx' -> lowercased
+ - 'Xxx' -> Capitalized
+ - 'XXX' -> UPPERCASED
+ - 'xX' -> None of the above
+ """
+
+ def __init__(self, factory_config, **shared):
+ super(ShapeNgramFactory, self).__init__(factory_config, **shared)
+ self.n = self.args["n"]
+ if self.n < 1:
+ raise ValueError("n should be >= 1")
+ self._language = None
+ self.language = self.args.get("language_code")
+
+ @property
+ def language(self):
+ return self._language
+
+ @language.setter
+ def language(self, value):
+ if value is not None:
+ self._language = value
+ self.args["language_code"] = value
+
+ @property
+ def feature_name(self):
+ return "shape_ngram_%s" % self.n
+
+ def fit(self, dataset, intent):
+ self.language = dataset[LANGUAGE]
+
+ def compute_feature(self, tokens, token_index):
+ from snips_nlu_utils import get_shape
+
+ max_len = len(tokens)
+ end = token_index + self.n
+ if 0 <= token_index < max_len and end <= max_len:
+ return get_default_sep(self.language).join(
+ get_shape(t.value) for t in tokens[token_index:end])
+ return None
+
+
+@CRFFeatureFactory.register("word_cluster")
+class WordClusterFactory(SingleFeatureFactory):
+ """Feature: The cluster which the considered token belongs to, if any
+
+ This feature has several parameters:
+
+ - 'cluster_name' (str): the name of the word cluster to use
+ - 'use_stemming' (bool): whether or not to stem the token before looking
+ for its cluster
+
+ Typical words clusters are the Brown Clusters in which words are
+ clustered into a binary tree resulting in clusters of the form '100111001'
+ See https://en.wikipedia.org/wiki/Brown_clustering
+ """
+
+ def __init__(self, factory_config, **shared):
+ super(WordClusterFactory, self).__init__(factory_config, **shared)
+ self.cluster_name = self.args["cluster_name"]
+ self.use_stemming = self.args["use_stemming"]
+ self._cluster = None
+
+ @property
+ def cluster(self):
+ if self._cluster is None:
+ self._cluster = get_word_cluster(self.resources, self.cluster_name)
+ return self._cluster
+
+ @property
+ def feature_name(self):
+ return "word_cluster_%s" % self.cluster_name
+
+ def compute_feature(self, tokens, token_index):
+ if self.use_stemming:
+ value = stem_token(tokens[token_index], self.resources)
+ else:
+ value = normalize_token(tokens[token_index])
+ return self.cluster.get(value, None)
+
+ def get_required_resources(self):
+ return {
+ WORD_CLUSTERS: {self.cluster_name},
+ STEMS: self.use_stemming
+ }
+
+
+@CRFFeatureFactory.register("entity_match")
+class CustomEntityMatchFactory(CRFFeatureFactory):
+ """Features: does the considered token belongs to the values of one of the
+ entities in the training dataset
+
+ This factory builds as many features as there are entities in the dataset,
+ one per entity.
+
+ It has the following parameters:
+
+ - 'use_stemming' (bool): whether or not to stem the token before looking
+ for it among the (stemmed) entity values
+ - 'tagging_scheme_code' (int): Represents a :class:`.TaggingScheme`. This
+ allows to give more information about the match.
+ - 'entity_filter' (dict): a filter applied to select the custom entities
+ for which the custom match feature will be computed. Available
+ filters:
+ - 'automatically_extensible': if True, selects automatically
+ extensible entities only, if False selects non automatically
+ extensible entities only
+ """
+
+ def __init__(self, factory_config, **shared):
+ super(CustomEntityMatchFactory, self).__init__(factory_config,
+ **shared)
+ self.use_stemming = self.args["use_stemming"]
+ self.tagging_scheme = TaggingScheme(
+ self.args["tagging_scheme_code"])
+ self._entities = None
+ self.entities = self.args.get("entities")
+ ent_filter = self.args.get("entity_filter")
+ if ent_filter:
+ try:
+ _check_custom_entity_filter(ent_filter)
+ except _InvalidCustomEntityFilter as e:
+ logger.warning(
+ "Invalid filter '%s', invalid arguments have been ignored:"
+ " %s", ent_filter, e,
+ )
+ self.entity_filter = ent_filter or dict()
+
+ @property
+ def entities(self):
+ return self._entities
+
+ @entities.setter
+ def entities(self, value):
+ if value is not None:
+ self._entities = value
+ self.args["entities"] = value
+
+ def fit(self, dataset, intent):
+ entities_names = extract_intent_entities(
+ dataset, lambda e: not is_builtin_entity(e))[intent]
+ extensible = self.entity_filter.get(AUTOMATICALLY_EXTENSIBLE)
+ if extensible is not None:
+ entities_names = [
+ e for e in entities_names
+ if dataset[ENTITIES][e][AUTOMATICALLY_EXTENSIBLE] == extensible
+ ]
+ self.entities = list(entities_names)
+ return self
+
+ def _transform(self, tokens):
+ if self.use_stemming:
+ light_tokens = (stem_token(t, self.resources) for t in tokens)
+ else:
+ light_tokens = (normalize_token(t) for t in tokens)
+ current_index = 0
+ transformed_tokens = []
+ for light_token in light_tokens:
+ transformed_token = Token(
+ value=light_token,
+ start=current_index,
+ end=current_index + len(light_token))
+ transformed_tokens.append(transformed_token)
+ current_index = transformed_token.end + 1
+ return transformed_tokens
+
+ def build_features(self):
+ features = []
+ for entity_name in self.entities:
+ # We need to call this wrapper in order to properly capture
+ # `entity_name`
+ entity_match = self._build_entity_match_fn(entity_name)
+
+ for offset in self.offsets:
+ feature = Feature("entity_match_%s" % entity_name,
+ entity_match, offset, self.drop_out)
+ features.append(feature)
+ return features
+
+ def _build_entity_match_fn(self, entity):
+
+ def entity_match(tokens, token_index):
+ transformed_tokens = self._transform(tokens)
+ text = initial_string_from_tokens(transformed_tokens)
+ token_start = transformed_tokens[token_index].start
+ token_end = transformed_tokens[token_index].end
+ custom_entities = self.custom_entity_parser.parse(
+ text, scope=[entity], use_cache=True)
+ # only keep builtin entities (of type `entity`) which overlap with
+ # the current token
+ custom_entities = [ent for ent in custom_entities
+ if entity_filter(ent, token_start, token_end)]
+ if custom_entities:
+ # In most cases, 0 or 1 entity will be found. We fall back to
+ # the first entity if 2 or more were found
+ ent = custom_entities[0]
+ indexes = []
+ for index, token in enumerate(transformed_tokens):
+ if entity_filter(ent, token.start, token.end):
+ indexes.append(index)
+ return get_scheme_prefix(token_index, indexes,
+ self.tagging_scheme)
+ return None
+
+ return entity_match
+
+ def get_required_resources(self):
+ if self.use_stemming:
+ return {
+ STEMS: True,
+ CUSTOM_ENTITY_PARSER_USAGE: CustomEntityParserUsage.WITH_STEMS
+ }
+ return {
+ STEMS: False,
+ CUSTOM_ENTITY_PARSER_USAGE:
+ CustomEntityParserUsage.WITHOUT_STEMS
+ }
+
+
+class _InvalidCustomEntityFilter(ValueError):
+ pass
+
+
+CUSTOM_ENTITIES_FILTER_KEYS = {"automatically_extensible"}
+
+
+# pylint: disable=redefined-outer-name
+def _check_custom_entity_filter(entity_filter):
+ for k in entity_filter:
+ if k not in CUSTOM_ENTITIES_FILTER_KEYS:
+ msg = "Invalid custom entity filter key '%s'. Accepted filter " \
+ "keys are %s" % (k, list(CUSTOM_ENTITIES_FILTER_KEYS))
+ raise _InvalidCustomEntityFilter(msg)
+
+
+@CRFFeatureFactory.register("builtin_entity_match")
+class BuiltinEntityMatchFactory(CRFFeatureFactory):
+ """Features: is the considered token part of a builtin entity such as a
+ date, a temperature etc
+
+ This factory builds as many features as there are builtin entities
+ available in the considered language.
+
+ It has one parameter, *tagging_scheme_code*, which represents a
+ :class:`.TaggingScheme`. This allows to give more information about the
+ match.
+ """
+
+ def __init__(self, factory_config, **shared):
+ super(BuiltinEntityMatchFactory, self).__init__(factory_config,
+ **shared)
+ self.tagging_scheme = TaggingScheme(
+ self.args["tagging_scheme_code"])
+ self.builtin_entities = None
+ self.builtin_entities = self.args.get("entity_labels")
+ self._language = None
+ self.language = self.args.get("language_code")
+
+ @property
+ def language(self):
+ return self._language
+
+ @language.setter
+ def language(self, value):
+ if value is not None:
+ self._language = value
+ self.args["language_code"] = self.language
+
+ def fit(self, dataset, intent):
+ self.language = dataset[LANGUAGE]
+ self.builtin_entities = sorted(
+ self._get_builtin_entity_scope(dataset, intent))
+ self.args["entity_labels"] = self.builtin_entities
+
+ def build_features(self):
+ features = []
+
+ for builtin_entity in self.builtin_entities:
+ # We need to call this wrapper in order to properly capture
+ # `builtin_entity`
+ builtin_entity_match = self._build_entity_match_fn(builtin_entity)
+ for offset in self.offsets:
+ feature_name = "builtin_entity_match_%s" % builtin_entity
+ feature = Feature(feature_name, builtin_entity_match, offset,
+ self.drop_out)
+ features.append(feature)
+
+ return features
+
+ def _build_entity_match_fn(self, builtin_entity):
+
+ def builtin_entity_match(tokens, token_index):
+ text = initial_string_from_tokens(tokens)
+ start = tokens[token_index].start
+ end = tokens[token_index].end
+
+ builtin_entities = self.builtin_entity_parser.parse(
+ text, scope=[builtin_entity], use_cache=True)
+ # only keep builtin entities (of type `builtin_entity`) which
+ # overlap with the current token
+ builtin_entities = [ent for ent in builtin_entities
+ if entity_filter(ent, start, end)]
+ if builtin_entities:
+ # In most cases, 0 or 1 entity will be found. We fall back to
+ # the first entity if 2 or more were found
+ ent = builtin_entities[0]
+ entity_start = ent[RES_MATCH_RANGE][START]
+ entity_end = ent[RES_MATCH_RANGE][END]
+ indexes = []
+ for index, token in enumerate(tokens):
+ if (entity_start <= token.start < entity_end) \
+ and (entity_start < token.end <= entity_end):
+ indexes.append(index)
+ return get_scheme_prefix(token_index, indexes,
+ self.tagging_scheme)
+ return None
+
+ return builtin_entity_match
+
+ @staticmethod
+ def _get_builtin_entity_scope(dataset, intent=None):
+ from snips_nlu_parsers import get_supported_grammar_entities
+
+ language = dataset[LANGUAGE]
+ grammar_entities = list(get_supported_grammar_entities(language))
+ gazetteer_entities = list(
+ get_dataset_gazetteer_entities(dataset, intent))
+ return grammar_entities + gazetteer_entities
diff --git a/snips_inference_agl/slot_filler/features_utils.py b/snips_inference_agl/slot_filler/features_utils.py
new file mode 100644
index 0000000..483e9c0
--- /dev/null
+++ b/snips_inference_agl/slot_filler/features_utils.py
@@ -0,0 +1,47 @@
+from __future__ import unicode_literals
+
+from copy import deepcopy
+
+from snips_inference_agl.common.dict_utils import LimitedSizeDict
+from snips_inference_agl.constants import END, RES_MATCH_RANGE, START
+
+_NGRAMS_CACHE = LimitedSizeDict(size_limit=1000)
+
+
+def get_all_ngrams(tokens):
+ from snips_nlu_utils import compute_all_ngrams
+
+ if not tokens:
+ return []
+ key = "<||>".join(tokens)
+ if key not in _NGRAMS_CACHE:
+ ngrams = compute_all_ngrams(tokens, len(tokens))
+ _NGRAMS_CACHE[key] = ngrams
+ return deepcopy(_NGRAMS_CACHE[key])
+
+
+def get_word_chunk(word, chunk_size, chunk_start, reverse=False):
+ if chunk_size < 1:
+ raise ValueError("chunk size should be >= 1")
+ if chunk_size > len(word):
+ return None
+ start = chunk_start - chunk_size if reverse else chunk_start
+ end = chunk_start if reverse else chunk_start + chunk_size
+ return word[start:end]
+
+
+def initial_string_from_tokens(tokens):
+ current_index = 0
+ s = ""
+ for t in tokens:
+ if t.start > current_index:
+ s += " " * (t.start - current_index)
+ s += t.value
+ current_index = t.end
+ return s
+
+
+def entity_filter(entity, start, end):
+ entity_start = entity[RES_MATCH_RANGE][START]
+ entity_end = entity[RES_MATCH_RANGE][END]
+ return entity_start <= start < end <= entity_end
diff --git a/snips_inference_agl/slot_filler/keyword_slot_filler.py b/snips_inference_agl/slot_filler/keyword_slot_filler.py
new file mode 100644
index 0000000..087d997
--- /dev/null
+++ b/snips_inference_agl/slot_filler/keyword_slot_filler.py
@@ -0,0 +1,70 @@
+from __future__ import unicode_literals
+
+import json
+
+from snips_inference_agl.common.utils import json_string
+from snips_inference_agl.preprocessing import tokenize
+from snips_inference_agl.result import unresolved_slot
+from snips_inference_agl.slot_filler import SlotFiller
+
+
+@SlotFiller.register("keyword_slot_filler")
+class KeywordSlotFiller(SlotFiller):
+ def __init__(self, config=None, **shared):
+ super(KeywordSlotFiller, self).__init__(config, **shared)
+ self.slots_keywords = None
+ self.language = None
+
+ @property
+ def fitted(self):
+ return self.slots_keywords is not None
+
+ def fit(self, dataset, intent):
+ self.language = dataset["language"]
+ self.slots_keywords = dict()
+ utterances = dataset["intents"][intent]["utterances"]
+ for utterance in utterances:
+ for chunk in utterance["data"]:
+ if "slot_name" in chunk:
+ text = chunk["text"]
+ if self.config.get("lowercase", False):
+ text = text.lower()
+ self.slots_keywords[text] = [
+ chunk["entity"],
+ chunk["slot_name"]
+ ]
+ return self
+
+ def get_slots(self, text):
+ tokens = tokenize(text, self.language)
+ slots = []
+ for token in tokens:
+ normalized_value = token.value
+ if self.config.get("lowercase", False):
+ normalized_value = normalized_value.lower()
+ if normalized_value in self.slots_keywords:
+ entity = self.slots_keywords[normalized_value][0]
+ slot_name = self.slots_keywords[normalized_value][1]
+ slot = unresolved_slot((token.start, token.end), token.value,
+ entity, slot_name)
+ slots.append(slot)
+ return slots
+
+ def persist(self, path):
+ model = {
+ "language": self.language,
+ "slots_keywords": self.slots_keywords,
+ "config": self.config.to_dict()
+ }
+ with path.open(mode="w", encoding="utf8") as f:
+ f.write(json_string(model))
+
+ @classmethod
+ def from_path(cls, path, **shared):
+ with path.open() as f:
+ model = json.load(f)
+ slot_filler = cls()
+ slot_filler.language = model["language"]
+ slot_filler.slots_keywords = model["slots_keywords"]
+ slot_filler.config = cls.config_type.from_dict(model["config"])
+ return slot_filler
diff --git a/snips_inference_agl/slot_filler/slot_filler.py b/snips_inference_agl/slot_filler/slot_filler.py
new file mode 100644
index 0000000..a1fc937
--- /dev/null
+++ b/snips_inference_agl/slot_filler/slot_filler.py
@@ -0,0 +1,33 @@
+from abc import abstractmethod, ABCMeta
+
+from future.utils import with_metaclass
+
+from snips_inference_agl.common.abc_utils import classproperty
+from snips_inference_agl.pipeline.processing_unit import ProcessingUnit
+
+
+class SlotFiller(with_metaclass(ABCMeta, ProcessingUnit)):
+ """Abstraction which performs slot filling
+
+ A custom slot filler must inherit this class to be used in a
+ :class:`.ProbabilisticIntentParser`
+ """
+
+ @classproperty
+ def unit_name(cls): # pylint:disable=no-self-argument
+ return SlotFiller.registered_name(cls)
+
+ @abstractmethod
+ def fit(self, dataset, intent):
+ """Fit the slot filler with a valid Snips dataset"""
+ pass
+
+ @abstractmethod
+ def get_slots(self, text):
+ """Performs slot extraction (slot filling) on the provided *text*
+
+ Returns:
+ list of dict: The list of extracted slots. See
+ :func:`.unresolved_slot` for the output format of a slot
+ """
+ pass