aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl
diff options
context:
space:
mode:
Diffstat (limited to 'snips_inference_agl')
-rw-r--r--snips_inference_agl/__about__.py24
-rw-r--r--snips_inference_agl/__init__.py15
-rw-r--r--snips_inference_agl/__main__.py6
-rw-r--r--snips_inference_agl/cli/__init__.py39
-rw-r--r--snips_inference_agl/cli/inference.py66
-rw-r--r--snips_inference_agl/cli/utils.py79
-rw-r--r--snips_inference_agl/cli/versions.py19
-rw-r--r--snips_inference_agl/common/__init__.py0
-rw-r--r--snips_inference_agl/common/abc_utils.py36
-rw-r--r--snips_inference_agl/common/dataset_utils.py48
-rw-r--r--snips_inference_agl/common/dict_utils.py36
-rw-r--r--snips_inference_agl/common/from_dict.py30
-rw-r--r--snips_inference_agl/common/io_utils.py36
-rw-r--r--snips_inference_agl/common/log_utils.py61
-rw-r--r--snips_inference_agl/common/registrable.py73
-rw-r--r--snips_inference_agl/common/utils.py239
-rw-r--r--snips_inference_agl/constants.py74
-rw-r--r--snips_inference_agl/data_augmentation.py121
-rw-r--r--snips_inference_agl/dataset/__init__.py7
-rw-r--r--snips_inference_agl/dataset/dataset.py102
-rw-r--r--snips_inference_agl/dataset/entity.py175
-rw-r--r--snips_inference_agl/dataset/intent.py339
-rw-r--r--snips_inference_agl/dataset/utils.py67
-rw-r--r--snips_inference_agl/dataset/validation.py254
-rw-r--r--snips_inference_agl/dataset/yaml_wrapper.py11
-rw-r--r--snips_inference_agl/default_configs/__init__.py26
-rw-r--r--snips_inference_agl/default_configs/config_de.py159
-rw-r--r--snips_inference_agl/default_configs/config_en.py145
-rw-r--r--snips_inference_agl/default_configs/config_es.py138
-rw-r--r--snips_inference_agl/default_configs/config_fr.py137
-rw-r--r--snips_inference_agl/default_configs/config_it.py137
-rw-r--r--snips_inference_agl/default_configs/config_ja.py164
-rw-r--r--snips_inference_agl/default_configs/config_ko.py155
-rw-r--r--snips_inference_agl/default_configs/config_pt_br.py137
-rw-r--r--snips_inference_agl/default_configs/config_pt_pt.py137
-rw-r--r--snips_inference_agl/entity_parser/__init__.py6
-rw-r--r--snips_inference_agl/entity_parser/builtin_entity_parser.py162
-rw-r--r--snips_inference_agl/entity_parser/custom_entity_parser.py209
-rw-r--r--snips_inference_agl/entity_parser/custom_entity_parser_usage.py23
-rw-r--r--snips_inference_agl/entity_parser/entity_parser.py85
-rw-r--r--snips_inference_agl/exceptions.py87
-rw-r--r--snips_inference_agl/intent_classifier/__init__.py3
-rw-r--r--snips_inference_agl/intent_classifier/featurizer.py452
-rw-r--r--snips_inference_agl/intent_classifier/intent_classifier.py51
-rw-r--r--snips_inference_agl/intent_classifier/log_reg_classifier.py211
-rw-r--r--snips_inference_agl/intent_classifier/log_reg_classifier_utils.py94
-rw-r--r--snips_inference_agl/intent_parser/__init__.py4
-rw-r--r--snips_inference_agl/intent_parser/deterministic_intent_parser.py518
-rw-r--r--snips_inference_agl/intent_parser/intent_parser.py85
-rw-r--r--snips_inference_agl/intent_parser/lookup_intent_parser.py509
-rw-r--r--snips_inference_agl/intent_parser/probabilistic_intent_parser.py250
-rw-r--r--snips_inference_agl/languages.py44
-rw-r--r--snips_inference_agl/nlu_engine/__init__.py1
-rw-r--r--snips_inference_agl/nlu_engine/nlu_engine.py330
-rw-r--r--snips_inference_agl/pipeline/__init__.py0
-rw-r--r--snips_inference_agl/pipeline/configs/__init__.py10
-rw-r--r--snips_inference_agl/pipeline/configs/config.py49
-rw-r--r--snips_inference_agl/pipeline/configs/features.py81
-rw-r--r--snips_inference_agl/pipeline/configs/intent_classifier.py307
-rw-r--r--snips_inference_agl/pipeline/configs/intent_parser.py127
-rw-r--r--snips_inference_agl/pipeline/configs/nlu_engine.py55
-rw-r--r--snips_inference_agl/pipeline/configs/slot_filler.py145
-rw-r--r--snips_inference_agl/pipeline/processing_unit.py177
-rw-r--r--snips_inference_agl/preprocessing.py97
-rw-r--r--snips_inference_agl/resources.py267
-rw-r--r--snips_inference_agl/result.py342
-rw-r--r--snips_inference_agl/slot_filler/__init__.py3
-rw-r--r--snips_inference_agl/slot_filler/crf_slot_filler.py467
-rw-r--r--snips_inference_agl/slot_filler/crf_utils.py219
-rw-r--r--snips_inference_agl/slot_filler/feature.py69
-rw-r--r--snips_inference_agl/slot_filler/feature_factory.py568
-rw-r--r--snips_inference_agl/slot_filler/features_utils.py47
-rw-r--r--snips_inference_agl/slot_filler/keyword_slot_filler.py70
-rw-r--r--snips_inference_agl/slot_filler/slot_filler.py33
-rw-r--r--snips_inference_agl/string_variations.py195
75 files changed, 9744 insertions, 0 deletions
diff --git a/snips_inference_agl/__about__.py b/snips_inference_agl/__about__.py
new file mode 100644
index 0000000..eec36b5
--- /dev/null
+++ b/snips_inference_agl/__about__.py
@@ -0,0 +1,24 @@
+# inspired from:
+# https://python-packaging-user-guide.readthedocs.io/guides/single-sourcing-package-version/
+# https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py
+
+# pylint:disable=line-too-long
+
+__title__ = "snips_inference_agl"
+__summary__ = "A modified version of original Snips NLU library for NLU inference on AGL platform."
+__github_url__ = "https://github.com/snipsco/snips-nlu"
+__tracker_url__ = "https://github.com/snipsco/snips-nlu/issues"
+__author__ = "Clement Doumouro, Adrien Ball"
+__email__ = "clement.doumouro@snips.ai, adrien.ball@snips.ai"
+__license__ = "Apache License, Version 2.0"
+
+__version__ = "0.20.2"
+__model_version__ = "0.20.0"
+
+__download_url__ = "https://github.com/snipsco/snips-nlu-language-resources/releases/download"
+__compatibility__ = "https://raw.githubusercontent.com/snipsco/snips-nlu-language-resources/master/compatibility.json"
+__shortcuts__ = "https://raw.githubusercontent.com/snipsco/snips-nlu-language-resources/master/shortcuts.json"
+
+__entities_download_url__ = "https://resources.snips.ai/nlu/gazetteer-entities"
+
+# pylint:enable=line-too-long
diff --git a/snips_inference_agl/__init__.py b/snips_inference_agl/__init__.py
new file mode 100644
index 0000000..fe05b12
--- /dev/null
+++ b/snips_inference_agl/__init__.py
@@ -0,0 +1,15 @@
+from deprecation import deprecated
+
+from snips_inference_agl.__about__ import __model_version__, __version__
+from snips_inference_agl.nlu_engine import SnipsNLUEngine
+from snips_inference_agl.pipeline.configs import NLUEngineConfig
+
+
+@deprecated(deprecated_in="0.19.7", removed_in="0.21.0",
+ current_version=__version__,
+ details="Loading resources in the client code is no longer "
+ "required")
+def load_resources(name, required_resources=None):
+ from snips_inference_agl.resources import load_resources as _load_resources
+
+ return _load_resources(name, required_resources)
diff --git a/snips_inference_agl/__main__.py b/snips_inference_agl/__main__.py
new file mode 100644
index 0000000..7cf8cfe
--- /dev/null
+++ b/snips_inference_agl/__main__.py
@@ -0,0 +1,6 @@
+from __future__ import print_function, unicode_literals
+
+
+if __name__ == "__main__":
+ from snips_inference_agl.cli import main
+ main()
diff --git a/snips_inference_agl/cli/__init__.py b/snips_inference_agl/cli/__init__.py
new file mode 100644
index 0000000..ccfbf18
--- /dev/null
+++ b/snips_inference_agl/cli/__init__.py
@@ -0,0 +1,39 @@
+import argparse
+
+
+class Formatter(argparse.ArgumentDefaultsHelpFormatter):
+ def __init__(self, prog):
+ super(Formatter, self).__init__(prog, max_help_position=35, width=150)
+
+
+def get_arg_parser():
+ from snips_inference_agl.cli.inference import add_parse_parser
+ from snips_inference_agl.cli.versions import (
+ add_version_parser, add_model_version_parser)
+
+ arg_parser = argparse.ArgumentParser(
+ description="Snips NLU command line interface",
+ prog="python -m snips_nlu", formatter_class=Formatter)
+ arg_parser.add_argument("-v", "--version", action="store_true",
+ help="Print package version")
+ subparsers = arg_parser.add_subparsers(
+ title="available commands", metavar="command [options ...]")
+ add_parse_parser(subparsers, formatter_class=Formatter)
+ add_version_parser(subparsers, formatter_class=Formatter)
+ add_model_version_parser(subparsers, formatter_class=Formatter)
+ return arg_parser
+
+
+def main():
+ from snips_inference_agl.__about__ import __version__
+
+ arg_parser = get_arg_parser()
+ args = arg_parser.parse_args()
+
+ if hasattr(args, "func"):
+ args.func(args)
+ elif "version" in args:
+ print(__version__)
+ else:
+ arg_parser.print_help()
+ exit(1)
diff --git a/snips_inference_agl/cli/inference.py b/snips_inference_agl/cli/inference.py
new file mode 100644
index 0000000..4ef5618
--- /dev/null
+++ b/snips_inference_agl/cli/inference.py
@@ -0,0 +1,66 @@
+from __future__ import unicode_literals, print_function
+
+
+def add_parse_parser(subparsers, formatter_class):
+ subparser = subparsers.add_parser(
+ "parse", formatter_class=formatter_class,
+ help="Load a trained NLU engine and perform parsing")
+ subparser.add_argument("training_path", type=str,
+ help="Path to a trained engine")
+ subparser.add_argument("-q", "--query", type=str,
+ help="Query to parse. If provided, it disables the "
+ "interactive behavior.")
+ subparser.add_argument("-v", "--verbosity", action="count", default=0,
+ help="Increase output verbosity")
+ subparser.add_argument("-f", "--intents-filter", type=str,
+ help="Intents filter as a comma-separated list")
+ subparser.set_defaults(func=_parse)
+ return subparser
+
+
+def _parse(args_namespace):
+ return parse(args_namespace.training_path, args_namespace.query,
+ args_namespace.verbosity, args_namespace.intents_filter)
+
+
+def parse(training_path, query, verbose=False, intents_filter=None):
+ """Load a trained NLU engine and play with its parsing API interactively"""
+ import csv
+ import logging
+ from builtins import input, str
+ from snips_inference_agl import SnipsNLUEngine
+ from snips_inference_agl.cli.utils import set_nlu_logger
+
+ if verbose == 1:
+ set_nlu_logger(logging.INFO)
+ elif verbose >= 2:
+ set_nlu_logger(logging.DEBUG)
+ if intents_filter:
+ # use csv in order to properly handle commas and other special
+ # characters in intent names
+ intents_filter = next(csv.reader([intents_filter]))
+ else:
+ intents_filter = None
+
+ engine = SnipsNLUEngine.from_path(training_path)
+
+ if query:
+ print_parsing_result(engine, query, intents_filter)
+ return
+
+ while True:
+ query = input("Enter a query (type 'q' to quit): ").strip()
+ if not isinstance(query, str):
+ query = query.decode("utf-8")
+ if query == "q":
+ break
+ print_parsing_result(engine, query, intents_filter)
+
+
+def print_parsing_result(engine, query, intents_filter):
+ from snips_inference_agl.common.utils import unicode_string, json_string
+
+ query = unicode_string(query)
+ json_dump = json_string(engine.parse(query, intents_filter),
+ sort_keys=True, indent=2)
+ print(json_dump)
diff --git a/snips_inference_agl/cli/utils.py b/snips_inference_agl/cli/utils.py
new file mode 100644
index 0000000..0e5464c
--- /dev/null
+++ b/snips_inference_agl/cli/utils.py
@@ -0,0 +1,79 @@
+from __future__ import print_function, unicode_literals
+
+import logging
+import sys
+from enum import Enum, unique
+
+import requests
+
+import snips_inference_agl
+from snips_inference_agl import __about__
+
+
+@unique
+class PrettyPrintLevel(Enum):
+ INFO = 0
+ WARNING = 1
+ ERROR = 2
+ SUCCESS = 3
+
+
+FMT = "[%(levelname)s][%(asctime)s.%(msecs)03d][%(name)s]: %(message)s"
+DATE_FMT = "%H:%M:%S"
+
+
+def pretty_print(*texts, **kwargs):
+ """Print formatted message
+
+ Args:
+ *texts (str): Texts to print. Each argument is rendered as paragraph.
+ **kwargs: 'title' becomes coloured headline. exits=True performs sys
+ exit.
+ """
+ exits = kwargs.get("exits")
+ title = kwargs.get("title")
+ level = kwargs.get("level", PrettyPrintLevel.INFO)
+ title_color = _color_from_level(level)
+ if title:
+ title = "\033[{color}m{title}\033[0m\n".format(title=title,
+ color=title_color)
+ else:
+ title = ""
+ message = "\n\n".join([text for text in texts])
+ print("\n{title}{message}\n".format(title=title, message=message))
+ if exits is not None:
+ sys.exit(exits)
+
+
+def _color_from_level(level):
+ if level == PrettyPrintLevel.INFO:
+ return "92"
+ if level == PrettyPrintLevel.WARNING:
+ return "93"
+ if level == PrettyPrintLevel.ERROR:
+ return "91"
+ if level == PrettyPrintLevel.SUCCESS:
+ return "92"
+ else:
+ raise ValueError("Unknown PrettyPrintLevel: %s" % level)
+
+
+def get_json(url, desc):
+ r = requests.get(url)
+ if r.status_code != 200:
+ raise OSError("%s: Received status code %s when fetching the resource"
+ % (desc, r.status_code))
+ return r.json()
+
+
+def set_nlu_logger(level=logging.INFO):
+ logger = logging.getLogger(snips_inference_agl.__name__)
+ logger.setLevel(level)
+
+ formatter = logging.Formatter(FMT, DATE_FMT)
+
+ handler = logging.StreamHandler(sys.stdout)
+ handler.setFormatter(formatter)
+ handler.setLevel(level)
+
+ logger.addHandler(handler)
diff --git a/snips_inference_agl/cli/versions.py b/snips_inference_agl/cli/versions.py
new file mode 100644
index 0000000..d7922ef
--- /dev/null
+++ b/snips_inference_agl/cli/versions.py
@@ -0,0 +1,19 @@
+from __future__ import print_function
+
+
+def add_version_parser(subparsers, formatter_class):
+ from snips_inference_agl.__about__ import __version__
+ subparser = subparsers.add_parser(
+ "version", formatter_class=formatter_class,
+ help="Print the package version")
+ subparser.set_defaults(func=lambda _: print(__version__))
+ return subparser
+
+
+def add_model_version_parser(subparsers, formatter_class):
+ from snips_inference_agl.__about__ import __model_version__
+ subparser = subparsers.add_parser(
+ "model-version", formatter_class=formatter_class,
+ help="Print the model version")
+ subparser.set_defaults(func=lambda _: print(__model_version__))
+ return subparser
diff --git a/snips_inference_agl/common/__init__.py b/snips_inference_agl/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/snips_inference_agl/common/__init__.py
diff --git a/snips_inference_agl/common/abc_utils.py b/snips_inference_agl/common/abc_utils.py
new file mode 100644
index 0000000..db7b933
--- /dev/null
+++ b/snips_inference_agl/common/abc_utils.py
@@ -0,0 +1,36 @@
+class abstractclassmethod(classmethod): # pylint: disable=invalid-name
+ __isabstractmethod__ = True
+
+ def __init__(self, callable):
+ callable.__isabstractmethod__ = True
+ super(abstractclassmethod, self).__init__(callable)
+
+
+class ClassPropertyDescriptor(object):
+ def __init__(self, fget, fset=None):
+ self.fget = fget
+ self.fset = fset
+
+ def __get__(self, obj, klass=None):
+ if klass is None:
+ klass = type(obj)
+ return self.fget.__get__(obj, klass)()
+
+ def __set__(self, obj, value):
+ if not self.fset:
+ raise AttributeError("can't set attribute")
+ type_ = type(obj)
+ return self.fset.__get__(obj, type_)(value)
+
+ def setter(self, func):
+ if not isinstance(func, (classmethod, staticmethod)):
+ func = classmethod(func)
+ self.fset = func
+ return self
+
+
+def classproperty(func):
+ if not isinstance(func, (classmethod, staticmethod)):
+ func = classmethod(func)
+
+ return ClassPropertyDescriptor(func)
diff --git a/snips_inference_agl/common/dataset_utils.py b/snips_inference_agl/common/dataset_utils.py
new file mode 100644
index 0000000..34648e6
--- /dev/null
+++ b/snips_inference_agl/common/dataset_utils.py
@@ -0,0 +1,48 @@
+from snips_inference_agl.constants import INTENTS, UTTERANCES, DATA, SLOT_NAME, ENTITY
+from snips_inference_agl.exceptions import DatasetFormatError
+
+
+def type_error(expected_type, found_type, object_label=None):
+ if object_label is None:
+ raise DatasetFormatError("Invalid type: expected %s but found %s"
+ % (expected_type, found_type))
+ raise DatasetFormatError("Invalid type for '%s': expected %s but found %s"
+ % (object_label, expected_type, found_type))
+
+
+def validate_type(obj, expected_type, object_label=None):
+ if not isinstance(obj, expected_type):
+ type_error(expected_type, type(obj), object_label)
+
+
+def missing_key_error(key, object_label=None):
+ if object_label is None:
+ raise DatasetFormatError("Missing key: '%s'" % key)
+ raise DatasetFormatError("Expected %s to have key: '%s'"
+ % (object_label, key))
+
+
+def validate_key(obj, key, object_label=None):
+ if key not in obj:
+ missing_key_error(key, object_label)
+
+
+def validate_keys(obj, keys, object_label=None):
+ for key in keys:
+ validate_key(obj, key, object_label)
+
+
+def get_slot_name_mapping(dataset, intent):
+ """Returns a dict which maps slot names to entities for the provided intent
+ """
+ slot_name_mapping = dict()
+ for utterance in dataset[INTENTS][intent][UTTERANCES]:
+ for chunk in utterance[DATA]:
+ if SLOT_NAME in chunk:
+ slot_name_mapping[chunk[SLOT_NAME]] = chunk[ENTITY]
+ return slot_name_mapping
+
+def get_slot_name_mappings(dataset):
+ """Returns a dict which maps intents to their slot name mapping"""
+ return {intent: get_slot_name_mapping(dataset, intent)
+ for intent in dataset[INTENTS]} \ No newline at end of file
diff --git a/snips_inference_agl/common/dict_utils.py b/snips_inference_agl/common/dict_utils.py
new file mode 100644
index 0000000..a70b217
--- /dev/null
+++ b/snips_inference_agl/common/dict_utils.py
@@ -0,0 +1,36 @@
+from collections import OrderedDict
+
+
+class LimitedSizeDict(OrderedDict):
+ def __init__(self, *args, **kwds):
+ if "size_limit" not in kwds:
+ raise ValueError("'size_limit' must be passed as a keyword "
+ "argument")
+ self.size_limit = kwds.pop("size_limit")
+ if len(args) > 1:
+ raise TypeError('expected at most 1 arguments, got %d' % len(args))
+ if len(args) == 1 and len(args[0]) + len(kwds) > self.size_limit:
+ raise ValueError("Tried to initialize LimitedSizedDict with more "
+ "value than permitted with 'limit_size'")
+ super(LimitedSizeDict, self).__init__(*args, **kwds)
+
+ def __setitem__(self, key, value, dict_setitem=OrderedDict.__setitem__):
+ dict_setitem(self, key, value)
+ self._check_size_limit()
+
+ def _check_size_limit(self):
+ if self.size_limit is not None:
+ while len(self) > self.size_limit:
+ self.popitem(last=False)
+
+ def __eq__(self, other):
+ if self.size_limit != other.size_limit:
+ return False
+ return super(LimitedSizeDict, self).__eq__(other)
+
+
+class UnupdatableDict(dict):
+ def __setitem__(self, key, value):
+ if key in self:
+ raise KeyError("Can't update key '%s'" % key)
+ super(UnupdatableDict, self).__setitem__(key, value)
diff --git a/snips_inference_agl/common/from_dict.py b/snips_inference_agl/common/from_dict.py
new file mode 100644
index 0000000..2b776a6
--- /dev/null
+++ b/snips_inference_agl/common/from_dict.py
@@ -0,0 +1,30 @@
+try:
+ import funcsigs as inspect
+except ImportError:
+ import inspect
+
+from future.utils import iteritems
+
+KEYWORD_KINDS = {inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ inspect.Parameter.KEYWORD_ONLY}
+
+
+class FromDict(object):
+ @classmethod
+ def from_dict(cls, dict):
+ if dict is None:
+ return cls()
+ params = inspect.signature(cls.__init__).parameters
+
+ if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in
+ params.values()):
+ return cls(**dict)
+
+ param_names = set()
+ for i, (name, param) in enumerate(iteritems(params)):
+ if i == 0 and name == "self":
+ continue
+ if param.kind in KEYWORD_KINDS:
+ param_names.add(name)
+ filtered_dict = {k: v for k, v in iteritems(dict) if k in param_names}
+ return cls(**filtered_dict)
diff --git a/snips_inference_agl/common/io_utils.py b/snips_inference_agl/common/io_utils.py
new file mode 100644
index 0000000..f4407a3
--- /dev/null
+++ b/snips_inference_agl/common/io_utils.py
@@ -0,0 +1,36 @@
+import errno
+import os
+import shutil
+from contextlib import contextmanager
+from pathlib import Path
+from tempfile import mkdtemp
+from zipfile import ZipFile, ZIP_DEFLATED
+
+
+def mkdir_p(path):
+ """Reproduces the 'mkdir -p shell' command
+
+ See
+ http://stackoverflow.com/questions/600268/mkdir-p-functionality-in-python
+ """
+ try:
+ os.makedirs(str(path))
+ except OSError as exc: # Python >2.5
+ if exc.errno == errno.EEXIST and path.is_dir():
+ pass
+ else:
+ raise
+
+
+@contextmanager
+def temp_dir():
+ tmp_dir = mkdtemp()
+ try:
+ yield Path(tmp_dir)
+ finally:
+ shutil.rmtree(tmp_dir)
+
+
+def unzip_archive(archive_file, destination_dir):
+ with ZipFile(archive_file, "r", ZIP_DEFLATED) as zipf:
+ zipf.extractall(str(destination_dir))
diff --git a/snips_inference_agl/common/log_utils.py b/snips_inference_agl/common/log_utils.py
new file mode 100644
index 0000000..47da34e
--- /dev/null
+++ b/snips_inference_agl/common/log_utils.py
@@ -0,0 +1,61 @@
+from __future__ import unicode_literals
+
+from builtins import str
+from datetime import datetime
+from functools import wraps
+
+from snips_inference_agl.common.utils import json_debug_string
+
+
+class DifferedLoggingMessage(object):
+
+ def __init__(self, fn, *args, **kwargs):
+ self.fn = fn
+ self.args = args
+ self.kwargs = kwargs
+
+ def __str__(self):
+ return str(self.fn(*self.args, **self.kwargs))
+
+
+def log_elapsed_time(logger, level, output_msg=None):
+ if output_msg is None:
+ output_msg = "Elapsed time ->:\n{elapsed_time}"
+
+ def get_wrapper(fn):
+ @wraps(fn)
+ def wrapped(*args, **kwargs):
+ start = datetime.now()
+ msg_fmt = dict()
+ res = fn(*args, **kwargs)
+ if "elapsed_time" in output_msg:
+ msg_fmt["elapsed_time"] = datetime.now() - start
+ logger.log(level, output_msg.format(**msg_fmt))
+ return res
+
+ return wrapped
+
+ return get_wrapper
+
+
+def log_result(logger, level, output_msg=None):
+ if output_msg is None:
+ output_msg = "Result ->:\n{result}"
+
+ def get_wrapper(fn):
+ @wraps(fn)
+ def wrapped(*args, **kwargs):
+ msg_fmt = dict()
+ res = fn(*args, **kwargs)
+ if "result" in output_msg:
+ try:
+ res_debug_string = json_debug_string(res)
+ except TypeError:
+ res_debug_string = str(res)
+ msg_fmt["result"] = res_debug_string
+ logger.log(level, output_msg.format(**msg_fmt))
+ return res
+
+ return wrapped
+
+ return get_wrapper
diff --git a/snips_inference_agl/common/registrable.py b/snips_inference_agl/common/registrable.py
new file mode 100644
index 0000000..bbc7bdc
--- /dev/null
+++ b/snips_inference_agl/common/registrable.py
@@ -0,0 +1,73 @@
+# This module is largely inspired by the AllenNLP library
+# See github.com/allenai/allennlp/blob/master/allennlp/common/registrable.py
+
+from collections import defaultdict
+from future.utils import iteritems
+
+from snips_inference_agl.exceptions import AlreadyRegisteredError, NotRegisteredError
+
+
+class Registrable(object):
+ """
+ Any class that inherits from ``Registrable`` gains access to a named
+ registry for its subclasses. To register them, just decorate them with the
+ classmethod ``@BaseClass.register(name)``.
+
+ After which you can call ``BaseClass.list_available()`` to get the keys
+ for the registered subclasses, and ``BaseClass.by_name(name)`` to get the
+ corresponding subclass.
+
+ Note that if you use this class to implement a new ``Registrable``
+ abstract class, you must ensure that all subclasses of the abstract class
+ are loaded when the module is loaded, because the subclasses register
+ themselves in their respective files. You can achieve this by having the
+ abstract class and all subclasses in the __init__.py of the module in
+ which they reside (as this causes any import of either the abstract class
+ or a subclass to load all other subclasses and the abstract class).
+ """
+ _registry = defaultdict(dict)
+
+ @classmethod
+ def register(cls, name, override=False):
+ """Decorator used to add the decorated subclass to the registry of the
+ base class
+
+ Args:
+ name (str): name use to identify the registered subclass
+ override (bool, optional): this parameter controls the behavior in
+ case where a subclass is registered with the same identifier.
+ If True, then the previous subclass will be unregistered in
+ profit of the new subclass.
+
+ Raises:
+ AlreadyRegisteredError: when ``override`` is False, while trying
+ to register a subclass with a name already used by another
+ registered subclass
+ """
+ registry = Registrable._registry[cls]
+
+ def add_subclass_to_registry(subclass):
+ # Add to registry, raise an error if key has already been used.
+ if not override and name in registry:
+ raise AlreadyRegisteredError(name, cls, registry[name])
+ registry[name] = subclass
+ return subclass
+
+ return add_subclass_to_registry
+
+ @classmethod
+ def registered_name(cls, registered_class):
+ for name, subclass in iteritems(Registrable._registry[cls]):
+ if subclass == registered_class:
+ return name
+ raise NotRegisteredError(cls, registered_cls=registered_class)
+
+ @classmethod
+ def by_name(cls, name):
+ if name not in Registrable._registry[cls]:
+ raise NotRegisteredError(cls, name=name)
+ return Registrable._registry[cls][name]
+
+ @classmethod
+ def list_available(cls):
+ return list(Registrable._registry[cls].keys())
diff --git a/snips_inference_agl/common/utils.py b/snips_inference_agl/common/utils.py
new file mode 100644
index 0000000..816dae7
--- /dev/null
+++ b/snips_inference_agl/common/utils.py
@@ -0,0 +1,239 @@
+from __future__ import unicode_literals
+
+import importlib
+import json
+import numbers
+import re
+from builtins import bytes as newbytes, str as newstr
+from datetime import datetime
+from functools import wraps
+from pathlib import Path
+
+from future.utils import text_type
+
+from snips_inference_agl.constants import (END, ENTITY_KIND, RES_MATCH_RANGE, RES_VALUE,
+ START)
+from snips_inference_agl.exceptions import NotTrained, PersistingError
+
+REGEX_PUNCT = {'\\', '.', '+', '*', '?', '(', ')', '|', '[', ']', '{', '}',
+ '^', '$', '#', '&', '-', '~'}
+
+
+# pylint:disable=line-too-long
+def regex_escape(s):
+ """Escapes all regular expression meta characters in *s*
+
+ The string returned may be safely used as a literal in a regular
+ expression.
+
+ This function is more precise than :func:`re.escape`, the latter escapes
+ all non-alphanumeric characters which can cause cross-platform
+ compatibility issues.
+
+ References:
+
+ - https://github.com/rust-lang/regex/blob/master/regex-syntax/src/lib.rs#L1685
+ - https://github.com/rust-lang/regex/blob/master/regex-syntax/src/parser.rs#L1378
+ """
+ escaped_string = ""
+ for c in s:
+ if c in REGEX_PUNCT:
+ escaped_string += "\\"
+ escaped_string += c
+ return escaped_string
+
+
+# pylint:enable=line-too-long
+
+
+def check_random_state(seed):
+ """Turn seed into a :class:`numpy.random.RandomState` instance
+
+ If seed is None, return the RandomState singleton used by np.random.
+ If seed is an int, return a new RandomState instance seeded with seed.
+ If seed is already a RandomState instance, return it.
+ Otherwise raise ValueError.
+ """
+ import numpy as np
+
+ # pylint: disable=W0212
+ # pylint: disable=c-extension-no-member
+ if seed is None or seed is np.random:
+ return np.random.mtrand._rand # pylint: disable=c-extension-no-member
+ if isinstance(seed, (numbers.Integral, np.integer)):
+ return np.random.RandomState(seed)
+ if isinstance(seed, np.random.RandomState):
+ return seed
+ raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
+ ' instance' % seed)
+
+
+def ranges_overlap(lhs_range, rhs_range):
+ if isinstance(lhs_range, dict) and isinstance(rhs_range, dict):
+ return lhs_range[END] > rhs_range[START] \
+ and lhs_range[START] < rhs_range[END]
+ elif isinstance(lhs_range, (tuple, list)) \
+ and isinstance(rhs_range, (tuple, list)):
+ return lhs_range[1] > rhs_range[0] and lhs_range[0] < rhs_range[1]
+ else:
+ raise TypeError("Cannot check overlap on objects of type: %s and %s"
+ % (type(lhs_range), type(rhs_range)))
+
+
+def elapsed_since(time):
+ return datetime.now() - time
+
+
+def json_debug_string(dict_data):
+ return json.dumps(dict_data, ensure_ascii=False, indent=2, sort_keys=True)
+
+def json_string(json_object, indent=2, sort_keys=True):
+ json_dump = json.dumps(json_object, indent=indent, sort_keys=sort_keys,
+ separators=(',', ': '))
+ return unicode_string(json_dump)
+
+def unicode_string(string):
+ if isinstance(string, text_type):
+ return string
+ if isinstance(string, bytes):
+ return string.decode("utf8")
+ if isinstance(string, newstr):
+ return text_type(string)
+ if isinstance(string, newbytes):
+ string = bytes(string).decode("utf8")
+
+ raise TypeError("Cannot convert %s into unicode string" % type(string))
+
+
+def check_persisted_path(func):
+ @wraps(func)
+ def func_wrapper(self, path, *args, **kwargs):
+ path = Path(path)
+ if path.exists():
+ raise PersistingError(path)
+ return func(self, path, *args, **kwargs)
+
+ return func_wrapper
+
+
+def fitted_required(func):
+ @wraps(func)
+ def func_wrapper(self, *args, **kwargs):
+ if not self.fitted:
+ raise NotTrained("%s must be fitted" % self.unit_name)
+ return func(self, *args, **kwargs)
+
+ return func_wrapper
+
+
+def is_package(name):
+ """Check if name maps to a package installed via pip.
+
+ Args:
+ name (str): Name of package
+
+ Returns:
+ bool: True if an installed packaged corresponds to this name, False
+ otherwise.
+ """
+ import pkg_resources
+
+ name = name.lower().replace("-", "_")
+ packages = pkg_resources.working_set.by_key.keys()
+ for package in packages:
+ if package.lower().replace("-", "_") == name:
+ return True
+ return False
+
+
+def get_package_path(name):
+ """Get the path to an installed package.
+
+ Args:
+ name (str): Package name
+
+ Returns:
+ class:`.Path`: Path to the installed package
+ """
+ name = name.lower().replace("-", "_")
+ pkg = importlib.import_module(name)
+ return Path(pkg.__file__).parent
+
+
+def deduplicate_overlapping_items(items, overlap_fn, sort_key_fn):
+ """Deduplicates the items by looping over the items, sorted using
+ sort_key_fn, and checking overlaps with previously seen items using
+ overlap_fn
+ """
+ sorted_items = sorted(items, key=sort_key_fn)
+ deduplicated_items = []
+ for item in sorted_items:
+ if not any(overlap_fn(item, dedup_item)
+ for dedup_item in deduplicated_items):
+ deduplicated_items.append(item)
+ return deduplicated_items
+
+
+def replace_entities_with_placeholders(text, entities, placeholder_fn):
+ """Processes the text in order to replace entity values with placeholders
+ as defined by the placeholder function
+ """
+ if not entities:
+ return dict(), text
+
+ entities = deduplicate_overlapping_entities(entities)
+ entities = sorted(
+ entities, key=lambda e: e[RES_MATCH_RANGE][START])
+
+ range_mapping = dict()
+ processed_text = ""
+ offset = 0
+ current_ix = 0
+ for ent in entities:
+ ent_start = ent[RES_MATCH_RANGE][START]
+ ent_end = ent[RES_MATCH_RANGE][END]
+ rng_start = ent_start + offset
+
+ processed_text += text[current_ix:ent_start]
+
+ entity_length = ent_end - ent_start
+ entity_place_holder = placeholder_fn(ent[ENTITY_KIND])
+
+ offset += len(entity_place_holder) - entity_length
+
+ processed_text += entity_place_holder
+ rng_end = ent_end + offset
+ new_range = (rng_start, rng_end)
+ range_mapping[new_range] = ent[RES_MATCH_RANGE]
+ current_ix = ent_end
+
+ processed_text += text[current_ix:]
+ return range_mapping, processed_text
+
+
+def deduplicate_overlapping_entities(entities):
+ """Deduplicates entities based on overlapping ranges"""
+
+ def overlap(lhs_entity, rhs_entity):
+ return ranges_overlap(lhs_entity[RES_MATCH_RANGE],
+ rhs_entity[RES_MATCH_RANGE])
+
+ def sort_key_fn(entity):
+ return -len(entity[RES_VALUE])
+
+ deduplicated_entities = deduplicate_overlapping_items(
+ entities, overlap, sort_key_fn)
+ return sorted(deduplicated_entities,
+ key=lambda entity: entity[RES_MATCH_RANGE][START])
+
+
+SEMVER_PATTERN = r"^(?P<major>0|[1-9]\d*)" \
+ r".(?P<minor>0|[1-9]\d*)" \
+ r".(?P<patch>0|[1-9]\d*)" \
+ r"(?:.(?P<subpatch>0|[1-9]\d*))?" \
+ r"(?:-(?P<prerelease>(?:0|[1-9]\d*|\d*[a-zA-Z-]" \
+ r"[0-9a-zA-Z-]*)" \
+ r"(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?" \
+ r"(?:\+(?P<buildmetadata>[0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*)" \
+ r")?$"
+SEMVER_REGEX = re.compile(SEMVER_PATTERN)
diff --git a/snips_inference_agl/constants.py b/snips_inference_agl/constants.py
new file mode 100644
index 0000000..8affd08
--- /dev/null
+++ b/snips_inference_agl/constants.py
@@ -0,0 +1,74 @@
+from __future__ import unicode_literals
+
+from pathlib import Path
+
+# package
+ROOT_PATH = Path(__file__).parent.parent
+PACKAGE_NAME = "snips_inference_agl"
+DATA_PACKAGE_NAME = "data"
+DATA_PATH = ROOT_PATH / PACKAGE_NAME / DATA_PACKAGE_NAME
+PACKAGE_PATH = ROOT_PATH / PACKAGE_NAME
+
+# result
+RES_INPUT = "input"
+RES_INTENT = "intent"
+RES_SLOTS = "slots"
+RES_INTENT_NAME = "intentName"
+RES_PROBA = "probability"
+RES_SLOT_NAME = "slotName"
+RES_ENTITY = "entity"
+RES_VALUE = "value"
+RES_RAW_VALUE = "rawValue"
+RES_MATCH_RANGE = "range"
+
+# miscellaneous
+AUTOMATICALLY_EXTENSIBLE = "automatically_extensible"
+USE_SYNONYMS = "use_synonyms"
+SYNONYMS = "synonyms"
+DATA = "data"
+INTENTS = "intents"
+ENTITIES = "entities"
+ENTITY = "entity"
+ENTITY_KIND = "entity_kind"
+RESOLVED_VALUE = "resolved_value"
+SLOT_NAME = "slot_name"
+TEXT = "text"
+UTTERANCES = "utterances"
+LANGUAGE = "language"
+VALUE = "value"
+NGRAM = "ngram"
+CAPITALIZE = "capitalize"
+UNKNOWNWORD = "unknownword"
+VALIDATED = "validated"
+START = "start"
+END = "end"
+BUILTIN_ENTITY_PARSER = "builtin_entity_parser"
+CUSTOM_ENTITY_PARSER = "custom_entity_parser"
+MATCHING_STRICTNESS = "matching_strictness"
+LICENSE_INFO = "license_info"
+RANDOM_STATE = "random_state"
+BYPASS_VERSION_CHECK = "bypass_version_check"
+
+# resources
+RESOURCES = "resources"
+METADATA = "metadata"
+STOP_WORDS = "stop_words"
+NOISE = "noise"
+GAZETTEERS = "gazetteers"
+STEMS = "stems"
+CUSTOM_ENTITY_PARSER_USAGE = "custom_entity_parser_usage"
+WORD_CLUSTERS = "word_clusters"
+
+# builtin entities
+SNIPS_NUMBER = "snips/number"
+
+# languages
+LANGUAGE_DE = "de"
+LANGUAGE_EN = "en"
+LANGUAGE_ES = "es"
+LANGUAGE_FR = "fr"
+LANGUAGE_IT = "it"
+LANGUAGE_JA = "ja"
+LANGUAGE_KO = "ko"
+LANGUAGE_PT_BR = "pt_br"
+LANGUAGE_PT_PT = "pt_pt"
diff --git a/snips_inference_agl/data_augmentation.py b/snips_inference_agl/data_augmentation.py
new file mode 100644
index 0000000..5a37f5e
--- /dev/null
+++ b/snips_inference_agl/data_augmentation.py
@@ -0,0 +1,121 @@
+from __future__ import unicode_literals
+
+from builtins import next
+from copy import deepcopy
+from itertools import cycle
+
+from future.utils import iteritems
+
+from snips_inference_agl.constants import (
+ CAPITALIZE, DATA, ENTITIES, ENTITY, INTENTS, TEXT, UTTERANCES)
+from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity
+from snips_inference_agl.languages import get_default_sep
+from snips_inference_agl.preprocessing import tokenize_light
+from snips_inference_agl.resources import get_stop_words
+
+
+def capitalize(text, language, resources):
+ tokens = tokenize_light(text, language)
+ stop_words = get_stop_words(resources)
+ return get_default_sep(language).join(
+ t.title() if t.lower() not in stop_words
+ else t.lower() for t in tokens)
+
+
+def capitalize_utterances(utterances, entities, language, ratio, resources,
+ random_state):
+ capitalized_utterances = []
+ for utterance in utterances:
+ capitalized_utterance = deepcopy(utterance)
+ for i, chunk in enumerate(capitalized_utterance[DATA]):
+ capitalized_utterance[DATA][i][TEXT] = chunk[TEXT].lower()
+ if ENTITY not in chunk:
+ continue
+ entity_label = chunk[ENTITY]
+ if is_builtin_entity(entity_label):
+ continue
+ if not entities[entity_label][CAPITALIZE]:
+ continue
+ if random_state.rand() > ratio:
+ continue
+ capitalized_utterance[DATA][i][TEXT] = capitalize(
+ chunk[TEXT], language, resources)
+ capitalized_utterances.append(capitalized_utterance)
+ return capitalized_utterances
+
+
+def generate_utterance(contexts_iterator, entities_iterators):
+ context = deepcopy(next(contexts_iterator))
+ context_data = []
+ for chunk in context[DATA]:
+ if ENTITY in chunk:
+ chunk[TEXT] = deepcopy(
+ next(entities_iterators[chunk[ENTITY]]))
+ chunk[TEXT] = chunk[TEXT].strip() + " "
+ context_data.append(chunk)
+ context[DATA] = context_data
+ return context
+
+
+def get_contexts_iterator(dataset, intent_name, random_state):
+ shuffled_utterances = random_state.permutation(
+ dataset[INTENTS][intent_name][UTTERANCES])
+ return cycle(shuffled_utterances)
+
+
+def get_entities_iterators(intent_entities, language,
+ add_builtin_entities_examples, random_state):
+ from snips_nlu_parsers import get_builtin_entity_examples
+
+ entities_its = dict()
+ for entity_name, entity in iteritems(intent_entities):
+ utterance_values = random_state.permutation(sorted(entity[UTTERANCES]))
+ if add_builtin_entities_examples and is_builtin_entity(entity_name):
+ entity_examples = get_builtin_entity_examples(
+ entity_name, language)
+ # Builtin entity examples must be kept first in the iterator to
+ # ensure that they are used when augmenting data
+ iterator_values = entity_examples + list(utterance_values)
+ else:
+ iterator_values = utterance_values
+ entities_its[entity_name] = cycle(iterator_values)
+ return entities_its
+
+
+def get_intent_entities(dataset, intent_name):
+ intent_entities = set()
+ for utterance in dataset[INTENTS][intent_name][UTTERANCES]:
+ for chunk in utterance[DATA]:
+ if ENTITY in chunk:
+ intent_entities.add(chunk[ENTITY])
+ return sorted(intent_entities)
+
+
+def num_queries_to_generate(dataset, intent_name, min_utterances):
+ nb_utterances = len(dataset[INTENTS][intent_name][UTTERANCES])
+ return max(nb_utterances, min_utterances)
+
+
+def augment_utterances(dataset, intent_name, language, min_utterances,
+ capitalization_ratio, add_builtin_entities_examples,
+ resources, random_state):
+ contexts_it = get_contexts_iterator(dataset, intent_name, random_state)
+ intent_entities = {e: dataset[ENTITIES][e]
+ for e in get_intent_entities(dataset, intent_name)}
+ entities_its = get_entities_iterators(intent_entities, language,
+ add_builtin_entities_examples,
+ random_state)
+ generated_utterances = []
+ nb_to_generate = num_queries_to_generate(dataset, intent_name,
+ min_utterances)
+ while nb_to_generate > 0:
+ generated_utterance = generate_utterance(contexts_it, entities_its)
+ generated_utterances.append(generated_utterance)
+ nb_to_generate -= 1
+
+ generated_utterances = capitalize_utterances(
+ generated_utterances, dataset[ENTITIES], language,
+ ratio=capitalization_ratio, resources=resources,
+ random_state=random_state)
+
+ return generated_utterances
diff --git a/snips_inference_agl/dataset/__init__.py b/snips_inference_agl/dataset/__init__.py
new file mode 100644
index 0000000..4cbed08
--- /dev/null
+++ b/snips_inference_agl/dataset/__init__.py
@@ -0,0 +1,7 @@
+from snips_inference_agl.dataset.dataset import Dataset
+from snips_inference_agl.dataset.entity import Entity
+from snips_inference_agl.dataset.intent import Intent
+from snips_inference_agl.dataset.utils import (
+ extract_intent_entities, extract_utterance_entities,
+ get_dataset_gazetteer_entities, get_text_from_chunks)
+from snips_inference_agl.dataset.validation import validate_and_format_dataset
diff --git a/snips_inference_agl/dataset/dataset.py b/snips_inference_agl/dataset/dataset.py
new file mode 100644
index 0000000..2ad7867
--- /dev/null
+++ b/snips_inference_agl/dataset/dataset.py
@@ -0,0 +1,102 @@
+# coding=utf-8
+from __future__ import print_function, unicode_literals
+
+import io
+from itertools import cycle
+
+from snips_inference_agl.common.utils import unicode_string
+from snips_inference_agl.dataset.entity import Entity
+from snips_inference_agl.dataset.intent import Intent
+from snips_inference_agl.exceptions import DatasetFormatError
+
+
+class Dataset(object):
+ """Dataset used in the main NLU training API
+
+ Consists of intents and entities data. This object can be built either from
+ text files (:meth:`.Dataset.from_files`) or from YAML files
+ (:meth:`.Dataset.from_yaml_files`).
+
+ Attributes:
+ language (str): language of the intents
+ intents (list of :class:`.Intent`): intents data
+ entities (list of :class:`.Entity`): entities data
+ """
+
+ def __init__(self, language, intents, entities):
+ self.language = language
+ self.intents = intents
+ self.entities = entities
+ self._add_missing_entities()
+ self._ensure_entity_values()
+
+ @classmethod
+ def _load_dataset_parts(cls, stream, stream_description):
+ from snips_inference_agl.dataset.yaml_wrapper import yaml
+
+ intents = []
+ entities = []
+ for doc in yaml.safe_load_all(stream):
+ doc_type = doc.get("type")
+ if doc_type == "entity":
+ entities.append(Entity.from_yaml(doc))
+ elif doc_type == "intent":
+ intents.append(Intent.from_yaml(doc))
+ else:
+ raise DatasetFormatError(
+ "Invalid 'type' value in YAML file '%s': '%s'"
+ % (stream_description, doc_type))
+ return intents, entities
+
+ def _add_missing_entities(self):
+ entity_names = set(e.name for e in self.entities)
+
+ # Add entities appearing only in the intents utterances
+ for intent in self.intents:
+ for entity_name in intent.entities_names:
+ if entity_name not in entity_names:
+ entity_names.add(entity_name)
+ self.entities.append(Entity(name=entity_name))
+
+ def _ensure_entity_values(self):
+ entities_values = {entity.name: self._get_entity_values(entity)
+ for entity in self.entities}
+ for intent in self.intents:
+ for utterance in intent.utterances:
+ for chunk in utterance.slot_chunks:
+ if chunk.text is not None:
+ continue
+ try:
+ chunk.text = next(entities_values[chunk.entity])
+ except StopIteration:
+ raise DatasetFormatError(
+ "At least one entity value must be provided for "
+ "entity '%s'" % chunk.entity)
+ return self
+
+ def _get_entity_values(self, entity):
+ from snips_nlu_parsers import get_builtin_entity_examples
+
+ if entity.is_builtin:
+ return cycle(get_builtin_entity_examples(
+ entity.name, self.language))
+ values = [v for utterance in entity.utterances
+ for v in utterance.variations]
+ values_set = set(values)
+ for intent in self.intents:
+ for utterance in intent.utterances:
+ for chunk in utterance.slot_chunks:
+ if not chunk.text or chunk.entity != entity.name:
+ continue
+ if chunk.text not in values_set:
+ values_set.add(chunk.text)
+ values.append(chunk.text)
+ return cycle(values)
+
+ @property
+ def json(self):
+ """Dataset data in json format"""
+ intents = {intent_data.intent_name: intent_data.json
+ for intent_data in self.intents}
+ entities = {entity.name: entity.json for entity in self.entities}
+ return dict(language=self.language, intents=intents, entities=entities)
diff --git a/snips_inference_agl/dataset/entity.py b/snips_inference_agl/dataset/entity.py
new file mode 100644
index 0000000..65b9994
--- /dev/null
+++ b/snips_inference_agl/dataset/entity.py
@@ -0,0 +1,175 @@
+# coding=utf-8
+from __future__ import unicode_literals
+
+from builtins import str
+from io import IOBase
+
+from snips_inference_agl.constants import (
+ AUTOMATICALLY_EXTENSIBLE, DATA, MATCHING_STRICTNESS, SYNONYMS,
+ USE_SYNONYMS, VALUE)
+from snips_inference_agl.exceptions import EntityFormatError
+
+
+class Entity(object):
+ """Entity data of a :class:`.Dataset`
+
+ This class can represents both a custom or a builtin entity. When the
+ entity is a builtin one, only the `name` attribute is relevant.
+
+ Attributes:
+ name (str): name of the entity
+ utterances (list of :class:`.EntityUtterance`): entity utterances
+ (only for custom entities)
+ automatically_extensible (bool): whether or not the entity can be
+ extended to values not present in the data (only for custom
+ entities)
+ use_synonyms (bool): whether or not to map entity values using
+ synonyms (only for custom entities)
+ matching_strictness (float): controls the matching strictness of the
+ entity (only for custom entities). Must be between 0.0 and 1.0.
+ """
+
+ def __init__(self, name, utterances=None, automatically_extensible=True,
+ use_synonyms=True, matching_strictness=1.0):
+ if utterances is None:
+ utterances = []
+ self.name = name
+ self.utterances = utterances
+ self.automatically_extensible = automatically_extensible
+ self.use_synonyms = use_synonyms
+ self.matching_strictness = matching_strictness
+
+ @property
+ def is_builtin(self):
+ from snips_nlu_parsers import get_all_builtin_entities
+
+ return self.name in get_all_builtin_entities()
+
+ @classmethod
+ def from_yaml(cls, yaml_dict):
+ """Build an :class:`.Entity` from its YAML definition object
+
+ Args:
+ yaml_dict (dict or :class:`.IOBase`): object containing the YAML
+ definition of the entity. It can be either a stream, or the
+ corresponding python dict.
+
+ Examples:
+ An entity can be defined with a YAML document following the schema
+ illustrated in the example below:
+
+ >>> import io
+ >>> from snips_inference_agl.common.utils import json_string
+ >>> entity_yaml = io.StringIO('''
+ ... # City Entity
+ ... ---
+ ... type: entity
+ ... name: city
+ ... automatically_extensible: false # default value is true
+ ... use_synonyms: false # default value is true
+ ... matching_strictness: 0.8 # default value is 1.0
+ ... values:
+ ... - london
+ ... - [new york, big apple]
+ ... - [paris, city of lights]''')
+ >>> entity = Entity.from_yaml(entity_yaml)
+ >>> print(json_string(entity.json, indent=4, sort_keys=True))
+ {
+ "automatically_extensible": false,
+ "data": [
+ {
+ "synonyms": [],
+ "value": "london"
+ },
+ {
+ "synonyms": [
+ "big apple"
+ ],
+ "value": "new york"
+ },
+ {
+ "synonyms": [
+ "city of lights"
+ ],
+ "value": "paris"
+ }
+ ],
+ "matching_strictness": 0.8,
+ "use_synonyms": false
+ }
+
+ Raises:
+ EntityFormatError: When the YAML dict does not correspond to the
+ :ref:`expected entity format <yaml_entity_format>`
+ """
+ if isinstance(yaml_dict, IOBase):
+ from snips_inference_agl.dataset.yaml_wrapper import yaml
+
+ yaml_dict = yaml.safe_load(yaml_dict)
+
+ object_type = yaml_dict.get("type")
+ if object_type and object_type != "entity":
+ raise EntityFormatError("Wrong type: '%s'" % object_type)
+ entity_name = yaml_dict.get("name")
+ if not entity_name:
+ raise EntityFormatError("Missing 'name' attribute")
+ auto_extensible = yaml_dict.get(AUTOMATICALLY_EXTENSIBLE, True)
+ use_synonyms = yaml_dict.get(USE_SYNONYMS, True)
+ matching_strictness = yaml_dict.get("matching_strictness", 1.0)
+ utterances = []
+ for entity_value in yaml_dict.get("values", []):
+ if isinstance(entity_value, list):
+ utterance = EntityUtterance(entity_value[0], entity_value[1:])
+ elif isinstance(entity_value, str):
+ utterance = EntityUtterance(entity_value)
+ else:
+ raise EntityFormatError(
+ "YAML entity values must be either strings or lists, but "
+ "found: %s" % type(entity_value))
+ utterances.append(utterance)
+
+ return cls(name=entity_name,
+ utterances=utterances,
+ automatically_extensible=auto_extensible,
+ use_synonyms=use_synonyms,
+ matching_strictness=matching_strictness)
+
+ @property
+ def json(self):
+ """Returns the entity in json format"""
+ if self.is_builtin:
+ return dict()
+ return {
+ AUTOMATICALLY_EXTENSIBLE: self.automatically_extensible,
+ USE_SYNONYMS: self.use_synonyms,
+ DATA: [u.json for u in self.utterances],
+ MATCHING_STRICTNESS: self.matching_strictness
+ }
+
+
+class EntityUtterance(object):
+ """Represents a value of a :class:`.CustomEntity` with potential synonyms
+
+ Attributes:
+ value (str): entity value
+ synonyms (list of str): The values to remap to the utterance value
+ """
+
+ def __init__(self, value, synonyms=None):
+ self.value = value
+ if synonyms is None:
+ synonyms = []
+ self.synonyms = synonyms
+
+ @property
+ def variations(self):
+ return [self.value] + self.synonyms
+
+ @property
+ def json(self):
+ return {VALUE: self.value, SYNONYMS: self.synonyms}
+
+
+def utf_8_encoder(f):
+ for line in f:
+ yield line.encode("utf-8")
diff --git a/snips_inference_agl/dataset/intent.py b/snips_inference_agl/dataset/intent.py
new file mode 100644
index 0000000..0a915ce
--- /dev/null
+++ b/snips_inference_agl/dataset/intent.py
@@ -0,0 +1,339 @@
+from __future__ import absolute_import, print_function, unicode_literals
+
+from abc import ABCMeta, abstractmethod
+from builtins import object
+from io import IOBase
+
+from future.utils import with_metaclass
+
+from snips_inference_agl.constants import DATA, ENTITY, SLOT_NAME, TEXT, UTTERANCES
+from snips_inference_agl.exceptions import IntentFormatError
+
+
+class Intent(object):
+ """Intent data of a :class:`.Dataset`
+
+ Attributes:
+ intent_name (str): name of the intent
+ utterances (list of :class:`.IntentUtterance`): annotated intent
+ utterances
+ slot_mapping (dict): mapping between slot names and entities
+ """
+
+ def __init__(self, intent_name, utterances, slot_mapping=None):
+ if slot_mapping is None:
+ slot_mapping = dict()
+ self.intent_name = intent_name
+ self.utterances = utterances
+ self.slot_mapping = slot_mapping
+ self._complete_slot_name_mapping()
+ self._ensure_entity_names()
+
+ @classmethod
+ def from_yaml(cls, yaml_dict):
+ """Build an :class:`.Intent` from its YAML definition object
+
+ Args:
+ yaml_dict (dict or :class:`.IOBase`): object containing the YAML
+ definition of the intent. It can be either a stream, or the
+ corresponding python dict.
+
+ Examples:
+ An intent can be defined with a YAML document following the schema
+ illustrated in the example below:
+
+ >>> import io
+ >>> from snips_inference_agl.common.utils import json_string
+ >>> intent_yaml = io.StringIO('''
+ ... # searchFlight Intent
+ ... ---
+ ... type: intent
+ ... name: searchFlight
+ ... slots:
+ ... - name: origin
+ ... entity: city
+ ... - name: destination
+ ... entity: city
+ ... - name: date
+ ... entity: snips/datetime
+ ... utterances:
+ ... - find me a flight from [origin](Oslo) to [destination](Lima)
+ ... - I need a flight leaving to [destination](Berlin)''')
+ >>> intent = Intent.from_yaml(intent_yaml)
+ >>> print(json_string(intent.json, indent=4, sort_keys=True))
+ {
+ "utterances": [
+ {
+ "data": [
+ {
+ "text": "find me a flight from "
+ },
+ {
+ "entity": "city",
+ "slot_name": "origin",
+ "text": "Oslo"
+ },
+ {
+ "text": " to "
+ },
+ {
+ "entity": "city",
+ "slot_name": "destination",
+ "text": "Lima"
+ }
+ ]
+ },
+ {
+ "data": [
+ {
+ "text": "I need a flight leaving to "
+ },
+ {
+ "entity": "city",
+ "slot_name": "destination",
+ "text": "Berlin"
+ }
+ ]
+ }
+ ]
+ }
+
+ Raises:
+ IntentFormatError: When the YAML dict does not correspond to the
+ :ref:`expected intent format <yaml_intent_format>`
+ """
+
+ if isinstance(yaml_dict, IOBase):
+ from snips_inference_agl.dataset.yaml_wrapper import yaml
+
+ yaml_dict = yaml.safe_load(yaml_dict)
+
+ object_type = yaml_dict.get("type")
+ if object_type and object_type != "intent":
+ raise IntentFormatError("Wrong type: '%s'" % object_type)
+ intent_name = yaml_dict.get("name")
+ if not intent_name:
+ raise IntentFormatError("Missing 'name' attribute")
+ slot_mapping = dict()
+ for slot in yaml_dict.get("slots", []):
+ slot_mapping[slot["name"]] = slot["entity"]
+ utterances = [IntentUtterance.parse(u.strip())
+ for u in yaml_dict["utterances"] if u.strip()]
+ if not utterances:
+ raise IntentFormatError(
+ "Intent must contain at least one utterance")
+ return cls(intent_name, utterances, slot_mapping)
+
+ def _complete_slot_name_mapping(self):
+ for utterance in self.utterances:
+ for chunk in utterance.slot_chunks:
+ if chunk.entity and chunk.slot_name not in self.slot_mapping:
+ self.slot_mapping[chunk.slot_name] = chunk.entity
+ return self
+
+ def _ensure_entity_names(self):
+ for utterance in self.utterances:
+ for chunk in utterance.slot_chunks:
+ if chunk.entity:
+ continue
+ chunk.entity = self.slot_mapping.get(
+ chunk.slot_name, chunk.slot_name)
+ return self
+
+ @property
+ def json(self):
+ """Intent data in json format"""
+ return {
+ UTTERANCES: [
+ {DATA: [chunk.json for chunk in utterance.chunks]}
+ for utterance in self.utterances
+ ]
+ }
+
+ @property
+ def entities_names(self):
+ return set(chunk.entity for u in self.utterances
+ for chunk in u.chunks if isinstance(chunk, SlotChunk))
+
+
+class IntentUtterance(object):
+ def __init__(self, chunks):
+ self.chunks = chunks
+
+ @property
+ def text(self):
+ return "".join((chunk.text for chunk in self.chunks))
+
+ @property
+ def slot_chunks(self):
+ return (chunk for chunk in self.chunks if isinstance(chunk, SlotChunk))
+
+ @classmethod
+ def parse(cls, string):
+ """Parses an utterance
+
+ Args:
+ string (str): an utterance in the class:`.Utterance` format
+
+ Examples:
+
+ >>> from snips_inference_agl.dataset.intent import IntentUtterance
+ >>> u = IntentUtterance.\
+ parse("president of [country:default](France)")
+ >>> u.text
+ 'president of France'
+ >>> len(u.chunks)
+ 2
+ >>> u.chunks[0].text
+ 'president of '
+ >>> u.chunks[1].slot_name
+ 'country'
+ >>> u.chunks[1].entity
+ 'default'
+ """
+ sm = SM(string)
+ capture_text(sm)
+ return cls(sm.chunks)
+
+
+class Chunk(with_metaclass(ABCMeta, object)):
+ def __init__(self, text):
+ self.text = text
+
+ @abstractmethod
+ def json(self):
+ pass
+
+
+class SlotChunk(Chunk):
+ def __init__(self, slot_name, entity, text):
+ super(SlotChunk, self).__init__(text)
+ self.slot_name = slot_name
+ self.entity = entity
+
+ @property
+ def json(self):
+ return {
+ TEXT: self.text,
+ SLOT_NAME: self.slot_name,
+ ENTITY: self.entity,
+ }
+
+
+class TextChunk(Chunk):
+ @property
+ def json(self):
+ return {
+ TEXT: self.text
+ }
+
+
+class SM(object):
+ """State Machine for parsing"""
+
+ def __init__(self, input):
+ self.input = input
+ self.chunks = []
+ self.current = 0
+
+ @property
+ def end_of_input(self):
+ return self.current >= len(self.input)
+
+ def add_slot(self, name, entity=None):
+ """Adds a named slot
+
+ Args:
+ name (str): slot name
+ entity (str): entity name
+ """
+ chunk = SlotChunk(slot_name=name, entity=entity, text=None)
+ self.chunks.append(chunk)
+
+ def add_text(self, text):
+ """Adds a simple text chunk using the current position"""
+ chunk = TextChunk(text=text)
+ self.chunks.append(chunk)
+
+ def add_tagged(self, text):
+ """Adds text to the last slot"""
+ if not self.chunks:
+ raise AssertionError("Cannot add tagged text because chunks list "
+ "is empty")
+ self.chunks[-1].text = text
+
+ def find(self, s):
+ return self.input.find(s, self.current)
+
+ def move(self, pos):
+ """Moves the cursor of the state to position after given
+
+ Args:
+ pos (int): position to place the cursor just after
+ """
+ self.current = pos + 1
+
+ def peek(self):
+ if self.end_of_input:
+ return None
+ return self[0]
+
+ def read(self):
+ c = self[0]
+ self.current += 1
+ return c
+
+ def __getitem__(self, key):
+ current = self.current
+ if isinstance(key, int):
+ return self.input[current + key]
+ elif isinstance(key, slice):
+ start = current + key.start if key.start else current
+ return self.input[slice(start, key.stop, key.step)]
+ else:
+ raise TypeError("Bad key type: %s" % type(key))
+
+
+def capture_text(state):
+ next_pos = state.find('[')
+ sub = state[:] if next_pos < 0 else state[:next_pos]
+ if sub:
+ state.add_text(sub)
+ if next_pos >= 0:
+ state.move(next_pos)
+ capture_slot(state)
+
+
+def capture_slot(state):
+ next_colon_pos = state.find(':')
+ next_square_bracket_pos = state.find(']')
+ if next_square_bracket_pos < 0:
+ raise IntentFormatError(
+ "Missing ending ']' in annotated utterance \"%s\"" % state.input)
+ if next_colon_pos < 0 or next_square_bracket_pos < next_colon_pos:
+ slot_name = state[:next_square_bracket_pos]
+ state.move(next_square_bracket_pos)
+ state.add_slot(slot_name)
+ else:
+ slot_name = state[:next_colon_pos]
+ state.move(next_colon_pos)
+ entity = state[:next_square_bracket_pos]
+ state.move(next_square_bracket_pos)
+ state.add_slot(slot_name, entity)
+ if state.peek() == '(':
+ state.read()
+ capture_tagged(state)
+ else:
+ capture_text(state)
+
+
+def capture_tagged(state):
+ next_pos = state.find(')')
+ if next_pos < 1:
+ raise IntentFormatError(
+ "Missing ending ')' in annotated utterance \"%s\"" % state.input)
+ else:
+ tagged_text = state[:next_pos]
+ state.add_tagged(tagged_text)
+ state.move(next_pos)
+ capture_text(state)
diff --git a/snips_inference_agl/dataset/utils.py b/snips_inference_agl/dataset/utils.py
new file mode 100644
index 0000000..f147f0f
--- /dev/null
+++ b/snips_inference_agl/dataset/utils.py
@@ -0,0 +1,67 @@
+from __future__ import unicode_literals
+
+from future.utils import iteritems, itervalues
+
+from snips_inference_agl.constants import (
+ DATA, ENTITIES, ENTITY, INTENTS, TEXT, UTTERANCES)
+from snips_inference_agl.entity_parser.builtin_entity_parser import is_gazetteer_entity
+
+
+def extract_utterance_entities(dataset):
+ entities_values = {ent_name: set() for ent_name in dataset[ENTITIES]}
+
+ for intent in itervalues(dataset[INTENTS]):
+ for utterance in intent[UTTERANCES]:
+ for chunk in utterance[DATA]:
+ if ENTITY in chunk:
+ entities_values[chunk[ENTITY]].add(chunk[TEXT].strip())
+ return {k: list(v) for k, v in iteritems(entities_values)}
+
+
+def extract_intent_entities(dataset, entity_filter=None):
+ intent_entities = {intent: set() for intent in dataset[INTENTS]}
+ for intent_name, intent_data in iteritems(dataset[INTENTS]):
+ for utterance in intent_data[UTTERANCES]:
+ for chunk in utterance[DATA]:
+ if ENTITY in chunk:
+ if entity_filter and not entity_filter(chunk[ENTITY]):
+ continue
+ intent_entities[intent_name].add(chunk[ENTITY])
+ return intent_entities
+
+
+def extract_entity_values(dataset, apply_normalization):
+ from snips_nlu_utils import normalize
+
+ entities_per_intent = {intent: set() for intent in dataset[INTENTS]}
+ intent_entities = extract_intent_entities(dataset)
+ for intent, entities in iteritems(intent_entities):
+ for entity in entities:
+ entity_values = set(dataset[ENTITIES][entity][UTTERANCES])
+ if apply_normalization:
+ entity_values = {normalize(v) for v in entity_values}
+ entities_per_intent[intent].update(entity_values)
+ return entities_per_intent
+
+
+def get_text_from_chunks(chunks):
+ return "".join(chunk[TEXT] for chunk in chunks)
+
+
+def get_dataset_gazetteer_entities(dataset, intent=None):
+ if intent is not None:
+ return extract_intent_entities(dataset, is_gazetteer_entity)[intent]
+ return {e for e in dataset[ENTITIES] if is_gazetteer_entity(e)}
+
+
+def get_stop_words_whitelist(dataset, stop_words):
+ """Extracts stop words whitelists per intent consisting of entity values
+ that appear in the stop_words list"""
+ entity_values_per_intent = extract_entity_values(
+ dataset, apply_normalization=True)
+ stop_words_whitelist = dict()
+ for intent, entity_values in iteritems(entity_values_per_intent):
+ whitelist = stop_words.intersection(entity_values)
+ if whitelist:
+ stop_words_whitelist[intent] = whitelist
+ return stop_words_whitelist
diff --git a/snips_inference_agl/dataset/validation.py b/snips_inference_agl/dataset/validation.py
new file mode 100644
index 0000000..d6fc4a1
--- /dev/null
+++ b/snips_inference_agl/dataset/validation.py
@@ -0,0 +1,254 @@
+from __future__ import division, unicode_literals
+
+import json
+from builtins import str
+from collections import Counter
+from copy import deepcopy
+
+from future.utils import iteritems, itervalues
+
+from snips_inference_agl.common.dataset_utils import (validate_key, validate_keys,
+ validate_type)
+from snips_inference_agl.constants import (
+ AUTOMATICALLY_EXTENSIBLE, CAPITALIZE, DATA, ENTITIES, ENTITY, INTENTS,
+ LANGUAGE, MATCHING_STRICTNESS, SLOT_NAME, SYNONYMS, TEXT, USE_SYNONYMS,
+ UTTERANCES, VALIDATED, VALUE, LICENSE_INFO)
+from snips_inference_agl.dataset import extract_utterance_entities, Dataset
+from snips_inference_agl.entity_parser.builtin_entity_parser import (
+ BuiltinEntityParser, is_builtin_entity)
+from snips_inference_agl.exceptions import DatasetFormatError
+from snips_inference_agl.preprocessing import tokenize_light
+from snips_inference_agl.string_variations import get_string_variations
+
+NUMBER_VARIATIONS_THRESHOLD = 1e3
+VARIATIONS_GENERATION_THRESHOLD = 1e4
+
+
+def validate_and_format_dataset(dataset):
+ """Checks that the dataset is valid and format it
+
+ Raise:
+ DatasetFormatError: When the dataset format is wrong
+ """
+ from snips_nlu_parsers import get_all_languages
+
+ if isinstance(dataset, Dataset):
+ dataset = dataset.json
+
+ # Make this function idempotent
+ if dataset.get(VALIDATED, False):
+ return dataset
+ dataset = deepcopy(dataset)
+ dataset = json.loads(json.dumps(dataset))
+ validate_type(dataset, dict, object_label="dataset")
+ mandatory_keys = [INTENTS, ENTITIES, LANGUAGE]
+ for key in mandatory_keys:
+ validate_key(dataset, key, object_label="dataset")
+ validate_type(dataset[ENTITIES], dict, object_label="entities")
+ validate_type(dataset[INTENTS], dict, object_label="intents")
+ language = dataset[LANGUAGE]
+ validate_type(language, str, object_label="language")
+ if language not in get_all_languages():
+ raise DatasetFormatError("Unknown language: '%s'" % language)
+
+ dataset[INTENTS] = {
+ intent_name: intent_data
+ for intent_name, intent_data in sorted(iteritems(dataset[INTENTS]))}
+ for intent in itervalues(dataset[INTENTS]):
+ _validate_and_format_intent(intent, dataset[ENTITIES])
+
+ utterance_entities_values = extract_utterance_entities(dataset)
+ builtin_entity_parser = BuiltinEntityParser.build(dataset=dataset)
+
+ dataset[ENTITIES] = {
+ intent_name: entity_data
+ for intent_name, entity_data in sorted(iteritems(dataset[ENTITIES]))}
+
+ for entity_name, entity in iteritems(dataset[ENTITIES]):
+ uterrance_entities = utterance_entities_values[entity_name]
+ if is_builtin_entity(entity_name):
+ dataset[ENTITIES][entity_name] = \
+ _validate_and_format_builtin_entity(entity, uterrance_entities)
+ else:
+ dataset[ENTITIES][entity_name] = \
+ _validate_and_format_custom_entity(
+ entity, uterrance_entities, language,
+ builtin_entity_parser)
+ dataset[VALIDATED] = True
+ return dataset
+
+
+def _validate_and_format_intent(intent, entities):
+ validate_type(intent, dict, "intent")
+ validate_key(intent, UTTERANCES, object_label="intent dict")
+ validate_type(intent[UTTERANCES], list, object_label="utterances")
+ for utterance in intent[UTTERANCES]:
+ validate_type(utterance, dict, object_label="utterance")
+ validate_key(utterance, DATA, object_label="utterance")
+ validate_type(utterance[DATA], list, object_label="utterance data")
+ for chunk in utterance[DATA]:
+ validate_type(chunk, dict, object_label="utterance chunk")
+ validate_key(chunk, TEXT, object_label="chunk")
+ if ENTITY in chunk or SLOT_NAME in chunk:
+ mandatory_keys = [ENTITY, SLOT_NAME]
+ validate_keys(chunk, mandatory_keys, object_label="chunk")
+ if is_builtin_entity(chunk[ENTITY]):
+ continue
+ else:
+ validate_key(entities, chunk[ENTITY],
+ object_label=ENTITIES)
+ return intent
+
+
+def _has_any_capitalization(entity_utterances, language):
+ for utterance in entity_utterances:
+ tokens = tokenize_light(utterance, language)
+ if any(t.isupper() or t.istitle() for t in tokens):
+ return True
+ return False
+
+
+def _add_entity_variations(utterances, entity_variations, entity_value):
+ utterances[entity_value] = entity_value
+ for variation in entity_variations[entity_value]:
+ if variation:
+ utterances[variation] = entity_value
+ return utterances
+
+
+def _extract_entity_values(entity):
+ values = set()
+ for ent in entity[DATA]:
+ values.add(ent[VALUE])
+ if entity[USE_SYNONYMS]:
+ values.update(set(ent[SYNONYMS]))
+ return values
+
+
+def _validate_and_format_custom_entity(entity, utterance_entities, language,
+ builtin_entity_parser):
+ validate_type(entity, dict, object_label="entity")
+
+ # TODO: this is here temporarily, only to allow backward compatibility
+ if MATCHING_STRICTNESS not in entity:
+ strictness = entity.get("parser_threshold", 1.0)
+
+ entity[MATCHING_STRICTNESS] = strictness
+
+ mandatory_keys = [USE_SYNONYMS, AUTOMATICALLY_EXTENSIBLE, DATA,
+ MATCHING_STRICTNESS]
+ validate_keys(entity, mandatory_keys, object_label="custom entity")
+ validate_type(entity[USE_SYNONYMS], bool, object_label="use_synonyms")
+ validate_type(entity[AUTOMATICALLY_EXTENSIBLE], bool,
+ object_label="automatically_extensible")
+ validate_type(entity[DATA], list, object_label="entity data")
+ validate_type(entity[MATCHING_STRICTNESS], (float, int),
+ object_label="matching_strictness")
+
+ formatted_entity = dict()
+ formatted_entity[AUTOMATICALLY_EXTENSIBLE] = entity[
+ AUTOMATICALLY_EXTENSIBLE]
+ formatted_entity[MATCHING_STRICTNESS] = entity[MATCHING_STRICTNESS]
+ if LICENSE_INFO in entity:
+ formatted_entity[LICENSE_INFO] = entity[LICENSE_INFO]
+ use_synonyms = entity[USE_SYNONYMS]
+
+ # Validate format and filter out unused data
+ valid_entity_data = []
+ for entry in entity[DATA]:
+ validate_type(entry, dict, object_label="entity entry")
+ validate_keys(entry, [VALUE, SYNONYMS], object_label="entity entry")
+ entry[VALUE] = entry[VALUE].strip()
+ if not entry[VALUE]:
+ continue
+ validate_type(entry[SYNONYMS], list, object_label="entity synonyms")
+ entry[SYNONYMS] = [s.strip() for s in entry[SYNONYMS] if s.strip()]
+ valid_entity_data.append(entry)
+ entity[DATA] = valid_entity_data
+
+ # Compute capitalization before normalizing
+ # Normalization lowercase and hence lead to bad capitalization calculation
+ formatted_entity[CAPITALIZE] = _has_any_capitalization(utterance_entities,
+ language)
+
+ validated_utterances = dict()
+ # Map original values an synonyms
+ for data in entity[DATA]:
+ ent_value = data[VALUE]
+ validated_utterances[ent_value] = ent_value
+ if use_synonyms:
+ for s in data[SYNONYMS]:
+ if s not in validated_utterances:
+ validated_utterances[s] = ent_value
+
+ # Number variations in entities values are expensive since each entity
+ # value is parsed with the builtin entity parser before creating the
+ # variations. We avoid generating these variations if there's enough entity
+ # values
+
+ # Add variations if not colliding
+ all_original_values = _extract_entity_values(entity)
+ if len(entity[DATA]) < VARIATIONS_GENERATION_THRESHOLD:
+ variations_args = {
+ "case": True,
+ "and_": True,
+ "punctuation": True
+ }
+ else:
+ variations_args = {
+ "case": False,
+ "and_": False,
+ "punctuation": False
+ }
+
+ variations_args["numbers"] = len(
+ entity[DATA]) < NUMBER_VARIATIONS_THRESHOLD
+
+ variations = dict()
+ for data in entity[DATA]:
+ ent_value = data[VALUE]
+ values_to_variate = {ent_value}
+ if use_synonyms:
+ values_to_variate.update(set(data[SYNONYMS]))
+ variations[ent_value] = set(
+ v for value in values_to_variate
+ for v in get_string_variations(
+ value, language, builtin_entity_parser, **variations_args)
+ )
+ variation_counter = Counter(
+ [v for variations_ in itervalues(variations) for v in variations_])
+ non_colliding_variations = {
+ value: [
+ v for v in variations if
+ v not in all_original_values and variation_counter[v] == 1
+ ]
+ for value, variations in iteritems(variations)
+ }
+
+ for entry in entity[DATA]:
+ entry_value = entry[VALUE]
+ validated_utterances = _add_entity_variations(
+ validated_utterances, non_colliding_variations, entry_value)
+
+ # Merge utterances entities
+ utterance_entities_variations = {
+ ent: get_string_variations(
+ ent, language, builtin_entity_parser, **variations_args)
+ for ent in utterance_entities
+ }
+
+ for original_ent, variations in iteritems(utterance_entities_variations):
+ if not original_ent or original_ent in validated_utterances:
+ continue
+ validated_utterances[original_ent] = original_ent
+ for variation in variations:
+ if variation and variation not in validated_utterances \
+ and variation not in utterance_entities:
+ validated_utterances[variation] = original_ent
+ formatted_entity[UTTERANCES] = validated_utterances
+ return formatted_entity
+
+
+def _validate_and_format_builtin_entity(entity, utterance_entities):
+ validate_type(entity, dict, object_label="builtin entity")
+ return {UTTERANCES: set(utterance_entities)}
diff --git a/snips_inference_agl/dataset/yaml_wrapper.py b/snips_inference_agl/dataset/yaml_wrapper.py
new file mode 100644
index 0000000..ba8390d
--- /dev/null
+++ b/snips_inference_agl/dataset/yaml_wrapper.py
@@ -0,0 +1,11 @@
+import yaml
+
+
+def _construct_yaml_str(self, node):
+ # Override the default string handling function
+ # to always return unicode objects
+ return self.construct_scalar(node)
+
+
+yaml.Loader.add_constructor("tag:yaml.org,2002:str", _construct_yaml_str)
+yaml.SafeLoader.add_constructor("tag:yaml.org,2002:str", _construct_yaml_str)
diff --git a/snips_inference_agl/default_configs/__init__.py b/snips_inference_agl/default_configs/__init__.py
new file mode 100644
index 0000000..fc66d33
--- /dev/null
+++ b/snips_inference_agl/default_configs/__init__.py
@@ -0,0 +1,26 @@
+from __future__ import unicode_literals
+
+from snips_inference_agl.constants import (
+ LANGUAGE_DE, LANGUAGE_EN, LANGUAGE_ES, LANGUAGE_FR, LANGUAGE_IT,
+ LANGUAGE_JA, LANGUAGE_KO, LANGUAGE_PT_BR, LANGUAGE_PT_PT)
+from .config_de import CONFIG as CONFIG_DE
+from .config_en import CONFIG as CONFIG_EN
+from .config_es import CONFIG as CONFIG_ES
+from .config_fr import CONFIG as CONFIG_FR
+from .config_it import CONFIG as CONFIG_IT
+from .config_ja import CONFIG as CONFIG_JA
+from .config_ko import CONFIG as CONFIG_KO
+from .config_pt_br import CONFIG as CONFIG_PT_BR
+from .config_pt_pt import CONFIG as CONFIG_PT_PT
+
+DEFAULT_CONFIGS = {
+ LANGUAGE_DE: CONFIG_DE,
+ LANGUAGE_EN: CONFIG_EN,
+ LANGUAGE_ES: CONFIG_ES,
+ LANGUAGE_FR: CONFIG_FR,
+ LANGUAGE_IT: CONFIG_IT,
+ LANGUAGE_JA: CONFIG_JA,
+ LANGUAGE_KO: CONFIG_KO,
+ LANGUAGE_PT_BR: CONFIG_PT_BR,
+ LANGUAGE_PT_PT: CONFIG_PT_PT,
+}
diff --git a/snips_inference_agl/default_configs/config_de.py b/snips_inference_agl/default_configs/config_de.py
new file mode 100644
index 0000000..200fc30
--- /dev/null
+++ b/snips_inference_agl/default_configs/config_de.py
@@ -0,0 +1,159 @@
+from __future__ import unicode_literals
+
+CONFIG = {
+ "unit_name": "nlu_engine",
+ "intent_parsers_configs": [
+ {
+ "unit_name": "lookup_intent_parser",
+ "ignore_stop_words": True
+ },
+ {
+ "unit_name": "probabilistic_intent_parser",
+ "slot_filler_config": {
+ "unit_name": "crf_slot_filler",
+ "feature_factory_configs": [
+ {
+ "args": {
+ "common_words_gazetteer_name":
+ "top_200000_words_stemmed",
+ "use_stemming": True,
+ "n": 1
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, -1, 0, 1, 2]
+ },
+ {
+ "args": {
+ "common_words_gazetteer_name":
+ "top_200000_words_stemmed",
+ "use_stemming": True,
+ "n": 2
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, 1]
+ },
+ {
+ "args": {
+ "prefix_size": 2
+ },
+ "factory_name": "prefix",
+ "offsets": [0]
+ },
+ {
+ "args": {"prefix_size": 5},
+ "factory_name": "prefix",
+ "offsets": [0]
+ },
+ {
+ "args": {"suffix_size": 2},
+ "factory_name": "suffix",
+ "offsets": [0]
+ },
+ {
+ "args": {"suffix_size": 5},
+ "factory_name": "suffix",
+ "offsets": [0]
+ },
+ {
+ "args": {},
+ "factory_name": "is_digit",
+ "offsets": [-1, 0, 1]
+ },
+ {
+ "args": {},
+ "factory_name": "is_first",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {},
+ "factory_name": "is_last",
+ "offsets": [0, 1, 2]
+ },
+ {
+ "args": {"n": 1},
+ "factory_name": "shape_ngram",
+ "offsets": [0]
+ },
+ {
+ "args": {"n": 2},
+ "factory_name": "shape_ngram",
+ "offsets": [-1, 0]
+ },
+ {
+ "args": {"n": 3},
+ "factory_name": "shape_ngram",
+ "offsets": [-1]
+ },
+ {
+ "args": {
+ "use_stemming": True,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": False
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {
+ "use_stemming": True,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": True
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-2, -1, 0],
+ "drop_out": 0.5
+ },
+ {
+ "args": {"tagging_scheme_code": 1},
+ "factory_name": "builtin_entity_match",
+ "offsets": [-2, -1, 0]
+ }
+ ],
+ "crf_args": {
+ "c1": 0.1,
+ "c2": 0.1,
+ "algorithm": "lbfgs"
+ },
+ "tagging_scheme": 1,
+ "data_augmentation_config": {
+ "min_utterances": 200,
+ "capitalization_ratio": 0.2,
+ "add_builtin_entities_examples": True
+ }
+ },
+ "intent_classifier_config": {
+ "unit_name": "log_reg_intent_classifier",
+ "data_augmentation_config": {
+ "min_utterances": 20,
+ "noise_factor": 5,
+ "add_builtin_entities_examples": False,
+ "max_unknown_words": None,
+ "unknown_word_prob": 0.0,
+ "unknown_words_replacement_string": None
+ },
+ "featurizer_config": {
+ "unit_name": "featurizer",
+ "pvalue_threshold": 0.4,
+ "added_cooccurrence_feature_ratio": 0.0,
+ "tfidf_vectorizer_config": {
+ "unit_name": "tfidf_vectorizer",
+ "use_stemming": True,
+ "word_clusters_name": None
+ },
+ "cooccurrence_vectorizer_config": {
+ "unit_name": "cooccurrence_vectorizer",
+ "window_size": None,
+ "filter_stop_words": True,
+ "unknown_words_replacement_string": None,
+ "keep_order": True
+ }
+ },
+ "noise_reweight_factor": 1,
+ }
+ }
+ ]
+}
diff --git a/snips_inference_agl/default_configs/config_en.py b/snips_inference_agl/default_configs/config_en.py
new file mode 100644
index 0000000..12f7ae1
--- /dev/null
+++ b/snips_inference_agl/default_configs/config_en.py
@@ -0,0 +1,145 @@
+from __future__ import unicode_literals
+
+CONFIG = {
+ "unit_name": "nlu_engine",
+ "intent_parsers_configs": [
+ {
+ "unit_name": "lookup_intent_parser",
+ "ignore_stop_words": True
+ },
+ {
+ "unit_name": "probabilistic_intent_parser",
+ "slot_filler_config": {
+ "unit_name": "crf_slot_filler",
+ "feature_factory_configs": [
+ {
+ "args": {
+ "common_words_gazetteer_name":
+ "top_10000_words_stemmed",
+ "use_stemming": True,
+ "n": 1
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, -1, 0, 1, 2]
+ },
+ {
+ "args": {
+ "common_words_gazetteer_name":
+ "top_10000_words_stemmed",
+ "use_stemming": True,
+ "n": 2
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, 1]
+ },
+ {
+ "args": {},
+ "factory_name": "is_digit",
+ "offsets": [-1, 0, 1]
+ },
+ {
+ "args": {},
+ "factory_name": "is_first",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {},
+ "factory_name": "is_last",
+ "offsets": [0, 1, 2]
+ },
+ {
+ "args": {"n": 1},
+ "factory_name": "shape_ngram",
+ "offsets": [0]
+ },
+ {
+ "args": {"n": 2},
+ "factory_name": "shape_ngram",
+ "offsets": [-1, 0]
+ },
+ {
+ "args": {"n": 3},
+ "factory_name": "shape_ngram",
+ "offsets": [-1]
+ },
+ {
+ "args": {
+ "use_stemming": True,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": False
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {
+ "use_stemming": True,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": True
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-2, -1, 0],
+ "drop_out": 0.5
+ },
+ {
+ "args": {"tagging_scheme_code": 1},
+ "factory_name": "builtin_entity_match",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {
+ "cluster_name": "brown_clusters",
+ "use_stemming": False
+ },
+ "factory_name": "word_cluster",
+ "offsets": [-2, -1, 0, 1]
+ }
+ ],
+ "crf_args": {
+ "c1": 0.1,
+ "c2": 0.1,
+ "algorithm": "lbfgs"
+ },
+ "tagging_scheme": 1,
+ "data_augmentation_config": {
+ "min_utterances": 200,
+ "capitalization_ratio": 0.2,
+ "add_builtin_entities_examples": True
+ }
+ },
+ "intent_classifier_config": {
+ "unit_name": "log_reg_intent_classifier",
+ "data_augmentation_config": {
+ "min_utterances": 20,
+ "noise_factor": 5,
+ "add_builtin_entities_examples": False,
+ "max_unknown_words": None,
+ "unknown_word_prob": 0.0,
+ "unknown_words_replacement_string": None
+ },
+ "featurizer_config": {
+ "unit_name": "featurizer",
+ "pvalue_threshold": 0.4,
+ "added_cooccurrence_feature_ratio": 0.0,
+ "tfidf_vectorizer_config": {
+ "unit_name": "tfidf_vectorizer",
+ "use_stemming": False,
+ "word_clusters_name": None
+ },
+ "cooccurrence_vectorizer_config": {
+ "unit_name": "cooccurrence_vectorizer",
+ "window_size": None,
+ "filter_stop_words": True,
+ "unknown_words_replacement_string": None,
+ "keep_order": True
+ }
+ },
+ "noise_reweight_factor": 1,
+ }
+ }
+ ]
+}
diff --git a/snips_inference_agl/default_configs/config_es.py b/snips_inference_agl/default_configs/config_es.py
new file mode 100644
index 0000000..28969ce
--- /dev/null
+++ b/snips_inference_agl/default_configs/config_es.py
@@ -0,0 +1,138 @@
+from __future__ import unicode_literals
+
+CONFIG = {
+ "unit_name": "nlu_engine",
+ "intent_parsers_configs": [
+ {
+ "unit_name": "lookup_intent_parser",
+ "ignore_stop_words": True
+ },
+ {
+ "unit_name": "probabilistic_intent_parser",
+ "slot_filler_config": {
+ "unit_name": "crf_slot_filler",
+ "feature_factory_configs": [
+ {
+ "args": {
+ "common_words_gazetteer_name":
+ "top_10000_words_stemmed",
+ "use_stemming": True,
+ "n": 1
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, -1, 0, 1, 2]
+ },
+ {
+ "args": {
+ "common_words_gazetteer_name":
+ "top_10000_words_stemmed",
+ "use_stemming": True,
+ "n": 2
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, 1]
+ },
+ {
+ "args": {},
+ "factory_name": "is_digit",
+ "offsets": [-1, 0, 1]
+ },
+ {
+ "args": {},
+ "factory_name": "is_first",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {},
+ "factory_name": "is_last",
+ "offsets": [0, 1, 2]
+ },
+ {
+ "args": {"n": 1},
+ "factory_name": "shape_ngram",
+ "offsets": [0]
+ },
+ {
+ "args": {"n": 2},
+ "factory_name": "shape_ngram",
+ "offsets": [-1, 0]
+ },
+ {
+ "args": {"n": 3},
+ "factory_name": "shape_ngram",
+ "offsets": [-1]
+ },
+ {
+ "args": {
+ "use_stemming": True,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": False
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {
+ "use_stemming": True,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": True
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-2, -1, 0],
+ "drop_out": 0.5
+ },
+ {
+ "args": {"tagging_scheme_code": 1},
+ "factory_name": "builtin_entity_match",
+ "offsets": [-2, -1, 0]
+ }
+ ],
+ "crf_args": {
+ "c1": 0.1,
+ "c2": 0.1,
+ "algorithm": "lbfgs"
+ },
+ "tagging_scheme": 1,
+ "data_augmentation_config": {
+ "min_utterances": 200,
+ "capitalization_ratio": 0.2,
+ "add_builtin_entities_examples": True
+ },
+
+ },
+ "intent_classifier_config": {
+ "unit_name": "log_reg_intent_classifier",
+ "data_augmentation_config": {
+ "min_utterances": 20,
+ "noise_factor": 5,
+ "add_builtin_entities_examples": False,
+ "max_unknown_words": None,
+ "unknown_word_prob": 0.0,
+ "unknown_words_replacement_string": None
+ },
+ "featurizer_config": {
+ "unit_name": "featurizer",
+ "pvalue_threshold": 0.4,
+ "added_cooccurrence_feature_ratio": 0.0,
+ "tfidf_vectorizer_config": {
+ "unit_name": "tfidf_vectorizer",
+ "use_stemming": True,
+ "word_clusters_name": None
+ },
+ "cooccurrence_vectorizer_config": {
+ "unit_name": "cooccurrence_vectorizer",
+ "window_size": None,
+ "filter_stop_words": True,
+ "unknown_words_replacement_string": None,
+ "keep_order": True
+ }
+ },
+ "noise_reweight_factor": 1,
+ }
+ }
+ ]
+}
diff --git a/snips_inference_agl/default_configs/config_fr.py b/snips_inference_agl/default_configs/config_fr.py
new file mode 100644
index 0000000..a2da590
--- /dev/null
+++ b/snips_inference_agl/default_configs/config_fr.py
@@ -0,0 +1,137 @@
+from __future__ import unicode_literals
+
+CONFIG = {
+ "unit_name": "nlu_engine",
+ "intent_parsers_configs": [
+ {
+ "unit_name": "lookup_intent_parser",
+ "ignore_stop_words": True
+ },
+ {
+ "unit_name": "probabilistic_intent_parser",
+ "slot_filler_config": {
+ "unit_name": "crf_slot_filler",
+ "feature_factory_configs": [
+ {
+ "args": {
+ "common_words_gazetteer_name":
+ "top_10000_words_stemmed",
+ "use_stemming": True,
+ "n": 1
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, -1, 0, 1, 2]
+ },
+ {
+ "args": {
+ "common_words_gazetteer_name":
+ "top_10000_words_stemmed",
+ "use_stemming": True,
+ "n": 2
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, 1]
+ },
+ {
+ "args": {},
+ "factory_name": "is_digit",
+ "offsets": [-1, 0, 1]
+ },
+ {
+ "args": {},
+ "factory_name": "is_first",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {},
+ "factory_name": "is_last",
+ "offsets": [0, 1, 2]
+ },
+ {
+ "args": {"n": 1},
+ "factory_name": "shape_ngram",
+ "offsets": [0]
+ },
+ {
+ "args": {"n": 2},
+ "factory_name": "shape_ngram",
+ "offsets": [-1, 0]
+ },
+ {
+ "args": {"n": 3},
+ "factory_name": "shape_ngram",
+ "offsets": [-1]
+ },
+ {
+ "args": {
+ "use_stemming": True,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": False
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {
+ "use_stemming": True,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": True
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-2, -1, 0],
+ "drop_out": 0.5
+ },
+ {
+ "args": {"tagging_scheme_code": 1},
+ "factory_name": "builtin_entity_match",
+ "offsets": [-2, -1, 0]
+ }
+ ],
+ "crf_args": {
+ "c1": 0.1,
+ "c2": 0.1,
+ "algorithm": "lbfgs"
+ },
+ "tagging_scheme": 1,
+ "data_augmentation_config": {
+ "min_utterances": 200,
+ "capitalization_ratio": 0.2,
+ "add_builtin_entities_examples": True
+ }
+ },
+ "intent_classifier_config": {
+ "unit_name": "log_reg_intent_classifier",
+ "data_augmentation_config": {
+ "min_utterances": 20,
+ "noise_factor": 5,
+ "add_builtin_entities_examples": False,
+ "max_unknown_words": None,
+ "unknown_word_prob": 0.0,
+ "unknown_words_replacement_string": None
+ },
+ "featurizer_config": {
+ "unit_name": "featurizer",
+ "pvalue_threshold": 0.4,
+ "added_cooccurrence_feature_ratio": 0.0,
+ "tfidf_vectorizer_config": {
+ "unit_name": "tfidf_vectorizer",
+ "use_stemming": True,
+ "word_clusters_name": None
+ },
+ "cooccurrence_vectorizer_config": {
+ "unit_name": "cooccurrence_vectorizer",
+ "window_size": None,
+ "filter_stop_words": True,
+ "unknown_words_replacement_string": None,
+ "keep_order": True
+ }
+ },
+ "noise_reweight_factor": 1,
+ }
+ }
+ ]
+}
diff --git a/snips_inference_agl/default_configs/config_it.py b/snips_inference_agl/default_configs/config_it.py
new file mode 100644
index 0000000..a2da590
--- /dev/null
+++ b/snips_inference_agl/default_configs/config_it.py
@@ -0,0 +1,137 @@
+from __future__ import unicode_literals
+
+CONFIG = {
+ "unit_name": "nlu_engine",
+ "intent_parsers_configs": [
+ {
+ "unit_name": "lookup_intent_parser",
+ "ignore_stop_words": True
+ },
+ {
+ "unit_name": "probabilistic_intent_parser",
+ "slot_filler_config": {
+ "unit_name": "crf_slot_filler",
+ "feature_factory_configs": [
+ {
+ "args": {
+ "common_words_gazetteer_name":
+ "top_10000_words_stemmed",
+ "use_stemming": True,
+ "n": 1
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, -1, 0, 1, 2]
+ },
+ {
+ "args": {
+ "common_words_gazetteer_name":
+ "top_10000_words_stemmed",
+ "use_stemming": True,
+ "n": 2
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, 1]
+ },
+ {
+ "args": {},
+ "factory_name": "is_digit",
+ "offsets": [-1, 0, 1]
+ },
+ {
+ "args": {},
+ "factory_name": "is_first",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {},
+ "factory_name": "is_last",
+ "offsets": [0, 1, 2]
+ },
+ {
+ "args": {"n": 1},
+ "factory_name": "shape_ngram",
+ "offsets": [0]
+ },
+ {
+ "args": {"n": 2},
+ "factory_name": "shape_ngram",
+ "offsets": [-1, 0]
+ },
+ {
+ "args": {"n": 3},
+ "factory_name": "shape_ngram",
+ "offsets": [-1]
+ },
+ {
+ "args": {
+ "use_stemming": True,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": False
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {
+ "use_stemming": True,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": True
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-2, -1, 0],
+ "drop_out": 0.5
+ },
+ {
+ "args": {"tagging_scheme_code": 1},
+ "factory_name": "builtin_entity_match",
+ "offsets": [-2, -1, 0]
+ }
+ ],
+ "crf_args": {
+ "c1": 0.1,
+ "c2": 0.1,
+ "algorithm": "lbfgs"
+ },
+ "tagging_scheme": 1,
+ "data_augmentation_config": {
+ "min_utterances": 200,
+ "capitalization_ratio": 0.2,
+ "add_builtin_entities_examples": True
+ }
+ },
+ "intent_classifier_config": {
+ "unit_name": "log_reg_intent_classifier",
+ "data_augmentation_config": {
+ "min_utterances": 20,
+ "noise_factor": 5,
+ "add_builtin_entities_examples": False,
+ "max_unknown_words": None,
+ "unknown_word_prob": 0.0,
+ "unknown_words_replacement_string": None
+ },
+ "featurizer_config": {
+ "unit_name": "featurizer",
+ "pvalue_threshold": 0.4,
+ "added_cooccurrence_feature_ratio": 0.0,
+ "tfidf_vectorizer_config": {
+ "unit_name": "tfidf_vectorizer",
+ "use_stemming": True,
+ "word_clusters_name": None
+ },
+ "cooccurrence_vectorizer_config": {
+ "unit_name": "cooccurrence_vectorizer",
+ "window_size": None,
+ "filter_stop_words": True,
+ "unknown_words_replacement_string": None,
+ "keep_order": True
+ }
+ },
+ "noise_reweight_factor": 1,
+ }
+ }
+ ]
+}
diff --git a/snips_inference_agl/default_configs/config_ja.py b/snips_inference_agl/default_configs/config_ja.py
new file mode 100644
index 0000000..b28791f
--- /dev/null
+++ b/snips_inference_agl/default_configs/config_ja.py
@@ -0,0 +1,164 @@
+from __future__ import unicode_literals
+
+CONFIG = {
+ "unit_name": "nlu_engine",
+ "intent_parsers_configs": [
+ {
+ "unit_name": "lookup_intent_parser",
+ "ignore_stop_words": False
+ },
+ {
+ "unit_name": "probabilistic_intent_parser",
+ "slot_filler_config": {
+ "unit_name": "crf_slot_filler",
+ "feature_factory_configs": [
+ {
+ "args": {
+ "common_words_gazetteer_name": None,
+ "use_stemming": False,
+ "n": 1
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, -1, 0, 1, 2]
+ },
+ {
+ "args": {
+ "common_words_gazetteer_name": None,
+ "use_stemming": False,
+ "n": 2
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, 0, 1, 2]
+ },
+ {
+ "args": {"prefix_size": 1},
+ "factory_name": "prefix",
+ "offsets": [0, 1]
+ },
+ {
+ "args": {"prefix_size": 2},
+ "factory_name": "prefix",
+ "offsets": [0, 1]
+ },
+ {
+ "args": {"suffix_size": 1},
+ "factory_name": "suffix",
+ "offsets": [0, 1]
+ },
+ {
+ "args": {"suffix_size": 2},
+ "factory_name": "suffix",
+ "offsets": [0, 1]
+ },
+ {
+ "args": {},
+ "factory_name": "is_digit",
+ "offsets": [-1, 0, 1]
+ },
+ {
+ "args": {},
+ "factory_name": "is_first",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {},
+ "factory_name": "is_last",
+ "offsets": [0, 1, 2]
+ },
+ {
+ "args": {"n": 1},
+ "factory_name": "shape_ngram",
+ "offsets": [0]
+ },
+ {
+ "args": {"n": 2},
+ "factory_name": "shape_ngram",
+ "offsets": [-1, 0]
+ },
+ {
+ "args": {"n": 3},
+ "factory_name": "shape_ngram",
+ "offsets": [-1]
+ },
+ {
+ "args": {
+ "use_stemming": False,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": False
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-1, 0, 1, 2],
+ },
+ {
+ "args": {
+ "use_stemming": False,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": False
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-1, 0, 1, 2],
+ "drop_out": 0.5
+ },
+ {
+ "args": {"tagging_scheme_code": 1},
+ "factory_name": "builtin_entity_match",
+ "offsets": [-1, 0, 1, 2]
+ },
+ {
+ "args": {
+ "cluster_name": "w2v_clusters",
+ "use_stemming": False
+ },
+ "factory_name": "word_cluster",
+ "offsets": [-2, -1, 0, 1, 2]
+ }
+ ],
+ "crf_args": {
+ "c1": 0.1,
+ "c2": 0.1,
+ "algorithm": "lbfgs"
+ },
+ "tagging_scheme": 1,
+ "data_augmentation_config": {
+ "min_utterances": 200,
+ "capitalization_ratio": 0.2,
+ "add_builtin_entities_examples": True
+ },
+
+ },
+ "intent_classifier_config": {
+ "unit_name": "log_reg_intent_classifier",
+ "data_augmentation_config": {
+ "min_utterances": 20,
+ "noise_factor": 5,
+ "add_builtin_entities_examples": False,
+ "max_unknown_words": None,
+ "unknown_word_prob": 0.0,
+ "unknown_words_replacement_string": None
+ },
+ "featurizer_config": {
+ "unit_name": "featurizer",
+ "pvalue_threshold": 0.9,
+ "added_cooccurrence_feature_ratio": 0.0,
+ "tfidf_vectorizer_config": {
+ "unit_name": "tfidf_vectorizer",
+ "use_stemming": False,
+ "word_clusters_name": "w2v_clusters"
+ },
+ "cooccurrence_vectorizer_config": {
+ "unit_name": "cooccurrence_vectorizer",
+ "window_size": None,
+ "filter_stop_words": True,
+ "unknown_words_replacement_string": None,
+ "keep_order": True
+ }
+ },
+ "noise_reweight_factor": 1,
+ }
+ }
+ ]
+}
diff --git a/snips_inference_agl/default_configs/config_ko.py b/snips_inference_agl/default_configs/config_ko.py
new file mode 100644
index 0000000..1630796
--- /dev/null
+++ b/snips_inference_agl/default_configs/config_ko.py
@@ -0,0 +1,155 @@
+from __future__ import unicode_literals
+
+CONFIG = {
+ "unit_name": "nlu_engine",
+ "intent_parsers_configs": [
+ {
+ "unit_name": "lookup_intent_parser",
+ "ignore_stop_words": False
+ },
+ {
+ "unit_name": "probabilistic_intent_parser",
+ "slot_filler_config": {
+ "unit_name": "crf_slot_filler",
+ "feature_factory_configs": [
+ {
+ "args": {
+ "common_words_gazetteer_name": None,
+ "use_stemming": False,
+ "n": 1
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, -1, 0, 1, 2]
+ },
+ {
+ "args": {
+ "common_words_gazetteer_name": None,
+ "use_stemming": False,
+ "n": 2
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, 1]
+ },
+ {
+ "args": {"prefix_size": 1},
+ "factory_name": "prefix",
+ "offsets": [0]
+ },
+ {
+ "args": {"prefix_size": 2},
+ "factory_name": "prefix",
+ "offsets": [0]
+ },
+ {
+ "args": {"suffix_size": 1},
+ "factory_name": "suffix",
+ "offsets": [0]
+ },
+ {
+ "args": {"suffix_size": 2},
+ "factory_name": "suffix",
+ "offsets": [0]
+ },
+ {
+ "args": {},
+ "factory_name": "is_digit",
+ "offsets": [-1, 0, 1]
+ },
+ {
+ "args": {},
+ "factory_name": "is_first",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {},
+ "factory_name": "is_last",
+ "offsets": [0, 1, 2]
+ },
+ {
+ "args": {"n": 1},
+ "factory_name": "shape_ngram",
+ "offsets": [0]
+ },
+ {
+ "args": {"n": 2},
+ "factory_name": "shape_ngram",
+ "offsets": [-1, 0]
+ },
+ {
+ "args": {"n": 3},
+ "factory_name": "shape_ngram",
+ "offsets": [-1]
+ },
+ {
+ "args": {
+ "use_stemming": False,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": False
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {
+ "use_stemming": False,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": True
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-2, -1, 0],
+ "drop_out": 0.5
+ },
+ {
+ "args": {"tagging_scheme_code": 1},
+ "factory_name": "builtin_entity_match",
+ "offsets": [-2, -1, 0]
+ }
+ ],
+ "crf_args": {
+ "c1": 0.1,
+ "c2": 0.1,
+ "algorithm": "lbfgs"
+ },
+ "tagging_scheme": 1,
+ "data_augmentation_config": {
+ "min_utterances": 200,
+ "capitalization_ratio": 0.2,
+ "add_builtin_entities_examples": True
+ }
+ },
+ "intent_classifier_config": {
+ "unit_name": "log_reg_intent_classifier",
+ "data_augmentation_config": {
+ "min_utterances": 20,
+ "noise_factor": 5,
+ "add_builtin_entities_examples": False,
+ "max_unknown_words": None,
+ "unknown_word_prob": 0.0,
+ "unknown_words_replacement_string": None
+ },
+ "featurizer_config": {
+ "unit_name": "featurizer",
+ "pvalue_threshold": 0.4,
+ "added_cooccurrence_feature_ratio": 0.0,
+ "tfidf_vectorizer_config": {
+ "unit_name": "tfidf_vectorizer",
+ "use_stemming": False,
+ "word_clusters_name": None
+ },
+ "cooccurrence_vectorizer_config": {
+ "unit_name": "cooccurrence_vectorizer",
+ "window_size": None,
+ "filter_stop_words": True,
+ "unknown_words_replacement_string": None,
+ "keep_order": True
+ }
+ },
+ "noise_reweight_factor": 1,
+ }
+ }
+ ]
+}
diff --git a/snips_inference_agl/default_configs/config_pt_br.py b/snips_inference_agl/default_configs/config_pt_br.py
new file mode 100644
index 0000000..450f0db
--- /dev/null
+++ b/snips_inference_agl/default_configs/config_pt_br.py
@@ -0,0 +1,137 @@
+from __future__ import unicode_literals
+
+CONFIG = {
+ "unit_name": "nlu_engine",
+ "intent_parsers_configs": [
+ {
+ "unit_name": "lookup_intent_parser",
+ "ignore_stop_words": True
+ },
+ {
+ "unit_name": "probabilistic_intent_parser",
+ "slot_filler_config": {
+ "unit_name": "crf_slot_filler",
+ "feature_factory_configs": [
+ {
+ "args": {
+ "common_words_gazetteer_name":
+ "top_5000_words_stemmed",
+ "use_stemming": True,
+ "n": 1
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, -1, 0, 1, 2]
+ },
+ {
+ "args": {
+ "common_words_gazetteer_name":
+ "top_5000_words_stemmed",
+ "use_stemming": True,
+ "n": 2
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, 1]
+ },
+ {
+ "args": {},
+ "factory_name": "is_digit",
+ "offsets": [-1, 0, 1]
+ },
+ {
+ "args": {},
+ "factory_name": "is_first",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {},
+ "factory_name": "is_last",
+ "offsets": [0, 1, 2]
+ },
+ {
+ "args": {"n": 1},
+ "factory_name": "shape_ngram",
+ "offsets": [0]
+ },
+ {
+ "args": {"n": 2},
+ "factory_name": "shape_ngram",
+ "offsets": [-1, 0]
+ },
+ {
+ "args": {"n": 3},
+ "factory_name": "shape_ngram",
+ "offsets": [-1]
+ },
+ {
+ "args": {
+ "use_stemming": True,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": False
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {
+ "use_stemming": True,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": True
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-2, -1, 0],
+ "drop_out": 0.5
+ },
+ {
+ "args": {"tagging_scheme_code": 1},
+ "factory_name": "builtin_entity_match",
+ "offsets": [-2, -1, 0]
+ }
+ ],
+ "crf_args": {
+ "c1": 0.1,
+ "c2": 0.1,
+ "algorithm": "lbfgs"
+ },
+ "tagging_scheme": 1,
+ "data_augmentation_config": {
+ "min_utterances": 200,
+ "capitalization_ratio": 0.2,
+ "add_builtin_entities_examples": True
+ },
+ },
+ "intent_classifier_config": {
+ "unit_name": "log_reg_intent_classifier",
+ "data_augmentation_config": {
+ "min_utterances": 20,
+ "noise_factor": 5,
+ "add_builtin_entities_examples": False,
+ "max_unknown_words": None,
+ "unknown_word_prob": 0.0,
+ "unknown_words_replacement_string": None
+ },
+ "featurizer_config": {
+ "unit_name": "featurizer",
+ "pvalue_threshold": 0.4,
+ "added_cooccurrence_feature_ratio": 0.0,
+ "tfidf_vectorizer_config": {
+ "unit_name": "tfidf_vectorizer",
+ "use_stemming": True,
+ "word_clusters_name": None
+ },
+ "cooccurrence_vectorizer_config": {
+ "unit_name": "cooccurrence_vectorizer",
+ "window_size": None,
+ "filter_stop_words": True,
+ "unknown_words_replacement_string": None,
+ "keep_order": True
+ }
+ },
+ },
+ "noise_reweight_factor": 1,
+ }
+ ]
+}
diff --git a/snips_inference_agl/default_configs/config_pt_pt.py b/snips_inference_agl/default_configs/config_pt_pt.py
new file mode 100644
index 0000000..450f0db
--- /dev/null
+++ b/snips_inference_agl/default_configs/config_pt_pt.py
@@ -0,0 +1,137 @@
+from __future__ import unicode_literals
+
+CONFIG = {
+ "unit_name": "nlu_engine",
+ "intent_parsers_configs": [
+ {
+ "unit_name": "lookup_intent_parser",
+ "ignore_stop_words": True
+ },
+ {
+ "unit_name": "probabilistic_intent_parser",
+ "slot_filler_config": {
+ "unit_name": "crf_slot_filler",
+ "feature_factory_configs": [
+ {
+ "args": {
+ "common_words_gazetteer_name":
+ "top_5000_words_stemmed",
+ "use_stemming": True,
+ "n": 1
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, -1, 0, 1, 2]
+ },
+ {
+ "args": {
+ "common_words_gazetteer_name":
+ "top_5000_words_stemmed",
+ "use_stemming": True,
+ "n": 2
+ },
+ "factory_name": "ngram",
+ "offsets": [-2, 1]
+ },
+ {
+ "args": {},
+ "factory_name": "is_digit",
+ "offsets": [-1, 0, 1]
+ },
+ {
+ "args": {},
+ "factory_name": "is_first",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {},
+ "factory_name": "is_last",
+ "offsets": [0, 1, 2]
+ },
+ {
+ "args": {"n": 1},
+ "factory_name": "shape_ngram",
+ "offsets": [0]
+ },
+ {
+ "args": {"n": 2},
+ "factory_name": "shape_ngram",
+ "offsets": [-1, 0]
+ },
+ {
+ "args": {"n": 3},
+ "factory_name": "shape_ngram",
+ "offsets": [-1]
+ },
+ {
+ "args": {
+ "use_stemming": True,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": False
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {
+ "use_stemming": True,
+ "tagging_scheme_code": 2,
+ "entity_filter": {
+ "automatically_extensible": True
+ }
+ },
+ "factory_name": "entity_match",
+ "offsets": [-2, -1, 0],
+ "drop_out": 0.5
+ },
+ {
+ "args": {"tagging_scheme_code": 1},
+ "factory_name": "builtin_entity_match",
+ "offsets": [-2, -1, 0]
+ }
+ ],
+ "crf_args": {
+ "c1": 0.1,
+ "c2": 0.1,
+ "algorithm": "lbfgs"
+ },
+ "tagging_scheme": 1,
+ "data_augmentation_config": {
+ "min_utterances": 200,
+ "capitalization_ratio": 0.2,
+ "add_builtin_entities_examples": True
+ },
+ },
+ "intent_classifier_config": {
+ "unit_name": "log_reg_intent_classifier",
+ "data_augmentation_config": {
+ "min_utterances": 20,
+ "noise_factor": 5,
+ "add_builtin_entities_examples": False,
+ "max_unknown_words": None,
+ "unknown_word_prob": 0.0,
+ "unknown_words_replacement_string": None
+ },
+ "featurizer_config": {
+ "unit_name": "featurizer",
+ "pvalue_threshold": 0.4,
+ "added_cooccurrence_feature_ratio": 0.0,
+ "tfidf_vectorizer_config": {
+ "unit_name": "tfidf_vectorizer",
+ "use_stemming": True,
+ "word_clusters_name": None
+ },
+ "cooccurrence_vectorizer_config": {
+ "unit_name": "cooccurrence_vectorizer",
+ "window_size": None,
+ "filter_stop_words": True,
+ "unknown_words_replacement_string": None,
+ "keep_order": True
+ }
+ },
+ },
+ "noise_reweight_factor": 1,
+ }
+ ]
+}
diff --git a/snips_inference_agl/entity_parser/__init__.py b/snips_inference_agl/entity_parser/__init__.py
new file mode 100644
index 0000000..c54f0b2
--- /dev/null
+++ b/snips_inference_agl/entity_parser/__init__.py
@@ -0,0 +1,6 @@
+# coding=utf-8
+from __future__ import unicode_literals
+
+from snips_inference_agl.entity_parser.builtin_entity_parser import BuiltinEntityParser
+from snips_inference_agl.entity_parser.custom_entity_parser import (
+ CustomEntityParser, CustomEntityParserUsage)
diff --git a/snips_inference_agl/entity_parser/builtin_entity_parser.py b/snips_inference_agl/entity_parser/builtin_entity_parser.py
new file mode 100644
index 0000000..02fa610
--- /dev/null
+++ b/snips_inference_agl/entity_parser/builtin_entity_parser.py
@@ -0,0 +1,162 @@
+from __future__ import unicode_literals
+
+import json
+import shutil
+
+from future.builtins import str
+
+from snips_inference_agl.common.io_utils import temp_dir
+from snips_inference_agl.common.utils import json_string
+from snips_inference_agl.constants import DATA_PATH, ENTITIES, LANGUAGE
+from snips_inference_agl.entity_parser.entity_parser import EntityParser
+from snips_inference_agl.result import parsed_entity
+
+_BUILTIN_ENTITY_PARSERS = dict()
+
+try:
+ FileNotFoundError
+except NameError:
+ FileNotFoundError = IOError
+
+
+class BuiltinEntityParser(EntityParser):
+ def __init__(self, parser):
+ super(BuiltinEntityParser, self).__init__()
+ self._parser = parser
+
+ def _parse(self, text, scope=None):
+ entities = self._parser.parse(text.lower(), scope=scope)
+ result = []
+ for entity in entities:
+ ent = parsed_entity(
+ entity_kind=entity["entity_kind"],
+ entity_value=entity["value"],
+ entity_resolved_value=entity["entity"],
+ entity_range=entity["range"]
+ )
+ result.append(ent)
+ return result
+
+ def persist(self, path):
+ self._parser.persist(path)
+
+ @classmethod
+ def from_path(cls, path):
+ from snips_nlu_parsers import (
+ BuiltinEntityParser as _BuiltinEntityParser)
+
+ parser = _BuiltinEntityParser.from_path(path)
+ return cls(parser)
+
+ @classmethod
+ def build(cls, dataset=None, language=None, gazetteer_entity_scope=None):
+ from snips_nlu_parsers import get_supported_gazetteer_entities
+
+ global _BUILTIN_ENTITY_PARSERS
+
+ if dataset is not None:
+ language = dataset[LANGUAGE]
+ gazetteer_entity_scope = [entity for entity in dataset[ENTITIES]
+ if is_gazetteer_entity(entity)]
+
+ if language is None:
+ raise ValueError("Either a dataset or a language must be provided "
+ "in order to build a BuiltinEntityParser")
+
+ if gazetteer_entity_scope is None:
+ gazetteer_entity_scope = []
+ caching_key = _get_caching_key(language, gazetteer_entity_scope)
+ if caching_key not in _BUILTIN_ENTITY_PARSERS:
+ for entity in gazetteer_entity_scope:
+ if entity not in get_supported_gazetteer_entities(language):
+ raise ValueError(
+ "Gazetteer entity '%s' is not supported in "
+ "language '%s'" % (entity, language))
+ _BUILTIN_ENTITY_PARSERS[caching_key] = _build_builtin_parser(
+ language, gazetteer_entity_scope)
+ return _BUILTIN_ENTITY_PARSERS[caching_key]
+
+
+def _build_builtin_parser(language, gazetteer_entities):
+ from snips_nlu_parsers import BuiltinEntityParser as _BuiltinEntityParser
+
+ with temp_dir() as serialization_dir:
+ gazetteer_entity_parser = None
+ if gazetteer_entities:
+ gazetteer_entity_parser = _build_gazetteer_parser(
+ serialization_dir, gazetteer_entities, language)
+
+ metadata = {
+ "language": language.upper(),
+ "gazetteer_parser": gazetteer_entity_parser
+ }
+ metadata_path = serialization_dir / "metadata.json"
+ with metadata_path.open("w", encoding="utf-8") as f:
+ f.write(json_string(metadata))
+ parser = _BuiltinEntityParser.from_path(serialization_dir)
+ return BuiltinEntityParser(parser)
+
+
+def _build_gazetteer_parser(target_dir, gazetteer_entities, language):
+ from snips_nlu_parsers import get_builtin_entity_shortname
+
+ gazetteer_parser_name = "gazetteer_entity_parser"
+ gazetteer_parser_path = target_dir / gazetteer_parser_name
+ gazetteer_parser_metadata = []
+ for ent in sorted(gazetteer_entities):
+ # Fetch the compiled parser in the resources
+ source_parser_path = find_gazetteer_entity_data_path(language, ent)
+ short_name = get_builtin_entity_shortname(ent).lower()
+ target_parser_path = gazetteer_parser_path / short_name
+ parser_metadata = {
+ "entity_identifier": ent,
+ "entity_parser": short_name
+ }
+ gazetteer_parser_metadata.append(parser_metadata)
+ # Copy the single entity parser
+ shutil.copytree(str(source_parser_path), str(target_parser_path))
+ # Dump the parser metadata
+ gazetteer_entity_parser_metadata = {
+ "parsers_metadata": gazetteer_parser_metadata
+ }
+ gazetteer_parser_metadata_path = gazetteer_parser_path / "metadata.json"
+ with gazetteer_parser_metadata_path.open("w", encoding="utf-8") as f:
+ f.write(json_string(gazetteer_entity_parser_metadata))
+ return gazetteer_parser_name
+
+
+def is_builtin_entity(entity_label):
+ from snips_nlu_parsers import get_all_builtin_entities
+
+ return entity_label in get_all_builtin_entities()
+
+
+def is_gazetteer_entity(entity_label):
+ from snips_nlu_parsers import get_all_gazetteer_entities
+
+ return entity_label in get_all_gazetteer_entities()
+
+
+def find_gazetteer_entity_data_path(language, entity_name):
+ for directory in DATA_PATH.iterdir():
+ if not directory.is_dir():
+ continue
+ metadata_path = directory / "metadata.json"
+ if not metadata_path.exists():
+ continue
+ with metadata_path.open(encoding="utf8") as f:
+ metadata = json.load(f)
+ if metadata.get("entity_name") == entity_name \
+ and metadata.get("language") == language:
+ return directory / metadata["data_directory"]
+ raise FileNotFoundError(
+ "No data found for the '{e}' builtin entity in language '{lang}'. "
+ "You must download the corresponding resources by running "
+ "'python -m snips_nlu download-entity {e} {lang}' before you can use "
+ "this builtin entity.".format(e=entity_name, lang=language))
+
+
+def _get_caching_key(language, entity_scope):
+ tuple_key = (language,)
+ tuple_key += tuple(entity for entity in sorted(entity_scope))
+ return tuple_key
diff --git a/snips_inference_agl/entity_parser/custom_entity_parser.py b/snips_inference_agl/entity_parser/custom_entity_parser.py
new file mode 100644
index 0000000..949df1f
--- /dev/null
+++ b/snips_inference_agl/entity_parser/custom_entity_parser.py
@@ -0,0 +1,209 @@
+# coding=utf-8
+from __future__ import unicode_literals
+
+import json
+import operator
+from copy import deepcopy
+from pathlib import Path
+
+from future.utils import iteritems, viewvalues
+
+from snips_inference_agl.common.utils import json_string
+from snips_inference_agl.constants import (
+ END, ENTITIES, LANGUAGE, MATCHING_STRICTNESS, START, UTTERANCES,
+ LICENSE_INFO)
+from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity
+from snips_inference_agl.entity_parser.custom_entity_parser_usage import (
+ CustomEntityParserUsage)
+from snips_inference_agl.entity_parser.entity_parser import EntityParser
+from snips_inference_agl.preprocessing import stem, tokenize, tokenize_light
+from snips_inference_agl.result import parsed_entity
+
+STOPWORDS_FRACTION = 1e-3
+
+
+class CustomEntityParser(EntityParser):
+ def __init__(self, parser, language, parser_usage):
+ super(CustomEntityParser, self).__init__()
+ self._parser = parser
+ self.language = language
+ self.parser_usage = parser_usage
+
+ def _parse(self, text, scope=None):
+ tokens = tokenize(text, self.language)
+ shifts = _compute_char_shifts(tokens)
+ cleaned_text = " ".join(token.value for token in tokens)
+
+ entities = self._parser.parse(cleaned_text, scope)
+ result = []
+ for entity in entities:
+ start = entity["range"]["start"]
+ start -= shifts[start]
+ end = entity["range"]["end"]
+ end -= shifts[end - 1]
+ entity_range = {START: start, END: end}
+ ent = parsed_entity(
+ entity_kind=entity["entity_identifier"],
+ entity_value=entity["value"],
+ entity_resolved_value=entity["resolved_value"],
+ entity_range=entity_range
+ )
+ result.append(ent)
+ return result
+
+ def persist(self, path):
+ path = Path(path)
+ path.mkdir()
+ parser_directory = "parser"
+ metadata = {
+ "language": self.language,
+ "parser_usage": self.parser_usage.value,
+ "parser_directory": parser_directory
+ }
+ with (path / "metadata.json").open(mode="w", encoding="utf8") as f:
+ f.write(json_string(metadata))
+ self._parser.persist(path / parser_directory)
+
+ @classmethod
+ def from_path(cls, path):
+ from snips_nlu_parsers import GazetteerEntityParser
+
+ path = Path(path)
+ with (path / "metadata.json").open(encoding="utf8") as f:
+ metadata = json.load(f)
+ language = metadata["language"]
+ parser_usage = CustomEntityParserUsage(metadata["parser_usage"])
+ parser_path = path / metadata["parser_directory"]
+ parser = GazetteerEntityParser.from_path(parser_path)
+ return cls(parser, language, parser_usage)
+
+ @classmethod
+ def build(cls, dataset, parser_usage, resources):
+ from snips_nlu_parsers import GazetteerEntityParser
+ from snips_inference_agl.dataset import validate_and_format_dataset
+
+ dataset = validate_and_format_dataset(dataset)
+ language = dataset[LANGUAGE]
+ custom_entities = {
+ entity_name: deepcopy(entity)
+ for entity_name, entity in iteritems(dataset[ENTITIES])
+ if not is_builtin_entity(entity_name)
+ }
+ if parser_usage == CustomEntityParserUsage.WITH_AND_WITHOUT_STEMS:
+ for ent in viewvalues(custom_entities):
+ stemmed_utterances = _stem_entity_utterances(
+ ent[UTTERANCES], language, resources)
+ ent[UTTERANCES] = _merge_entity_utterances(
+ ent[UTTERANCES], stemmed_utterances)
+ elif parser_usage == CustomEntityParserUsage.WITH_STEMS:
+ for ent in viewvalues(custom_entities):
+ ent[UTTERANCES] = _stem_entity_utterances(
+ ent[UTTERANCES], language, resources)
+ elif parser_usage is None:
+ raise ValueError("A parser usage must be defined in order to fit "
+ "a CustomEntityParser")
+ configuration = _create_custom_entity_parser_configuration(
+ custom_entities,
+ language=dataset[LANGUAGE],
+ stopwords_fraction=STOPWORDS_FRACTION,
+ )
+ parser = GazetteerEntityParser.build(configuration)
+ return cls(parser, language, parser_usage)
+
+
+def _stem_entity_utterances(entity_utterances, language, resources):
+ values = dict()
+ # Sort by resolved value, so that values conflict in a deterministic way
+ for raw_value, resolved_value in sorted(
+ iteritems(entity_utterances), key=operator.itemgetter(1)):
+ stemmed_value = stem(raw_value, language, resources)
+ if stemmed_value not in values:
+ values[stemmed_value] = resolved_value
+ return values
+
+
+def _merge_entity_utterances(raw_utterances, stemmed_utterances):
+ # Sort by resolved value, so that values conflict in a deterministic way
+ for raw_stemmed_value, resolved_value in sorted(
+ iteritems(stemmed_utterances), key=operator.itemgetter(1)):
+ if raw_stemmed_value not in raw_utterances:
+ raw_utterances[raw_stemmed_value] = resolved_value
+ return raw_utterances
+
+
+def _create_custom_entity_parser_configuration(
+ entities, stopwords_fraction, language):
+ """Dynamically creates the gazetteer parser configuration.
+
+ Args:
+ entities (dict): entity for the dataset
+ stopwords_fraction (float): fraction of the vocabulary of
+ the entity values that will be considered as stop words (
+ the top n_vocabulary * stopwords_fraction most frequent words will
+ be considered stop words)
+ language (str): language of the entities
+
+ Returns: the parser configuration as dictionary
+ """
+
+ if not 0 < stopwords_fraction < 1:
+ raise ValueError("stopwords_fraction must be in ]0.0, 1.0[")
+
+ parser_configurations = []
+ for entity_name, entity in sorted(iteritems(entities)):
+ vocabulary = set(
+ t for raw_value in entity[UTTERANCES]
+ for t in tokenize_light(raw_value, language)
+ )
+ num_stopwords = int(stopwords_fraction * len(vocabulary))
+ config = {
+ "entity_identifier": entity_name,
+ "entity_parser": {
+ "threshold": entity[MATCHING_STRICTNESS],
+ "n_gazetteer_stop_words": num_stopwords,
+ "gazetteer": [
+ {
+ "raw_value": k,
+ "resolved_value": v
+ } for k, v in sorted(iteritems(entity[UTTERANCES]))
+ ]
+ }
+ }
+ if LICENSE_INFO in entity:
+ config["entity_parser"][LICENSE_INFO] = entity[LICENSE_INFO]
+ parser_configurations.append(config)
+
+ configuration = {
+ "entity_parsers": parser_configurations
+ }
+
+ return configuration
+
+
+def _compute_char_shifts(tokens):
+ """Compute the shifts in characters that occur when comparing the
+ tokens string with the string consisting of all tokens separated with a
+ space
+
+ For instance, if "hello?world" is tokenized in ["hello", "?", "world"],
+ then the character shifts between "hello?world" and "hello ? world" are
+ [0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2]
+ """
+ characters_shifts = []
+ if not tokens:
+ return characters_shifts
+
+ current_shift = 0
+ for token_index, token in enumerate(tokens):
+ if token_index == 0:
+ previous_token_end = 0
+ previous_space_len = 0
+ else:
+ previous_token_end = tokens[token_index - 1].end
+ previous_space_len = 1
+ offset = (token.start - previous_token_end) - previous_space_len
+ current_shift -= offset
+ token_len = token.end - token.start
+ index_shift = token_len + previous_space_len
+ characters_shifts += [current_shift for _ in range(index_shift)]
+ return characters_shifts
diff --git a/snips_inference_agl/entity_parser/custom_entity_parser_usage.py b/snips_inference_agl/entity_parser/custom_entity_parser_usage.py
new file mode 100644
index 0000000..72d420a
--- /dev/null
+++ b/snips_inference_agl/entity_parser/custom_entity_parser_usage.py
@@ -0,0 +1,23 @@
+from __future__ import unicode_literals
+
+from enum import Enum, unique
+
+
+@unique
+class CustomEntityParserUsage(Enum):
+ WITH_STEMS = 0
+ """The parser is used with stemming"""
+ WITHOUT_STEMS = 1
+ """The parser is used without stemming"""
+ WITH_AND_WITHOUT_STEMS = 2
+ """The parser is used both with and without stemming"""
+
+ @classmethod
+ def merge_usages(cls, lhs_usage, rhs_usage):
+ if lhs_usage is None:
+ return rhs_usage
+ if rhs_usage is None:
+ return lhs_usage
+ if lhs_usage == rhs_usage:
+ return lhs_usage
+ return cls.WITH_AND_WITHOUT_STEMS
diff --git a/snips_inference_agl/entity_parser/entity_parser.py b/snips_inference_agl/entity_parser/entity_parser.py
new file mode 100644
index 0000000..46de55e
--- /dev/null
+++ b/snips_inference_agl/entity_parser/entity_parser.py
@@ -0,0 +1,85 @@
+# coding=utf-8
+from __future__ import unicode_literals
+
+from abc import ABCMeta, abstractmethod
+
+from future.builtins import object
+from future.utils import with_metaclass
+
+from snips_inference_agl.common.dict_utils import LimitedSizeDict
+
+# pylint: disable=ungrouped-imports
+
+try:
+ from abc import abstractclassmethod
+except ImportError:
+ from snips_inference_agl.common.abc_utils import abstractclassmethod
+
+
+# pylint: enable=ungrouped-imports
+
+
+class EntityParser(with_metaclass(ABCMeta, object)):
+ """Abstraction of a entity parser implementing some basic caching
+ """
+
+ def __init__(self):
+ self._cache = LimitedSizeDict(size_limit=1000)
+
+ def parse(self, text, scope=None, use_cache=True):
+ """Search the given text for entities defined in the scope. If no
+ scope is provided, search for all kinds of entities.
+
+ Args:
+ text (str): input text
+ scope (list or set of str, optional): if provided the parser
+ will only look for entities which entity kind is given in
+ the scope. By default the scope is None and the parser
+ will search for all kinds of supported entities
+ use_cache (bool): if False the internal cache will not be use,
+ this can be useful if the output of the parser depends on
+ the current timestamp. Defaults to True.
+
+ Returns:
+ list of dict: list of the parsed entities formatted as a dict
+ containing the string value, the resolved value, the
+ entity kind and the entity range
+ """
+ if not use_cache:
+ return self._parse(text, scope)
+ scope_key = tuple(sorted(scope)) if scope is not None else scope
+ cache_key = (text, scope_key)
+ if cache_key not in self._cache:
+ parser_result = self._parse(text, scope)
+ self._cache[cache_key] = parser_result
+ return self._cache[cache_key]
+
+ @abstractmethod
+ def _parse(self, text, scope=None):
+ """Internal parse method to implement in each subclass of
+ :class:`.EntityParser`
+
+ Args:
+ text (str): input text
+ scope (list or set of str, optional): if provided the parser
+ will only look for entities which entity kind is given in
+ the scope. By default the scope is None and the parser
+ will search for all kinds of supported entities
+ use_cache (bool): if False the internal cache will not be use,
+ this can be useful if the output of the parser depends on
+ the current timestamp. Defaults to True.
+
+ Returns:
+ list of dict: list of the parsed entities. These entity must
+ have the same output format as the
+ :func:`snips_inference_agl.utils.result.parsed_entity` function
+ """
+ pass
+
+ @abstractmethod
+ def persist(self, path):
+ pass
+
+ @abstractclassmethod
+ def from_path(cls, path):
+ pass
diff --git a/snips_inference_agl/exceptions.py b/snips_inference_agl/exceptions.py
new file mode 100644
index 0000000..b1037a1
--- /dev/null
+++ b/snips_inference_agl/exceptions.py
@@ -0,0 +1,87 @@
+from snips_inference_agl.__about__ import __model_version__
+
+
+class SnipsNLUError(Exception):
+ """Base class for exceptions raised in the snips-nlu library"""
+
+
+class IncompatibleModelError(Exception):
+ """Raised when trying to load an incompatible NLU engine
+
+ This happens when the engine data was persisted with a previous version of
+ the library which is not compatible with the one used to load the model.
+ """
+
+ def __init__(self, persisted_version):
+ super(IncompatibleModelError, self).__init__(
+ "Incompatible data model: persisted model=%s, python lib model=%s"
+ % (persisted_version, __model_version__))
+
+
+class InvalidInputError(Exception):
+ """Raised when an incorrect input is passed to one of the APIs"""
+
+
+class NotTrained(SnipsNLUError):
+ """Raised when a processing unit is used while not fitted"""
+
+
+class IntentNotFoundError(SnipsNLUError):
+ """Raised when an intent is used although it was not part of the
+ training data"""
+
+ def __init__(self, intent):
+ super(IntentNotFoundError, self).__init__("Unknown intent '%s'"
+ % intent)
+
+
+class DatasetFormatError(SnipsNLUError):
+ """Raised when attempting to create a Snips NLU dataset using a wrong
+ format"""
+
+
+class EntityFormatError(DatasetFormatError):
+ """Raised when attempting to create a Snips NLU entity using a wrong
+ format"""
+
+
+class IntentFormatError(DatasetFormatError):
+ """Raised when attempting to create a Snips NLU intent using a wrong
+ format"""
+
+
+class AlreadyRegisteredError(SnipsNLUError):
+ """Raised when attempting to register a subclass which is already
+ registered"""
+
+ def __init__(self, name, new_class, existing_class):
+ msg = "Cannot register %s for %s as it has already been used to " \
+ "register %s" \
+ % (name, new_class.__name__, existing_class.__name__)
+ super(AlreadyRegisteredError, self).__init__(msg)
+
+
+class NotRegisteredError(SnipsNLUError):
+ """Raised when trying to use a subclass which was not registered"""
+
+ def __init__(self, registrable_cls, name=None, registered_cls=None):
+ if name is not None:
+ msg = "'%s' has not been registered for type %s. " \
+ % (name, registrable_cls)
+ else:
+ msg = "subclass %s has not been registered for type %s. " \
+ % (registered_cls, registrable_cls)
+ msg += "Use @BaseClass.register('my_component') to register a subclass"
+ super(NotRegisteredError, self).__init__(msg)
+
+class PersistingError(SnipsNLUError):
+ """Raised when trying to persist a processing unit to a path which already
+ exists"""
+
+ def __init__(self, path):
+ super(PersistingError, self).__init__("Path already exists: %s"
+ % str(path))
+
+class LoadingError(SnipsNLUError):
+ """Raised when trying to load a processing unit while some files are
+ missing"""
diff --git a/snips_inference_agl/intent_classifier/__init__.py b/snips_inference_agl/intent_classifier/__init__.py
new file mode 100644
index 0000000..89ccf95
--- /dev/null
+++ b/snips_inference_agl/intent_classifier/__init__.py
@@ -0,0 +1,3 @@
+from .intent_classifier import IntentClassifier
+from .log_reg_classifier import LogRegIntentClassifier
+from .featurizer import Featurizer, CooccurrenceVectorizer, TfidfVectorizer
diff --git a/snips_inference_agl/intent_classifier/featurizer.py b/snips_inference_agl/intent_classifier/featurizer.py
new file mode 100644
index 0000000..116837f
--- /dev/null
+++ b/snips_inference_agl/intent_classifier/featurizer.py
@@ -0,0 +1,452 @@
+from __future__ import division, unicode_literals
+
+import json
+from builtins import str, zip
+from copy import deepcopy
+from pathlib import Path
+
+from future.utils import iteritems
+
+from snips_inference_agl.common.utils import (
+ fitted_required, replace_entities_with_placeholders)
+from snips_inference_agl.constants import (
+ DATA, ENTITY, ENTITY_KIND, NGRAM, TEXT)
+from snips_inference_agl.dataset import get_text_from_chunks
+from snips_inference_agl.entity_parser.builtin_entity_parser import (
+ is_builtin_entity)
+from snips_inference_agl.exceptions import (LoadingError)
+from snips_inference_agl.languages import get_default_sep
+from snips_inference_agl.pipeline.configs import FeaturizerConfig
+from snips_inference_agl.pipeline.configs.intent_classifier import (
+ CooccurrenceVectorizerConfig, TfidfVectorizerConfig)
+from snips_inference_agl.pipeline.processing_unit import ProcessingUnit
+from snips_inference_agl.preprocessing import stem, tokenize_light
+from snips_inference_agl.resources import get_stop_words, get_word_cluster
+from snips_inference_agl.slot_filler.features_utils import get_all_ngrams
+
+
+@ProcessingUnit.register("featurizer")
+class Featurizer(ProcessingUnit):
+ """Feature extractor for text classification relying on ngrams tfidf and
+ optionally word cooccurrences features"""
+
+ config_type = FeaturizerConfig
+
+ def __init__(self, config=None, **shared):
+ super(Featurizer, self).__init__(config, **shared)
+ self.language = None
+ self.tfidf_vectorizer = None
+ self.cooccurrence_vectorizer = None
+
+ @property
+ def fitted(self):
+ if not self.tfidf_vectorizer or not self.tfidf_vectorizer.vocabulary:
+ return False
+ return True
+
+ def transform(self, utterances):
+ import scipy.sparse as sp
+
+ x = self.tfidf_vectorizer.transform(utterances)
+ if self.cooccurrence_vectorizer:
+ x_cooccurrence = self.cooccurrence_vectorizer.transform(utterances)
+ x = sp.hstack((x, x_cooccurrence))
+ return x
+
+ @classmethod
+ def from_path(cls, path, **shared):
+ path = Path(path)
+
+ model_path = path / "featurizer.json"
+ if not model_path.exists():
+ raise LoadingError("Missing featurizer model file: %s"
+ % model_path.name)
+ with model_path.open("r", encoding="utf-8") as f:
+ featurizer_dict = json.load(f)
+
+ featurizer_config = featurizer_dict["config"]
+ featurizer = cls(featurizer_config, **shared)
+
+ featurizer.language = featurizer_dict["language_code"]
+
+ tfidf_vectorizer = featurizer_dict["tfidf_vectorizer"]
+ if tfidf_vectorizer:
+ vectorizer_path = path / featurizer_dict["tfidf_vectorizer"]
+ tfidf_vectorizer = TfidfVectorizer.from_path(
+ vectorizer_path, **shared)
+ featurizer.tfidf_vectorizer = tfidf_vectorizer
+
+ cooccurrence_vectorizer = featurizer_dict["cooccurrence_vectorizer"]
+ if cooccurrence_vectorizer:
+ vectorizer_path = path / featurizer_dict["cooccurrence_vectorizer"]
+ cooccurrence_vectorizer = CooccurrenceVectorizer.from_path(
+ vectorizer_path, **shared)
+ featurizer.cooccurrence_vectorizer = cooccurrence_vectorizer
+
+ return featurizer
+
+
+@ProcessingUnit.register("tfidf_vectorizer")
+class TfidfVectorizer(ProcessingUnit):
+ """Wrapper of the scikit-learn TfidfVectorizer"""
+
+ config_type = TfidfVectorizerConfig
+
+ def __init__(self, config=None, **shared):
+ super(TfidfVectorizer, self).__init__(config, **shared)
+ self._tfidf_vectorizer = None
+ self._language = None
+ self.builtin_entity_scope = None
+
+ @property
+ def fitted(self):
+ return self._tfidf_vectorizer is not None and hasattr(
+ self._tfidf_vectorizer, "vocabulary_")
+
+ @fitted_required
+ def transform(self, x):
+ """Featurizes the given utterances after enriching them with builtin
+ entities matches, custom entities matches and the potential word
+ clusters matches
+
+ Args:
+ x (list of dict): list of utterances
+
+ Returns:
+ :class:`.scipy.sparse.csr_matrix`: A sparse matrix X of shape
+ (len(x), len(self.vocabulary)) where X[i, j] contains tfdif of
+ the ngram of index j of the vocabulary in the utterance i
+
+ Raises:
+ NotTrained: when the vectorizer is not fitted:
+ """
+ utterances = [self._enrich_utterance(*data)
+ for data in zip(*self._preprocess(x))]
+ return self._tfidf_vectorizer.transform(utterances)
+
+ def _preprocess(self, utterances):
+ normalized_utterances = deepcopy(utterances)
+ for u in normalized_utterances:
+ nb_chunks = len(u[DATA])
+ for i, chunk in enumerate(u[DATA]):
+ chunk[TEXT] = _normalize_stem(
+ chunk[TEXT], self.language, self.resources,
+ self.config.use_stemming)
+ if i < nb_chunks - 1:
+ chunk[TEXT] += " "
+
+ # Extract builtin entities on unormalized utterances
+ builtin_ents = [
+ self.builtin_entity_parser.parse(
+ get_text_from_chunks(u[DATA]),
+ self.builtin_entity_scope, use_cache=True)
+ for u in utterances
+ ]
+ # Extract builtin entities on normalized utterances
+ custom_ents = [
+ self.custom_entity_parser.parse(
+ get_text_from_chunks(u[DATA]), use_cache=True)
+ for u in normalized_utterances
+ ]
+ if self.config.word_clusters_name:
+ # Extract world clusters on unormalized utterances
+ original_utterances_text = [get_text_from_chunks(u[DATA])
+ for u in utterances]
+ w_clusters = [
+ _get_word_cluster_features(
+ tokenize_light(u.lower(), self.language),
+ self.config.word_clusters_name,
+ self.resources)
+ for u in original_utterances_text
+ ]
+ else:
+ w_clusters = [None for _ in normalized_utterances]
+
+ return normalized_utterances, builtin_ents, custom_ents, w_clusters
+
+ def _enrich_utterance(self, utterance, builtin_entities, custom_entities,
+ word_clusters):
+ custom_entities_features = [
+ _entity_name_to_feature(e[ENTITY_KIND], self.language)
+ for e in custom_entities]
+
+ builtin_entities_features = [
+ _builtin_entity_to_feature(ent[ENTITY_KIND], self.language)
+ for ent in builtin_entities
+ ]
+
+ # We remove values of builtin slots from the utterance to avoid
+ # learning specific samples such as '42' or 'tomorrow'
+ filtered_tokens = [
+ chunk[TEXT] for chunk in utterance[DATA]
+ if ENTITY not in chunk or not is_builtin_entity(chunk[ENTITY])
+ ]
+
+ features = get_default_sep(self.language).join(filtered_tokens)
+
+ if builtin_entities_features:
+ features += " " + " ".join(sorted(builtin_entities_features))
+ if custom_entities_features:
+ features += " " + " ".join(sorted(custom_entities_features))
+ if word_clusters:
+ features += " " + " ".join(sorted(word_clusters))
+
+ return features
+
+ @property
+ def language(self):
+ # Create this getter to prevent the language from being set elsewhere
+ # than in the fit
+ return self._language
+
+ @property
+ def vocabulary(self):
+ if self._tfidf_vectorizer and hasattr(
+ self._tfidf_vectorizer, "vocabulary_"):
+ return self._tfidf_vectorizer.vocabulary_
+ return None
+
+ @property
+ def idf_diag(self):
+ if self._tfidf_vectorizer and hasattr(
+ self._tfidf_vectorizer, "vocabulary_"):
+ return self._tfidf_vectorizer.idf_
+ return None
+
+ @classmethod
+ # pylint: disable=W0212
+ def from_path(cls, path, **shared):
+ import numpy as np
+ import scipy.sparse as sp
+ from sklearn.feature_extraction.text import (
+ TfidfTransformer, TfidfVectorizer as SklearnTfidfVectorizer)
+
+ path = Path(path)
+
+ model_path = path / "vectorizer.json"
+ if not model_path.exists():
+ raise LoadingError("Missing vectorizer model file: %s"
+ % model_path.name)
+ with model_path.open("r", encoding="utf-8") as f:
+ vectorizer_dict = json.load(f)
+
+ vectorizer = cls(vectorizer_dict["config"], **shared)
+ vectorizer._language = vectorizer_dict["language_code"]
+
+ builtin_entity_scope = vectorizer_dict["builtin_entity_scope"]
+ if builtin_entity_scope is not None:
+ builtin_entity_scope = set(builtin_entity_scope)
+ vectorizer.builtin_entity_scope = builtin_entity_scope
+
+ vectorizer_ = vectorizer_dict["vectorizer"]
+ if vectorizer_:
+ vocab = vectorizer_["vocab"]
+ idf_diag_data = vectorizer_["idf_diag"]
+ idf_diag_data = np.array(idf_diag_data)
+
+ idf_diag_shape = (len(idf_diag_data), len(idf_diag_data))
+ row = list(range(idf_diag_shape[0]))
+ col = list(range(idf_diag_shape[0]))
+ idf_diag = sp.csr_matrix(
+ (idf_diag_data, (row, col)), shape=idf_diag_shape)
+
+ tfidf_transformer = TfidfTransformer()
+ tfidf_transformer._idf_diag = idf_diag
+
+ vectorizer_ = SklearnTfidfVectorizer(
+ tokenizer=lambda x: tokenize_light(x, vectorizer._language))
+ vectorizer_.vocabulary_ = vocab
+
+ vectorizer_._tfidf = tfidf_transformer
+
+ vectorizer._tfidf_vectorizer = vectorizer_
+ return vectorizer
+
+
+@ProcessingUnit.register("cooccurrence_vectorizer")
+class CooccurrenceVectorizer(ProcessingUnit):
+ """Featurizer that takes utterances and extracts ordered word cooccurrence
+ features matrix from them"""
+
+ config_type = CooccurrenceVectorizerConfig
+
+ def __init__(self, config=None, **shared):
+ super(CooccurrenceVectorizer, self).__init__(config, **shared)
+ self._word_pairs = None
+ self._language = None
+ self.builtin_entity_scope = None
+
+ @property
+ def language(self):
+ # Create this getter to prevent the language from being set elsewhere
+ # than in the fit
+ return self._language
+
+ @property
+ def word_pairs(self):
+ return self._word_pairs
+
+ @property
+ def fitted(self):
+ """Whether or not the vectorizer is fitted"""
+ return self.word_pairs is not None
+
+ @fitted_required
+ def transform(self, x):
+ """Computes the cooccurrence feature matrix.
+
+ Args:
+ x (list of dict): list of utterances
+
+ Returns:
+ :class:`.scipy.sparse.csr_matrix`: A sparse matrix X of shape
+ (len(x), len(self.word_pairs)) where X[i, j] = 1.0 if
+ x[i][0] contains the words cooccurrence (w1, w2) and if
+ self.word_pairs[(w1, w2)] = j
+
+ Raises:
+ NotTrained: when the vectorizer is not fitted
+ """
+ import numpy as np
+ import scipy.sparse as sp
+
+ preprocessed = self._preprocess(x)
+ utterances = [
+ self._enrich_utterance(utterance, builtin_ents, custom_ent)
+ for utterance, builtin_ents, custom_ent in zip(*preprocessed)]
+
+ x_coo = sp.dok_matrix((len(x), len(self.word_pairs)), dtype=np.int32)
+ for i, u in enumerate(utterances):
+ for p in self._extract_word_pairs(u):
+ if p in self.word_pairs:
+ x_coo[i, self.word_pairs[p]] = 1
+
+ return x_coo.tocsr()
+
+ def _preprocess(self, x):
+ # Extract all entities on unnormalized data
+ builtin_ents = [
+ self.builtin_entity_parser.parse(
+ get_text_from_chunks(u[DATA]),
+ self.builtin_entity_scope,
+ use_cache=True
+ ) for u in x
+ ]
+ custom_ents = [
+ self.custom_entity_parser.parse(
+ get_text_from_chunks(u[DATA]), use_cache=True)
+ for u in x
+ ]
+ return x, builtin_ents, custom_ents
+
+ def _extract_word_pairs(self, utterance):
+ if self.config.filter_stop_words:
+ stop_words = get_stop_words(self.resources)
+ utterance = [t for t in utterance if t not in stop_words]
+ pairs = set()
+ for j, w1 in enumerate(utterance):
+ max_index = None
+ if self.config.window_size is not None:
+ max_index = j + self.config.window_size + 1
+ for w2 in utterance[j + 1:max_index]:
+ key = (w1, w2)
+ if not self.config.keep_order:
+ key = tuple(sorted(key))
+ pairs.add(key)
+ return pairs
+
+ def _enrich_utterance(self, x, builtin_ents, custom_ents):
+ utterance = get_text_from_chunks(x[DATA])
+ all_entities = builtin_ents + custom_ents
+ placeholder_fn = self._placeholder_fn
+ # Replace entities with placeholders
+ enriched_utterance = replace_entities_with_placeholders(
+ utterance, all_entities, placeholder_fn)[1]
+ # Tokenize
+ enriched_utterance = tokenize_light(enriched_utterance, self.language)
+ # Remove the unknownword strings if needed
+ if self.config.unknown_words_replacement_string:
+ enriched_utterance = [
+ t for t in enriched_utterance
+ if t != self.config.unknown_words_replacement_string
+ ]
+ return enriched_utterance
+
+ def _extract_word_pairs(self, utterance):
+ if self.config.filter_stop_words:
+ stop_words = get_stop_words(self.resources)
+ utterance = [t for t in utterance if t not in stop_words]
+ pairs = set()
+ for j, w1 in enumerate(utterance):
+ max_index = None
+ if self.config.window_size is not None:
+ max_index = j + self.config.window_size + 1
+ for w2 in utterance[j + 1:max_index]:
+ key = (w1, w2)
+ if not self.config.keep_order:
+ key = tuple(sorted(key))
+ pairs.add(key)
+ return pairs
+
+ def _placeholder_fn(self, entity_name):
+ return "".join(
+ tokenize_light(str(entity_name), str(self.language))).upper()
+
+ @classmethod
+ # pylint: disable=protected-access
+ def from_path(cls, path, **shared):
+ path = Path(path)
+ model_path = path / "vectorizer.json"
+ if not model_path.exists():
+ raise LoadingError("Missing vectorizer model file: %s"
+ % model_path.name)
+
+ with model_path.open(encoding="utf8") as f:
+ vectorizer_dict = json.load(f)
+ config = vectorizer_dict.pop("config")
+
+ self = cls(config, **shared)
+ self._language = vectorizer_dict["language_code"]
+ self._word_pairs = None
+
+ builtin_entity_scope = vectorizer_dict["builtin_entity_scope"]
+ if builtin_entity_scope is not None:
+ builtin_entity_scope = set(builtin_entity_scope)
+ self.builtin_entity_scope = builtin_entity_scope
+
+ if vectorizer_dict["word_pairs"]:
+ self._word_pairs = {
+ tuple(p): int(i)
+ for i, p in iteritems(vectorizer_dict["word_pairs"])
+ }
+ return self
+
+def _entity_name_to_feature(entity_name, language):
+ return "entityfeature%s" % "".join(tokenize_light(
+ entity_name.lower(), language))
+
+
+def _builtin_entity_to_feature(builtin_entity_label, language):
+ return "builtinentityfeature%s" % "".join(tokenize_light(
+ builtin_entity_label.lower(), language))
+
+
+def _normalize_stem(text, language, resources, use_stemming):
+ from snips_nlu_utils import normalize
+
+ if use_stemming:
+ return stem(text, language, resources)
+ return normalize(text)
+
+
+def _get_word_cluster_features(query_tokens, clusters_name, resources):
+ if not clusters_name:
+ return []
+ ngrams = get_all_ngrams(query_tokens)
+ cluster_features = []
+ for ngram in ngrams:
+ cluster = get_word_cluster(resources, clusters_name).get(
+ ngram[NGRAM].lower(), None)
+ if cluster is not None:
+ cluster_features.append(cluster)
+ return cluster_features
diff --git a/snips_inference_agl/intent_classifier/intent_classifier.py b/snips_inference_agl/intent_classifier/intent_classifier.py
new file mode 100644
index 0000000..f9a7952
--- /dev/null
+++ b/snips_inference_agl/intent_classifier/intent_classifier.py
@@ -0,0 +1,51 @@
+from abc import ABCMeta
+
+from future.utils import with_metaclass
+
+from snips_inference_agl.pipeline.processing_unit import ProcessingUnit
+from snips_inference_agl.common.abc_utils import classproperty
+
+
+class IntentClassifier(with_metaclass(ABCMeta, ProcessingUnit)):
+ """Abstraction which performs intent classification
+
+ A custom intent classifier must inherit this class to be used in a
+ :class:`.ProbabilisticIntentParser`
+ """
+
+ @classproperty
+ def unit_name(cls): # pylint:disable=no-self-argument
+ return IntentClassifier.registered_name(cls)
+
+ # @abstractmethod
+ def get_intent(self, text, intents_filter):
+ """Performs intent classification on the provided *text*
+
+ Args:
+ text (str): Input
+ intents_filter (str or list of str): When defined, it will find
+ the most likely intent among the list, otherwise it will use
+ the whole list of intents defined in the dataset
+
+ Returns:
+ dict or None: The most likely intent along with its probability or
+ *None* if no intent was found. See
+ :func:`.intent_classification_result` for the output format.
+ """
+ pass
+
+ # @abstractmethod
+ def get_intents(self, text):
+ """Performs intent classification on the provided *text* and returns
+ the list of intents ordered by decreasing probability
+
+ The length of the returned list is exactly the number of intents in the
+ dataset + 1 for the None intent
+
+ .. note::
+
+ The probabilities returned along with each intent are not
+ guaranteed to sum to 1.0. They should be considered as scores
+ between 0 and 1.
+ """
+ pass
diff --git a/snips_inference_agl/intent_classifier/log_reg_classifier.py b/snips_inference_agl/intent_classifier/log_reg_classifier.py
new file mode 100644
index 0000000..09e537c
--- /dev/null
+++ b/snips_inference_agl/intent_classifier/log_reg_classifier.py
@@ -0,0 +1,211 @@
+from __future__ import unicode_literals
+
+import json
+import logging
+from builtins import str, zip
+from pathlib import Path
+
+from snips_inference_agl.common.log_utils import DifferedLoggingMessage
+from snips_inference_agl.common.utils import (fitted_required)
+from snips_inference_agl.constants import RES_PROBA
+from snips_inference_agl.exceptions import LoadingError
+from snips_inference_agl.intent_classifier.featurizer import Featurizer
+from snips_inference_agl.intent_classifier.intent_classifier import IntentClassifier
+from snips_inference_agl.intent_classifier.log_reg_classifier_utils import (text_to_utterance)
+from snips_inference_agl.pipeline.configs import LogRegIntentClassifierConfig
+from snips_inference_agl.result import intent_classification_result
+
+logger = logging.getLogger(__name__)
+
+# We set tol to 1e-3 to silence the following warning with Python 2 (
+# scikit-learn 0.20):
+#
+# FutureWarning: max_iter and tol parameters have been added in SGDClassifier
+# in 0.19. If max_iter is set but tol is left unset, the default value for tol
+# in 0.19 and 0.20 will be None (which is equivalent to -infinity, so it has no
+# effect) but will change in 0.21 to 1e-3. Specify tol to silence this warning.
+
+LOG_REG_ARGS = {
+ "loss": "log",
+ "penalty": "l2",
+ "max_iter": 1000,
+ "tol": 1e-3,
+ "n_jobs": -1
+}
+
+
+@IntentClassifier.register("log_reg_intent_classifier")
+class LogRegIntentClassifier(IntentClassifier):
+ """Intent classifier which uses a Logistic Regression underneath"""
+
+ config_type = LogRegIntentClassifierConfig
+
+ def __init__(self, config=None, **shared):
+ """The LogReg intent classifier can be configured by passing a
+ :class:`.LogRegIntentClassifierConfig`"""
+ super(LogRegIntentClassifier, self).__init__(config, **shared)
+ self.classifier = None
+ self.intent_list = None
+ self.featurizer = None
+
+ @property
+ def fitted(self):
+ """Whether or not the intent classifier has already been fitted"""
+ return self.intent_list is not None
+
+ @fitted_required
+ def get_intent(self, text, intents_filter=None):
+ """Performs intent classification on the provided *text*
+
+ Args:
+ text (str): Input
+ intents_filter (str or list of str): When defined, it will find
+ the most likely intent among the list, otherwise it will use
+ the whole list of intents defined in the dataset
+
+ Returns:
+ dict or None: The most likely intent along with its probability or
+ *None* if no intent was found
+
+ Raises:
+ :class:`snips_nlu.exceptions.NotTrained`: When the intent
+ classifier is not fitted
+
+ """
+ return self._get_intents(text, intents_filter)[0]
+
+ @fitted_required
+ def get_intents(self, text):
+ """Performs intent classification on the provided *text* and returns
+ the list of intents ordered by decreasing probability
+
+ The length of the returned list is exactly the number of intents in the
+ dataset + 1 for the None intent
+
+ Raises:
+ :class:`snips_nlu.exceptions.NotTrained`: when the intent
+ classifier is not fitted
+ """
+ return self._get_intents(text, intents_filter=None)
+
+ def _get_intents(self, text, intents_filter):
+ if isinstance(intents_filter, str):
+ intents_filter = {intents_filter}
+ elif isinstance(intents_filter, list):
+ intents_filter = set(intents_filter)
+
+ if not text or not self.intent_list or not self.featurizer:
+ results = [intent_classification_result(None, 1.0)]
+ results += [intent_classification_result(i, 0.0)
+ for i in self.intent_list if i is not None]
+ return results
+
+ if len(self.intent_list) == 1:
+ return [intent_classification_result(self.intent_list[0], 1.0)]
+
+ # pylint: disable=C0103
+ X = self.featurizer.transform([text_to_utterance(text)])
+ # pylint: enable=C0103
+ proba_vec = self._predict_proba(X)
+ logger.debug(
+ "%s", DifferedLoggingMessage(self.log_activation_weights, text, X))
+ results = [
+ intent_classification_result(i, proba)
+ for i, proba in zip(self.intent_list, proba_vec[0])
+ if intents_filter is None or i is None or i in intents_filter]
+
+ return sorted(results, key=lambda res: -res[RES_PROBA])
+
+ def _predict_proba(self, X): # pylint: disable=C0103
+ import numpy as np
+
+ self.classifier._check_proba() # pylint: disable=W0212
+
+ prob = self.classifier.decision_function(X)
+ prob *= -1
+ np.exp(prob, prob)
+ prob += 1
+ np.reciprocal(prob, prob)
+ if prob.ndim == 1:
+ return np.vstack([1 - prob, prob]).T
+ return prob
+
+ @classmethod
+ def from_path(cls, path, **shared):
+ """Loads a :class:`LogRegIntentClassifier` instance from a path
+
+ The data at the given path must have been generated using
+ :func:`~LogRegIntentClassifier.persist`
+ """
+ import numpy as np
+ from sklearn.linear_model import SGDClassifier
+
+ path = Path(path)
+ model_path = path / "intent_classifier.json"
+ if not model_path.exists():
+ raise LoadingError("Missing intent classifier model file: %s"
+ % model_path.name)
+
+ with model_path.open(encoding="utf8") as f:
+ model_dict = json.load(f)
+
+ # Create the classifier
+ config = LogRegIntentClassifierConfig.from_dict(model_dict["config"])
+ intent_classifier = cls(config=config, **shared)
+ intent_classifier.intent_list = model_dict['intent_list']
+
+ # Create the underlying SGD classifier
+ sgd_classifier = None
+ coeffs = model_dict['coeffs']
+ intercept = model_dict['intercept']
+ t_ = model_dict["t_"]
+ if coeffs is not None and intercept is not None:
+ sgd_classifier = SGDClassifier(**LOG_REG_ARGS)
+ sgd_classifier.coef_ = np.array(coeffs)
+ sgd_classifier.intercept_ = np.array(intercept)
+ sgd_classifier.t_ = t_
+ intent_classifier.classifier = sgd_classifier
+
+ # Add the featurizer
+ featurizer = model_dict['featurizer']
+ if featurizer is not None:
+ featurizer_path = path / featurizer
+ intent_classifier.featurizer = Featurizer.from_path(
+ featurizer_path, **shared)
+
+ return intent_classifier
+
+ def log_activation_weights(self, text, x, top_n=50):
+ import numpy as np
+
+ if not hasattr(self.featurizer, "feature_index_to_feature_name"):
+ return None
+
+ log = "\n\nTop {} feature activations for: \"{}\":\n".format(
+ top_n, text)
+ activations = np.multiply(
+ self.classifier.coef_, np.asarray(x.todense()))
+ abs_activation = np.absolute(activations).flatten().squeeze()
+
+ if top_n > activations.size:
+ top_n = activations.size
+
+ top_n_activations_ix = np.argpartition(abs_activation, -top_n,
+ axis=None)[-top_n:]
+ top_n_activations_ix = np.unravel_index(
+ top_n_activations_ix, activations.shape)
+
+ index_to_feature = self.featurizer.feature_index_to_feature_name
+ features_intent_and_activation = [
+ (self.intent_list[i], index_to_feature[f], activations[i, f])
+ for i, f in zip(*top_n_activations_ix)]
+
+ features_intent_and_activation = sorted(
+ features_intent_and_activation, key=lambda x: abs(x[2]),
+ reverse=True)
+
+ for intent, feature, activation in features_intent_and_activation:
+ log += "\n\n\"{}\" -> ({}, {:.2f})".format(
+ intent, feature, float(activation))
+ log += "\n\n"
+ return log
diff --git a/snips_inference_agl/intent_classifier/log_reg_classifier_utils.py b/snips_inference_agl/intent_classifier/log_reg_classifier_utils.py
new file mode 100644
index 0000000..75a8ab1
--- /dev/null
+++ b/snips_inference_agl/intent_classifier/log_reg_classifier_utils.py
@@ -0,0 +1,94 @@
+from __future__ import division, unicode_literals
+
+import itertools
+import re
+from builtins import next, range, str
+from copy import deepcopy
+from uuid import uuid4
+
+from future.utils import iteritems, itervalues
+
+from snips_inference_agl.constants import (DATA, ENTITY, INTENTS, TEXT,
+ UNKNOWNWORD, UTTERANCES)
+from snips_inference_agl.data_augmentation import augment_utterances
+from snips_inference_agl.dataset import get_text_from_chunks
+from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity
+from snips_inference_agl.preprocessing import tokenize_light
+from snips_inference_agl.resources import get_noise
+
+NOISE_NAME = str(uuid4())
+WORD_REGEX = re.compile(r"\w+(\s+\w+)*")
+UNKNOWNWORD_REGEX = re.compile(r"%s(\s+%s)*" % (UNKNOWNWORD, UNKNOWNWORD))
+
+
+def get_noise_it(noise, mean_length, std_length, random_state):
+ it = itertools.cycle(noise)
+ while True:
+ noise_length = int(random_state.normal(mean_length, std_length))
+ # pylint: disable=stop-iteration-return
+ yield " ".join(next(it) for _ in range(noise_length))
+ # pylint: enable=stop-iteration-return
+
+
+def generate_smart_noise(noise, augmented_utterances, replacement_string,
+ language):
+ text_utterances = [get_text_from_chunks(u[DATA])
+ for u in augmented_utterances]
+ vocab = [w for u in text_utterances for w in tokenize_light(u, language)]
+ vocab = set(vocab)
+ return [w if w in vocab else replacement_string for w in noise]
+
+
+def generate_noise_utterances(augmented_utterances, noise, num_intents,
+ data_augmentation_config, language,
+ random_state):
+ import numpy as np
+
+ if not augmented_utterances or not num_intents:
+ return []
+ avg_num_utterances = len(augmented_utterances) / float(num_intents)
+ if data_augmentation_config.unknown_words_replacement_string is not None:
+ noise = generate_smart_noise(
+ noise, augmented_utterances,
+ data_augmentation_config.unknown_words_replacement_string,
+ language)
+
+ noise_size = min(
+ int(data_augmentation_config.noise_factor * avg_num_utterances),
+ len(noise))
+ utterances_lengths = [
+ len(tokenize_light(get_text_from_chunks(u[DATA]), language))
+ for u in augmented_utterances]
+ mean_utterances_length = np.mean(utterances_lengths)
+ std_utterances_length = np.std(utterances_lengths)
+ noise_it = get_noise_it(noise, mean_utterances_length,
+ std_utterances_length, random_state)
+ # Remove duplicate 'unknownword unknownword'
+ return [
+ text_to_utterance(UNKNOWNWORD_REGEX.sub(UNKNOWNWORD, next(noise_it)))
+ for _ in range(noise_size)]
+
+
+def add_unknown_word_to_utterances(utterances, replacement_string,
+ unknown_word_prob, max_unknown_words,
+ random_state):
+ if not max_unknown_words:
+ return utterances
+
+ new_utterances = deepcopy(utterances)
+ for u in new_utterances:
+ if random_state.rand() < unknown_word_prob:
+ num_unknown = random_state.randint(1, max_unknown_words + 1)
+ # We choose to put the noise at the end of the sentence and not
+ # in the middle so that it doesn't impact to much ngrams
+ # computation
+ extra_chunk = {
+ TEXT: " " + " ".join(
+ replacement_string for _ in range(num_unknown))
+ }
+ u[DATA].append(extra_chunk)
+ return new_utterances
+
+
+def text_to_utterance(text):
+ return {DATA: [{TEXT: text}]}
diff --git a/snips_inference_agl/intent_parser/__init__.py b/snips_inference_agl/intent_parser/__init__.py
new file mode 100644
index 0000000..1b0d446
--- /dev/null
+++ b/snips_inference_agl/intent_parser/__init__.py
@@ -0,0 +1,4 @@
+from .deterministic_intent_parser import DeterministicIntentParser
+from .intent_parser import IntentParser
+from .lookup_intent_parser import LookupIntentParser
+from .probabilistic_intent_parser import ProbabilisticIntentParser
diff --git a/snips_inference_agl/intent_parser/deterministic_intent_parser.py b/snips_inference_agl/intent_parser/deterministic_intent_parser.py
new file mode 100644
index 0000000..845e59d
--- /dev/null
+++ b/snips_inference_agl/intent_parser/deterministic_intent_parser.py
@@ -0,0 +1,518 @@
+from __future__ import unicode_literals
+
+import json
+import logging
+import re
+from builtins import str
+from collections import defaultdict
+from pathlib import Path
+
+from future.utils import iteritems, itervalues
+
+from snips_inference_agl.common.dataset_utils import get_slot_name_mappings
+from snips_inference_agl.common.log_utils import log_elapsed_time, log_result
+from snips_inference_agl.common.utils import (
+ check_persisted_path, deduplicate_overlapping_items, fitted_required,
+ json_string, ranges_overlap, regex_escape,
+ replace_entities_with_placeholders)
+from snips_inference_agl.constants import (
+ DATA, END, ENTITIES, ENTITY,
+ INTENTS, LANGUAGE, RES_INTENT, RES_INTENT_NAME,
+ RES_MATCH_RANGE, RES_SLOTS, RES_VALUE, SLOT_NAME, START, TEXT, UTTERANCES,
+ RES_PROBA)
+from snips_inference_agl.dataset import validate_and_format_dataset
+from snips_inference_agl.dataset.utils import get_stop_words_whitelist
+from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity
+from snips_inference_agl.exceptions import IntentNotFoundError, LoadingError
+from snips_inference_agl.intent_parser.intent_parser import IntentParser
+from snips_inference_agl.pipeline.configs import DeterministicIntentParserConfig
+from snips_inference_agl.preprocessing import normalize_token, tokenize, tokenize_light
+from snips_inference_agl.resources import get_stop_words
+from snips_inference_agl.result import (empty_result, extraction_result,
+ intent_classification_result, parsing_result,
+ unresolved_slot)
+
+WHITESPACE_PATTERN = r"\s*"
+
+logger = logging.getLogger(__name__)
+
+
+@IntentParser.register("deterministic_intent_parser")
+class DeterministicIntentParser(IntentParser):
+ """Intent parser using pattern matching in a deterministic manner
+
+ This intent parser is very strict by nature, and tends to have a very good
+ precision but a low recall. For this reason, it is interesting to use it
+ first before potentially falling back to another parser.
+ """
+
+ config_type = DeterministicIntentParserConfig
+
+ def __init__(self, config=None, **shared):
+ """The deterministic intent parser can be configured by passing a
+ :class:`.DeterministicIntentParserConfig`"""
+ super(DeterministicIntentParser, self).__init__(config, **shared)
+ self._language = None
+ self._slot_names_to_entities = None
+ self._group_names_to_slot_names = None
+ self._stop_words = None
+ self._stop_words_whitelist = None
+ self.slot_names_to_group_names = None
+ self.regexes_per_intent = None
+ self.entity_scopes = None
+
+ @property
+ def language(self):
+ return self._language
+
+ @language.setter
+ def language(self, value):
+ self._language = value
+ if value is None:
+ self._stop_words = None
+ else:
+ if self.config.ignore_stop_words:
+ self._stop_words = get_stop_words(self.resources)
+ else:
+ self._stop_words = set()
+
+ @property
+ def slot_names_to_entities(self):
+ return self._slot_names_to_entities
+
+ @slot_names_to_entities.setter
+ def slot_names_to_entities(self, value):
+ self._slot_names_to_entities = value
+ if value is None:
+ self.entity_scopes = None
+ else:
+ self.entity_scopes = {
+ intent: {
+ "builtin": {ent for ent in itervalues(slot_mapping)
+ if is_builtin_entity(ent)},
+ "custom": {ent for ent in itervalues(slot_mapping)
+ if not is_builtin_entity(ent)}
+ }
+ for intent, slot_mapping in iteritems(value)}
+
+ @property
+ def group_names_to_slot_names(self):
+ return self._group_names_to_slot_names
+
+ @group_names_to_slot_names.setter
+ def group_names_to_slot_names(self, value):
+ self._group_names_to_slot_names = value
+ if value is not None:
+ self.slot_names_to_group_names = {
+ slot_name: group for group, slot_name in iteritems(value)}
+
+ @property
+ def patterns(self):
+ """Dictionary of patterns per intent"""
+ if self.regexes_per_intent is not None:
+ return {i: [r.pattern for r in regex_list] for i, regex_list in
+ iteritems(self.regexes_per_intent)}
+ return None
+
+ @patterns.setter
+ def patterns(self, value):
+ if value is not None:
+ self.regexes_per_intent = dict()
+ for intent, pattern_list in iteritems(value):
+ regexes = [re.compile(r"%s" % p, re.IGNORECASE)
+ for p in pattern_list]
+ self.regexes_per_intent[intent] = regexes
+
+ @property
+ def fitted(self):
+ """Whether or not the intent parser has already been trained"""
+ return self.regexes_per_intent is not None
+
+ @log_elapsed_time(
+ logger, logging.INFO, "Fitted deterministic parser in {elapsed_time}")
+ def fit(self, dataset, force_retrain=True):
+ """Fits the intent parser with a valid Snips dataset"""
+ logger.info("Fitting deterministic intent parser...")
+ dataset = validate_and_format_dataset(dataset)
+ self.load_resources_if_needed(dataset[LANGUAGE])
+ self.fit_builtin_entity_parser_if_needed(dataset)
+ self.fit_custom_entity_parser_if_needed(dataset)
+ self.language = dataset[LANGUAGE]
+ self.regexes_per_intent = dict()
+ entity_placeholders = _get_entity_placeholders(dataset, self.language)
+ self.slot_names_to_entities = get_slot_name_mappings(dataset)
+ self.group_names_to_slot_names = _get_group_names_to_slot_names(
+ self.slot_names_to_entities)
+ self._stop_words_whitelist = get_stop_words_whitelist(
+ dataset, self._stop_words)
+
+ # Do not use ambiguous patterns that appear in more than one intent
+ all_patterns = set()
+ ambiguous_patterns = set()
+ intent_patterns = dict()
+ for intent_name, intent in iteritems(dataset[INTENTS]):
+ patterns = self._generate_patterns(intent_name, intent[UTTERANCES],
+ entity_placeholders)
+ patterns = [p for p in patterns
+ if len(p) < self.config.max_pattern_length]
+ existing_patterns = {p for p in patterns if p in all_patterns}
+ ambiguous_patterns.update(existing_patterns)
+ all_patterns.update(set(patterns))
+ intent_patterns[intent_name] = patterns
+
+ for intent_name, patterns in iteritems(intent_patterns):
+ patterns = [p for p in patterns if p not in ambiguous_patterns]
+ patterns = patterns[:self.config.max_queries]
+ regexes = [re.compile(p, re.IGNORECASE) for p in patterns]
+ self.regexes_per_intent[intent_name] = regexes
+ return self
+
+ @log_result(
+ logger, logging.DEBUG, "DeterministicIntentParser result -> {result}")
+ @log_elapsed_time(logger, logging.DEBUG, "Parsed in {elapsed_time}.")
+ @fitted_required
+ def parse(self, text, intents=None, top_n=None):
+ """Performs intent parsing on the provided *text*
+
+ Intent and slots are extracted simultaneously through pattern matching
+
+ Args:
+ text (str): input
+ intents (str or list of str): if provided, reduces the scope of
+ intent parsing to the provided list of intents
+ top_n (int, optional): when provided, this method will return a
+ list of at most top_n most likely intents, instead of a single
+ parsing result.
+ Note that the returned list can contain less than ``top_n``
+ elements, for instance when the parameter ``intents`` is not
+ None, or when ``top_n`` is greater than the total number of
+ intents.
+
+ Returns:
+ dict or list: the most likely intent(s) along with the extracted
+ slots. See :func:`.parsing_result` and :func:`.extraction_result`
+ for the output format.
+
+ Raises:
+ NotTrained: when the intent parser is not fitted
+ """
+ if top_n is None:
+ top_intents = self._parse_top_intents(text, top_n=1,
+ intents=intents)
+ if top_intents:
+ intent = top_intents[0][RES_INTENT]
+ slots = top_intents[0][RES_SLOTS]
+ if intent[RES_PROBA] <= 0.5:
+ # return None in case of ambiguity
+ return empty_result(text, probability=1.0)
+ return parsing_result(text, intent, slots)
+ return empty_result(text, probability=1.0)
+ return self._parse_top_intents(text, top_n=top_n, intents=intents)
+
+ def _parse_top_intents(self, text, top_n, intents=None):
+ if isinstance(intents, str):
+ intents = {intents}
+ elif isinstance(intents, list):
+ intents = set(intents)
+
+ if top_n < 1:
+ raise ValueError(
+ "top_n argument must be greater or equal to 1, but got: %s"
+ % top_n)
+
+ def placeholder_fn(entity_name):
+ return _get_entity_name_placeholder(entity_name, self.language)
+
+ results = []
+
+ for intent, entity_scope in iteritems(self.entity_scopes):
+ if intents is not None and intent not in intents:
+ continue
+ builtin_entities = self.builtin_entity_parser.parse(
+ text, scope=entity_scope["builtin"], use_cache=True)
+ custom_entities = self.custom_entity_parser.parse(
+ text, scope=entity_scope["custom"], use_cache=True)
+ all_entities = builtin_entities + custom_entities
+ mapping, processed_text = replace_entities_with_placeholders(
+ text, all_entities, placeholder_fn=placeholder_fn)
+ cleaned_text = self._preprocess_text(text, intent)
+ cleaned_processed_text = self._preprocess_text(processed_text,
+ intent)
+ for regex in self.regexes_per_intent[intent]:
+ res = self._get_matching_result(text, cleaned_text, regex,
+ intent)
+ if res is None and cleaned_text != cleaned_processed_text:
+ res = self._get_matching_result(
+ text, cleaned_processed_text, regex, intent, mapping)
+
+ if res is not None:
+ results.append(res)
+ break
+
+ # In some rare cases there can be multiple ambiguous intents
+ # In such cases, priority is given to results containing fewer slots
+ weights = [1.0 / (1.0 + len(res[RES_SLOTS])) for res in results]
+ total_weight = sum(weights)
+
+ for res, weight in zip(results, weights):
+ res[RES_INTENT][RES_PROBA] = weight / total_weight
+
+ results = sorted(results, key=lambda r: -r[RES_INTENT][RES_PROBA])
+
+ return results[:top_n]
+
+ @fitted_required
+ def get_intents(self, text):
+ """Returns the list of intents ordered by decreasing probability
+
+ The length of the returned list is exactly the number of intents in the
+ dataset + 1 for the None intent
+ """
+ nb_intents = len(self.regexes_per_intent)
+ top_intents = [intent_result[RES_INTENT] for intent_result in
+ self._parse_top_intents(text, top_n=nb_intents)]
+ matched_intents = {res[RES_INTENT_NAME] for res in top_intents}
+ for intent in self.regexes_per_intent:
+ if intent not in matched_intents:
+ top_intents.append(intent_classification_result(intent, 0.0))
+
+ # The None intent is not included in the regex patterns and is thus
+ # never matched by the deterministic parser
+ top_intents.append(intent_classification_result(None, 0.0))
+ return top_intents
+
+ @fitted_required
+ def get_slots(self, text, intent):
+ """Extracts slots from a text input, with the knowledge of the intent
+
+ Args:
+ text (str): input
+ intent (str): the intent which the input corresponds to
+
+ Returns:
+ list: the list of extracted slots
+
+ Raises:
+ IntentNotFoundError: When the intent was not part of the training
+ data
+ """
+ if intent is None:
+ return []
+
+ if intent not in self.regexes_per_intent:
+ raise IntentNotFoundError(intent)
+
+ slots = self.parse(text, intents=[intent])[RES_SLOTS]
+ if slots is None:
+ slots = []
+ return slots
+
+ def _get_intent_stop_words(self, intent):
+ whitelist = self._stop_words_whitelist.get(intent, set())
+ return self._stop_words.difference(whitelist)
+
+ def _preprocess_text(self, string, intent):
+ """Replaces stop words and characters that are tokenized out by
+ whitespaces"""
+ tokens = tokenize(string, self.language)
+ current_idx = 0
+ cleaned_string = ""
+ stop_words = self._get_intent_stop_words(intent)
+ for token in tokens:
+ if stop_words and normalize_token(token) in stop_words:
+ token.value = "".join(" " for _ in range(len(token.value)))
+ prefix_length = token.start - current_idx
+ cleaned_string += "".join((" " for _ in range(prefix_length)))
+ cleaned_string += token.value
+ current_idx = token.end
+ suffix_length = len(string) - current_idx
+ cleaned_string += "".join((" " for _ in range(suffix_length)))
+ return cleaned_string
+
+ def _get_matching_result(self, text, processed_text, regex, intent,
+ entities_ranges_mapping=None):
+ found_result = regex.match(processed_text)
+ if found_result is None:
+ return None
+ parsed_intent = intent_classification_result(intent_name=intent,
+ probability=1.0)
+ slots = []
+ for group_name in found_result.groupdict():
+ ref_group_name = group_name
+ if "_" in group_name:
+ ref_group_name = group_name.split("_")[0]
+ slot_name = self.group_names_to_slot_names[ref_group_name]
+ entity = self.slot_names_to_entities[intent][slot_name]
+ rng = (found_result.start(group_name),
+ found_result.end(group_name))
+ if entities_ranges_mapping is not None:
+ if rng in entities_ranges_mapping:
+ rng = entities_ranges_mapping[rng]
+ else:
+ shift = _get_range_shift(
+ rng, entities_ranges_mapping)
+ rng = {START: rng[0] + shift, END: rng[1] + shift}
+ else:
+ rng = {START: rng[0], END: rng[1]}
+ value = text[rng[START]:rng[END]]
+ parsed_slot = unresolved_slot(
+ match_range=rng, value=value, entity=entity,
+ slot_name=slot_name)
+ slots.append(parsed_slot)
+ parsed_slots = _deduplicate_overlapping_slots(slots, self.language)
+ parsed_slots = sorted(parsed_slots,
+ key=lambda s: s[RES_MATCH_RANGE][START])
+ return extraction_result(parsed_intent, parsed_slots)
+
+ def _generate_patterns(self, intent, intent_utterances,
+ entity_placeholders):
+ unique_patterns = set()
+ patterns = []
+ stop_words = self._get_intent_stop_words(intent)
+ for utterance in intent_utterances:
+ pattern = self._utterance_to_pattern(
+ utterance, stop_words, entity_placeholders)
+ if pattern not in unique_patterns:
+ unique_patterns.add(pattern)
+ patterns.append(pattern)
+ return patterns
+
+ def _utterance_to_pattern(self, utterance, stop_words,
+ entity_placeholders):
+ from snips_nlu_utils import normalize
+
+ slot_names_count = defaultdict(int)
+ pattern = []
+ for chunk in utterance[DATA]:
+ if SLOT_NAME in chunk:
+ slot_name = chunk[SLOT_NAME]
+ slot_names_count[slot_name] += 1
+ group_name = self.slot_names_to_group_names[slot_name]
+ count = slot_names_count[slot_name]
+ if count > 1:
+ group_name = "%s_%s" % (group_name, count)
+ placeholder = entity_placeholders[chunk[ENTITY]]
+ pattern.append(r"(?P<%s>%s)" % (group_name, placeholder))
+ else:
+ tokens = tokenize_light(chunk[TEXT], self.language)
+ pattern += [regex_escape(t.lower()) for t in tokens
+ if normalize(t) not in stop_words]
+
+ pattern = r"^%s%s%s$" % (WHITESPACE_PATTERN,
+ WHITESPACE_PATTERN.join(pattern),
+ WHITESPACE_PATTERN)
+ return pattern
+
+ @check_persisted_path
+ def persist(self, path):
+ """Persists the object at the given path"""
+ path.mkdir()
+ parser_json = json_string(self.to_dict())
+ parser_path = path / "intent_parser.json"
+
+ with parser_path.open(mode="w", encoding="utf8") as f:
+ f.write(parser_json)
+ self.persist_metadata(path)
+
+ @classmethod
+ def from_path(cls, path, **shared):
+ """Loads a :class:`DeterministicIntentParser` instance from a path
+
+ The data at the given path must have been generated using
+ :func:`~DeterministicIntentParser.persist`
+ """
+ path = Path(path)
+ model_path = path / "intent_parser.json"
+ if not model_path.exists():
+ raise LoadingError(
+ "Missing deterministic intent parser metadata file: %s"
+ % model_path.name)
+
+ with model_path.open(encoding="utf8") as f:
+ metadata = json.load(f)
+ return cls.from_dict(metadata, **shared)
+
+ def to_dict(self):
+ """Returns a json-serializable dict"""
+ stop_words_whitelist = None
+ if self._stop_words_whitelist is not None:
+ stop_words_whitelist = {
+ intent: sorted(values)
+ for intent, values in iteritems(self._stop_words_whitelist)}
+ return {
+ "config": self.config.to_dict(),
+ "language_code": self.language,
+ "patterns": self.patterns,
+ "group_names_to_slot_names": self.group_names_to_slot_names,
+ "slot_names_to_entities": self.slot_names_to_entities,
+ "stop_words_whitelist": stop_words_whitelist
+ }
+
+ @classmethod
+ def from_dict(cls, unit_dict, **shared):
+ """Creates a :class:`DeterministicIntentParser` instance from a dict
+
+ The dict must have been generated with
+ :func:`~DeterministicIntentParser.to_dict`
+ """
+ config = cls.config_type.from_dict(unit_dict["config"])
+ parser = cls(config=config, **shared)
+ parser.patterns = unit_dict["patterns"]
+ parser.language = unit_dict["language_code"]
+ parser.group_names_to_slot_names = unit_dict[
+ "group_names_to_slot_names"]
+ parser.slot_names_to_entities = unit_dict["slot_names_to_entities"]
+ if parser.fitted:
+ whitelist = unit_dict.get("stop_words_whitelist", dict())
+ # pylint:disable=protected-access
+ parser._stop_words_whitelist = {
+ intent: set(values) for intent, values in iteritems(whitelist)}
+ # pylint:enable=protected-access
+ return parser
+
+
+def _get_range_shift(matched_range, ranges_mapping):
+ shift = 0
+ previous_replaced_range_end = None
+ matched_start = matched_range[0]
+ for replaced_range, orig_range in iteritems(ranges_mapping):
+ if replaced_range[1] <= matched_start:
+ if previous_replaced_range_end is None \
+ or replaced_range[1] > previous_replaced_range_end:
+ previous_replaced_range_end = replaced_range[1]
+ shift = orig_range[END] - replaced_range[1]
+ return shift
+
+
+def _get_group_names_to_slot_names(slot_names_mapping):
+ slot_names = {slot_name for mapping in itervalues(slot_names_mapping)
+ for slot_name in mapping}
+ return {"group%s" % i: name
+ for i, name in enumerate(sorted(slot_names))}
+
+
+def _get_entity_placeholders(dataset, language):
+ return {
+ e: _get_entity_name_placeholder(e, language)
+ for e in dataset[ENTITIES]
+ }
+
+
+def _deduplicate_overlapping_slots(slots, language):
+ def overlap(lhs_slot, rhs_slot):
+ return ranges_overlap(lhs_slot[RES_MATCH_RANGE],
+ rhs_slot[RES_MATCH_RANGE])
+
+ def sort_key_fn(slot):
+ tokens = tokenize(slot[RES_VALUE], language)
+ return -(len(tokens) + len(slot[RES_VALUE]))
+
+ deduplicated_slots = deduplicate_overlapping_items(
+ slots, overlap, sort_key_fn)
+ return sorted(deduplicated_slots,
+ key=lambda slot: slot[RES_MATCH_RANGE][START])
+
+
+def _get_entity_name_placeholder(entity_label, language):
+ return "%%%s%%" % "".join(
+ tokenize_light(entity_label, language)).upper()
diff --git a/snips_inference_agl/intent_parser/intent_parser.py b/snips_inference_agl/intent_parser/intent_parser.py
new file mode 100644
index 0000000..b269774
--- /dev/null
+++ b/snips_inference_agl/intent_parser/intent_parser.py
@@ -0,0 +1,85 @@
+from abc import abstractmethod, ABCMeta
+
+from future.utils import with_metaclass
+
+from snips_inference_agl.common.abc_utils import classproperty
+from snips_inference_agl.pipeline.processing_unit import ProcessingUnit
+
+
+class IntentParser(with_metaclass(ABCMeta, ProcessingUnit)):
+ """Abstraction which performs intent parsing
+
+ A custom intent parser must inherit this class to be used in a
+ :class:`.SnipsNLUEngine`
+ """
+
+ @classproperty
+ def unit_name(cls): # pylint:disable=no-self-argument
+ return IntentParser.registered_name(cls)
+
+ @abstractmethod
+ def fit(self, dataset, force_retrain):
+ """Fit the intent parser with a valid Snips dataset
+
+ Args:
+ dataset (dict): valid Snips NLU dataset
+ force_retrain (bool): specify whether or not sub units of the
+ intent parser that may be already trained should be retrained
+ """
+ pass
+
+ @abstractmethod
+ def parse(self, text, intents, top_n):
+ """Performs intent parsing on the provided *text*
+
+ Args:
+ text (str): input
+ intents (str or list of str): if provided, reduces the scope of
+ intent parsing to the provided list of intents
+ top_n (int, optional): when provided, this method will return a
+ list of at most top_n most likely intents, instead of a single
+ parsing result.
+ Note that the returned list can contain less than ``top_n``
+ elements, for instance when the parameter ``intents`` is not
+ None, or when ``top_n`` is greater than the total number of
+ intents.
+
+ Returns:
+ dict or list: the most likely intent(s) along with the extracted
+ slots. See :func:`.parsing_result` and :func:`.extraction_result`
+ for the output format.
+ """
+ pass
+
+ @abstractmethod
+ def get_intents(self, text):
+ """Performs intent classification on the provided *text* and returns
+ the list of intents ordered by decreasing probability
+
+ The length of the returned list is exactly the number of intents in the
+ dataset + 1 for the None intent
+
+ .. note::
+
+ The probabilities returned along with each intent are not
+ guaranteed to sum to 1.0. They should be considered as scores
+ between 0 and 1.
+ """
+ pass
+
+ @abstractmethod
+ def get_slots(self, text, intent):
+ """Extract slots from a text input, with the knowledge of the intent
+
+ Args:
+ text (str): input
+ intent (str): the intent which the input corresponds to
+
+ Returns:
+ list: the list of extracted slots
+
+ Raises:
+ IntentNotFoundError: when the intent was not part of the training
+ data
+ """
+ pass
diff --git a/snips_inference_agl/intent_parser/lookup_intent_parser.py b/snips_inference_agl/intent_parser/lookup_intent_parser.py
new file mode 100644
index 0000000..921dcc5
--- /dev/null
+++ b/snips_inference_agl/intent_parser/lookup_intent_parser.py
@@ -0,0 +1,509 @@
+from __future__ import unicode_literals
+
+import json
+import logging
+from builtins import str
+from collections import defaultdict
+from itertools import combinations
+from pathlib import Path
+
+from future.utils import iteritems, itervalues
+from snips_nlu_utils import normalize, hash_str
+
+from snips_inference_agl.common.log_utils import log_elapsed_time, log_result
+from snips_inference_agl.common.utils import (
+ check_persisted_path, deduplicate_overlapping_entities, fitted_required,
+ json_string)
+from snips_inference_agl.constants import (
+ DATA, END, ENTITIES, ENTITY, ENTITY_KIND, INTENTS, LANGUAGE, RES_INTENT,
+ RES_INTENT_NAME, RES_MATCH_RANGE, RES_SLOTS, SLOT_NAME, START, TEXT,
+ UTTERANCES, RES_PROBA)
+from snips_inference_agl.dataset import (
+ validate_and_format_dataset, extract_intent_entities)
+from snips_inference_agl.dataset.utils import get_stop_words_whitelist
+from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity
+from snips_inference_agl.exceptions import IntentNotFoundError, LoadingError
+from snips_inference_agl.intent_parser.intent_parser import IntentParser
+from snips_inference_agl.pipeline.configs import LookupIntentParserConfig
+from snips_inference_agl.preprocessing import tokenize_light
+from snips_inference_agl.resources import get_stop_words
+from snips_inference_agl.result import (
+ empty_result, intent_classification_result, parsing_result,
+ unresolved_slot, extraction_result)
+
+logger = logging.getLogger(__name__)
+
+
+@IntentParser.register("lookup_intent_parser")
+class LookupIntentParser(IntentParser):
+ """A deterministic Intent parser implementation based on a dictionary
+
+ This intent parser is very strict by nature, and tends to have a very good
+ precision but a low recall. For this reason, it is interesting to use it
+ first before potentially falling back to another parser.
+ """
+
+ config_type = LookupIntentParserConfig
+
+ def __init__(self, config=None, **shared):
+ """The lookup intent parser can be configured by passing a
+ :class:`.LookupIntentParserConfig`"""
+ super(LookupIntentParser, self).__init__(config, **shared)
+ self._language = None
+ self._stop_words = None
+ self._stop_words_whitelist = None
+ self._map = None
+ self._intents_names = []
+ self._slots_names = []
+ self._intents_mapping = dict()
+ self._slots_mapping = dict()
+ self._entity_scopes = None
+
+ @property
+ def language(self):
+ return self._language
+
+ @language.setter
+ def language(self, value):
+ self._language = value
+ if value is None:
+ self._stop_words = None
+ else:
+ if self.config.ignore_stop_words:
+ self._stop_words = get_stop_words(self.resources)
+ else:
+ self._stop_words = set()
+
+ @property
+ def fitted(self):
+ """Whether or not the intent parser has already been trained"""
+ return self._map is not None
+
+ @log_elapsed_time(
+ logger, logging.INFO, "Fitted lookup intent parser in {elapsed_time}")
+ def fit(self, dataset, force_retrain=True):
+ """Fits the intent parser with a valid Snips dataset"""
+ logger.info("Fitting lookup intent parser...")
+ dataset = validate_and_format_dataset(dataset)
+ self.load_resources_if_needed(dataset[LANGUAGE])
+ self.fit_builtin_entity_parser_if_needed(dataset)
+ self.fit_custom_entity_parser_if_needed(dataset)
+ self.language = dataset[LANGUAGE]
+ self._entity_scopes = _get_entity_scopes(dataset)
+ self._map = dict()
+ self._stop_words_whitelist = get_stop_words_whitelist(
+ dataset, self._stop_words)
+ entity_placeholders = _get_entity_placeholders(dataset, self.language)
+
+ ambiguous_keys = set()
+ for (key, val) in self._generate_io_mapping(dataset[INTENTS],
+ entity_placeholders):
+ key = hash_str(key)
+ # handle key collisions -*- flag ambiguous entries -*-
+ if key in self._map and self._map[key] != val:
+ ambiguous_keys.add(key)
+ else:
+ self._map[key] = val
+
+ # delete ambiguous keys
+ for key in ambiguous_keys:
+ self._map.pop(key)
+
+ return self
+
+ @log_result(logger, logging.DEBUG, "LookupIntentParser result -> {result}")
+ @log_elapsed_time(logger, logging.DEBUG, "Parsed in {elapsed_time}.")
+ @fitted_required
+ def parse(self, text, intents=None, top_n=None):
+ """Performs intent parsing on the provided *text*
+
+ Intent and slots are extracted simultaneously through pattern matching
+
+ Args:
+ text (str): input
+ intents (str or list of str): if provided, reduces the scope of
+ intent parsing to the provided list of intents
+ top_n (int, optional): when provided, this method will return a
+ list of at most top_n most likely intents, instead of a single
+ parsing result.
+ Note that the returned list can contain less than ``top_n``
+ elements, for instance when the parameter ``intents`` is not
+ None, or when ``top_n`` is greater than the total number of
+ intents.
+
+ Returns:
+ dict or list: the most likely intent(s) along with the extracted
+ slots. See :func:`.parsing_result` and :func:`.extraction_result`
+ for the output format.
+
+ Raises:
+ NotTrained: when the intent parser is not fitted
+ """
+ if top_n is None:
+ top_intents = self._parse_top_intents(text, top_n=1,
+ intents=intents)
+ if top_intents:
+ intent = top_intents[0][RES_INTENT]
+ slots = top_intents[0][RES_SLOTS]
+ if intent[RES_PROBA] <= 0.5:
+ # return None in case of ambiguity
+ return empty_result(text, probability=1.0)
+ return parsing_result(text, intent, slots)
+ return empty_result(text, probability=1.0)
+ return self._parse_top_intents(text, top_n=top_n, intents=intents)
+
+ def _parse_top_intents(self, text, top_n, intents=None):
+ if isinstance(intents, str):
+ intents = {intents}
+ elif isinstance(intents, list):
+ intents = set(intents)
+
+ if top_n < 1:
+ raise ValueError(
+ "top_n argument must be greater or equal to 1, but got: %s"
+ % top_n)
+
+ results_per_intent = defaultdict(list)
+ for text_candidate, entities in self._get_candidates(text, intents):
+ val = self._map.get(hash_str(text_candidate))
+ if val is not None:
+ result = self._parse_map_output(text, val, entities, intents)
+ if result:
+ intent_name = result[RES_INTENT][RES_INTENT_NAME]
+ results_per_intent[intent_name].append(result)
+
+ results = []
+ for intent_results in itervalues(results_per_intent):
+ sorted_results = sorted(intent_results,
+ key=lambda res: len(res[RES_SLOTS]))
+ results.append(sorted_results[0])
+
+ # In some rare cases there can be multiple ambiguous intents
+ # In such cases, priority is given to results containing fewer slots
+ weights = [1.0 / (1.0 + len(res[RES_SLOTS])) for res in results]
+ total_weight = sum(weights)
+
+ for res, weight in zip(results, weights):
+ res[RES_INTENT][RES_PROBA] = weight / total_weight
+
+ results = sorted(results, key=lambda r: -r[RES_INTENT][RES_PROBA])
+ return results[:top_n]
+
+ def _get_candidates(self, text, intents):
+ candidates = defaultdict(list)
+ for grouped_entity_scope in self._entity_scopes:
+ entity_scope = grouped_entity_scope["entity_scope"]
+ intent_group = grouped_entity_scope["intent_group"]
+ intent_group = [intent_ for intent_ in intent_group
+ if intents is None or intent_ in intents]
+ if not intent_group:
+ continue
+
+ builtin_entities = self.builtin_entity_parser.parse(
+ text, scope=entity_scope["builtin"], use_cache=True)
+ custom_entities = self.custom_entity_parser.parse(
+ text, scope=entity_scope["custom"], use_cache=True)
+ all_entities = builtin_entities + custom_entities
+ all_entities = deduplicate_overlapping_entities(all_entities)
+
+ # We generate all subsets of entities to match utterances
+ # containing ambivalent words which can be both entity values or
+ # random words
+ for entities in _get_entities_combinations(all_entities):
+ processed_text = self._replace_entities_with_placeholders(
+ text, entities)
+ for intent in intent_group:
+ cleaned_text = self._preprocess_text(text, intent)
+ cleaned_processed_text = self._preprocess_text(
+ processed_text, intent)
+
+ raw_candidate = cleaned_text, []
+ placeholder_candidate = cleaned_processed_text, entities
+ intent_candidates = [raw_candidate, placeholder_candidate]
+ for text_input, text_entities in intent_candidates:
+ if text_input not in candidates \
+ or text_entities not in candidates[text_input]:
+ candidates[text_input].append(text_entities)
+ yield text_input, text_entities
+
+ def _parse_map_output(self, text, output, entities, intents):
+ """Parse the map output to the parser's result format"""
+ intent_id, slot_ids = output
+ intent_name = self._intents_names[intent_id]
+ if intents is not None and intent_name not in intents:
+ return None
+
+ parsed_intent = intent_classification_result(
+ intent_name=intent_name, probability=1.0)
+ slots = []
+ # assert invariant
+ assert len(slot_ids) == len(entities)
+ for slot_id, entity in zip(slot_ids, entities):
+ slot_name = self._slots_names[slot_id]
+ rng_start = entity[RES_MATCH_RANGE][START]
+ rng_end = entity[RES_MATCH_RANGE][END]
+ slot_value = text[rng_start:rng_end]
+ entity_name = entity[ENTITY_KIND]
+ slot = unresolved_slot(
+ [rng_start, rng_end], slot_value, entity_name, slot_name)
+ slots.append(slot)
+
+ return extraction_result(parsed_intent, slots)
+
+ @fitted_required
+ def get_intents(self, text):
+ """Returns the list of intents ordered by decreasing probability
+
+ The length of the returned list is exactly the number of intents in the
+ dataset + 1 for the None intent
+ """
+ nb_intents = len(self._intents_names)
+ top_intents = [intent_result[RES_INTENT] for intent_result in
+ self._parse_top_intents(text, top_n=nb_intents)]
+ matched_intents = {res[RES_INTENT_NAME] for res in top_intents}
+ for intent in self._intents_names:
+ if intent not in matched_intents:
+ top_intents.append(intent_classification_result(intent, 0.0))
+
+ # The None intent is not included in the lookup table and is thus
+ # never matched by the lookup parser
+ top_intents.append(intent_classification_result(None, 0.0))
+ return top_intents
+
+ @fitted_required
+ def get_slots(self, text, intent):
+ """Extracts slots from a text input, with the knowledge of the intent
+
+ Args:
+ text (str): input
+ intent (str): the intent which the input corresponds to
+
+ Returns:
+ list: the list of extracted slots
+
+ Raises:
+ IntentNotFoundError: When the intent was not part of the training
+ data
+ """
+ if intent is None:
+ return []
+
+ if intent not in self._intents_names:
+ raise IntentNotFoundError(intent)
+
+ slots = self.parse(text, intents=[intent])[RES_SLOTS]
+ if slots is None:
+ slots = []
+ return slots
+
+ def _get_intent_stop_words(self, intent):
+ whitelist = self._stop_words_whitelist.get(intent, set())
+ return self._stop_words.difference(whitelist)
+
+ def _get_intent_id(self, intent_name):
+ """generate a numeric id for an intent
+
+ Args:
+ intent_name (str): intent name
+
+ Returns:
+ int: numeric id
+
+ """
+ intent_id = self._intents_mapping.get(intent_name)
+ if intent_id is None:
+ intent_id = len(self._intents_names)
+ self._intents_names.append(intent_name)
+ self._intents_mapping[intent_name] = intent_id
+
+ return intent_id
+
+ def _get_slot_id(self, slot_name):
+ """generate a numeric id for a slot
+
+ Args:
+ slot_name (str): intent name
+
+ Returns:
+ int: numeric id
+
+ """
+ slot_id = self._slots_mapping.get(slot_name)
+ if slot_id is None:
+ slot_id = len(self._slots_names)
+ self._slots_names.append(slot_name)
+ self._slots_mapping[slot_name] = slot_id
+
+ return slot_id
+
+ def _preprocess_text(self, txt, intent):
+ """Replaces stop words and characters that are tokenized out by
+ whitespaces"""
+ stop_words = self._get_intent_stop_words(intent)
+ tokens = tokenize_light(txt, self.language)
+ cleaned_string = " ".join(
+ [tkn for tkn in tokens if normalize(tkn) not in stop_words])
+ return cleaned_string.lower()
+
+ def _generate_io_mapping(self, intents, entity_placeholders):
+ """Generate input-output pairs"""
+ for intent_name, intent in sorted(iteritems(intents)):
+ intent_id = self._get_intent_id(intent_name)
+ for entry in intent[UTTERANCES]:
+ yield self._build_io_mapping(
+ intent_id, entry, entity_placeholders)
+
+ def _build_io_mapping(self, intent_id, utterance, entity_placeholders):
+ input_ = []
+ output = [intent_id]
+ slots = []
+ for chunk in utterance[DATA]:
+ if SLOT_NAME in chunk:
+ slot_name = chunk[SLOT_NAME]
+ slot_id = self._get_slot_id(slot_name)
+ entity_name = chunk[ENTITY]
+ placeholder = entity_placeholders[entity_name]
+ input_.append(placeholder)
+ slots.append(slot_id)
+ else:
+ input_.append(chunk[TEXT])
+ output.append(slots)
+
+ intent = self._intents_names[intent_id]
+ key = self._preprocess_text(" ".join(input_), intent)
+
+ return key, output
+
+ def _replace_entities_with_placeholders(self, text, entities):
+ if not entities:
+ return text
+ entities = sorted(entities, key=lambda e: e[RES_MATCH_RANGE][START])
+ processed_text = ""
+ current_idx = 0
+ for ent in entities:
+ start = ent[RES_MATCH_RANGE][START]
+ end = ent[RES_MATCH_RANGE][END]
+ processed_text += text[current_idx:start]
+ place_holder = _get_entity_name_placeholder(
+ ent[ENTITY_KIND], self.language)
+ processed_text += place_holder
+ current_idx = end
+ processed_text += text[current_idx:]
+
+ return processed_text
+
+ @check_persisted_path
+ def persist(self, path):
+ """Persists the object at the given path"""
+ path.mkdir()
+ parser_json = json_string(self.to_dict())
+ parser_path = path / "intent_parser.json"
+
+ with parser_path.open(mode="w", encoding="utf8") as pfile:
+ pfile.write(parser_json)
+ self.persist_metadata(path)
+
+ @classmethod
+ def from_path(cls, path, **shared):
+ """Loads a :class:`LookupIntentParser` instance from a path
+
+ The data at the given path must have been generated using
+ :func:`~LookupIntentParser.persist`
+ """
+ path = Path(path)
+ model_path = path / "intent_parser.json"
+ if not model_path.exists():
+ raise LoadingError(
+ "Missing lookup intent parser metadata file: %s"
+ % model_path.name)
+
+ with model_path.open(encoding="utf8") as pfile:
+ metadata = json.load(pfile)
+ return cls.from_dict(metadata, **shared)
+
+ def to_dict(self):
+ """Returns a json-serializable dict"""
+ stop_words_whitelist = None
+ if self._stop_words_whitelist is not None:
+ stop_words_whitelist = {
+ intent: sorted(values)
+ for intent, values in iteritems(self._stop_words_whitelist)}
+ return {
+ "config": self.config.to_dict(),
+ "language_code": self.language,
+ "map": self._map,
+ "slots_names": self._slots_names,
+ "intents_names": self._intents_names,
+ "entity_scopes": self._entity_scopes,
+ "stop_words_whitelist": stop_words_whitelist,
+ }
+
+ @classmethod
+ def from_dict(cls, unit_dict, **shared):
+ """Creates a :class:`LookupIntentParser` instance from a dict
+
+ The dict must have been generated with
+ :func:`~LookupIntentParser.to_dict`
+ """
+ config = cls.config_type.from_dict(unit_dict["config"])
+ parser = cls(config=config, **shared)
+ parser.language = unit_dict["language_code"]
+ # pylint:disable=protected-access
+ parser._map = _convert_dict_keys_to_int(unit_dict["map"])
+ parser._slots_names = unit_dict["slots_names"]
+ parser._intents_names = unit_dict["intents_names"]
+ parser._entity_scopes = unit_dict["entity_scopes"]
+ if parser.fitted:
+ whitelist = unit_dict["stop_words_whitelist"]
+ parser._stop_words_whitelist = {
+ intent: set(values) for intent, values in iteritems(whitelist)}
+ # pylint:enable=protected-access
+ return parser
+
+
+def _get_entity_scopes(dataset):
+ intent_entities = extract_intent_entities(dataset)
+ intent_groups = []
+ entity_scopes = []
+ for intent, entities in sorted(iteritems(intent_entities)):
+ scope = {
+ "builtin": list(
+ {ent for ent in entities if is_builtin_entity(ent)}),
+ "custom": list(
+ {ent for ent in entities if not is_builtin_entity(ent)})
+ }
+ if scope in entity_scopes:
+ group_idx = entity_scopes.index(scope)
+ intent_groups[group_idx].append(intent)
+ else:
+ entity_scopes.append(scope)
+ intent_groups.append([intent])
+ return [
+ {
+ "intent_group": intent_group,
+ "entity_scope": entity_scope
+ } for intent_group, entity_scope in zip(intent_groups, entity_scopes)
+ ]
+
+
+def _get_entity_placeholders(dataset, language):
+ return {
+ e: _get_entity_name_placeholder(e, language) for e in dataset[ENTITIES]
+ }
+
+
+def _get_entity_name_placeholder(entity_label, language):
+ return "%%%s%%" % "".join(tokenize_light(entity_label, language)).upper()
+
+
+def _convert_dict_keys_to_int(dct):
+ if isinstance(dct, dict):
+ return {int(k): v for k, v in iteritems(dct)}
+ return dct
+
+
+def _get_entities_combinations(entities):
+ yield ()
+ for nb_entities in reversed(range(1, len(entities) + 1)):
+ for combination in combinations(entities, nb_entities):
+ yield combination
diff --git a/snips_inference_agl/intent_parser/probabilistic_intent_parser.py b/snips_inference_agl/intent_parser/probabilistic_intent_parser.py
new file mode 100644
index 0000000..23e7829
--- /dev/null
+++ b/snips_inference_agl/intent_parser/probabilistic_intent_parser.py
@@ -0,0 +1,250 @@
+from __future__ import unicode_literals
+
+import json
+import logging
+from builtins import str
+from copy import deepcopy
+from datetime import datetime
+from pathlib import Path
+
+from future.utils import iteritems, itervalues
+
+from snips_inference_agl.common.log_utils import log_elapsed_time, log_result
+from snips_inference_agl.common.utils import (
+ check_persisted_path, elapsed_since, fitted_required, json_string)
+from snips_inference_agl.constants import INTENTS, RES_INTENT_NAME
+from snips_inference_agl.dataset import validate_and_format_dataset
+from snips_inference_agl.exceptions import IntentNotFoundError, LoadingError
+from snips_inference_agl.intent_classifier import IntentClassifier
+from snips_inference_agl.intent_parser.intent_parser import IntentParser
+from snips_inference_agl.pipeline.configs import ProbabilisticIntentParserConfig
+from snips_inference_agl.result import parsing_result, extraction_result
+from snips_inference_agl.slot_filler import SlotFiller
+
+logger = logging.getLogger(__name__)
+
+
+@IntentParser.register("probabilistic_intent_parser")
+class ProbabilisticIntentParser(IntentParser):
+ """Intent parser which consists in two steps: intent classification then
+ slot filling"""
+
+ config_type = ProbabilisticIntentParserConfig
+
+ def __init__(self, config=None, **shared):
+ """The probabilistic intent parser can be configured by passing a
+ :class:`.ProbabilisticIntentParserConfig`"""
+ super(ProbabilisticIntentParser, self).__init__(config, **shared)
+ self.intent_classifier = None
+ self.slot_fillers = dict()
+
+ @property
+ def fitted(self):
+ """Whether or not the intent parser has already been fitted"""
+ return self.intent_classifier is not None \
+ and self.intent_classifier.fitted \
+ and all(slot_filler is not None and slot_filler.fitted
+ for slot_filler in itervalues(self.slot_fillers))
+
+ @log_elapsed_time(logger, logging.INFO,
+ "Fitted probabilistic intent parser in {elapsed_time}")
+ # pylint:disable=arguments-differ
+ def fit(self, dataset, force_retrain=True):
+ """Fits the probabilistic intent parser
+
+ Args:
+ dataset (dict): A valid Snips dataset
+ force_retrain (bool, optional): If *False*, will not retrain intent
+ classifier and slot fillers when they are already fitted.
+ Default to *True*.
+
+ Returns:
+ :class:`ProbabilisticIntentParser`: The same instance, trained
+ """
+ logger.info("Fitting probabilistic intent parser...")
+ dataset = validate_and_format_dataset(dataset)
+ intents = list(dataset[INTENTS])
+ if self.intent_classifier is None:
+ self.intent_classifier = IntentClassifier.from_config(
+ self.config.intent_classifier_config,
+ builtin_entity_parser=self.builtin_entity_parser,
+ custom_entity_parser=self.custom_entity_parser,
+ resources=self.resources,
+ random_state=self.random_state,
+ )
+
+ if force_retrain or not self.intent_classifier.fitted:
+ self.intent_classifier.fit(dataset)
+
+ if self.slot_fillers is None:
+ self.slot_fillers = dict()
+ slot_fillers_start = datetime.now()
+ for intent_name in intents:
+ # We need to copy the slot filler config as it may be mutated
+ if self.slot_fillers.get(intent_name) is None:
+ slot_filler_config = deepcopy(self.config.slot_filler_config)
+ self.slot_fillers[intent_name] = SlotFiller.from_config(
+ slot_filler_config,
+ builtin_entity_parser=self.builtin_entity_parser,
+ custom_entity_parser=self.custom_entity_parser,
+ resources=self.resources,
+ random_state=self.random_state,
+ )
+ if force_retrain or not self.slot_fillers[intent_name].fitted:
+ self.slot_fillers[intent_name].fit(dataset, intent_name)
+ logger.debug("Fitted slot fillers in %s",
+ elapsed_since(slot_fillers_start))
+ return self
+
+ # pylint:enable=arguments-differ
+
+ @log_result(logger, logging.DEBUG,
+ "ProbabilisticIntentParser result -> {result}")
+ @log_elapsed_time(logger, logging.DEBUG,
+ "ProbabilisticIntentParser parsed in {elapsed_time}")
+ @fitted_required
+ def parse(self, text, intents=None, top_n=None):
+ """Performs intent parsing on the provided *text* by first classifying
+ the intent and then using the correspond slot filler to extract slots
+
+ Args:
+ text (str): input
+ intents (str or list of str): if provided, reduces the scope of
+ intent parsing to the provided list of intents
+ top_n (int, optional): when provided, this method will return a
+ list of at most top_n most likely intents, instead of a single
+ parsing result.
+ Note that the returned list can contain less than ``top_n``
+ elements, for instance when the parameter ``intents`` is not
+ None, or when ``top_n`` is greater than the total number of
+ intents.
+
+ Returns:
+ dict or list: the most likely intent(s) along with the extracted
+ slots. See :func:`.parsing_result` and :func:`.extraction_result`
+ for the output format.
+
+ Raises:
+ NotTrained: when the intent parser is not fitted
+ """
+ if isinstance(intents, str):
+ intents = {intents}
+ elif isinstance(intents, list):
+ intents = list(intents)
+
+ if top_n is None:
+ intent_result = self.intent_classifier.get_intent(text, intents)
+ intent_name = intent_result[RES_INTENT_NAME]
+ if intent_name is not None:
+ slots = self.slot_fillers[intent_name].get_slots(text)
+ else:
+ slots = []
+ return parsing_result(text, intent_result, slots)
+
+ results = []
+ intents_results = self.intent_classifier.get_intents(text)
+ for intent_result in intents_results[:top_n]:
+ intent_name = intent_result[RES_INTENT_NAME]
+ if intent_name is not None:
+ slots = self.slot_fillers[intent_name].get_slots(text)
+ else:
+ slots = []
+ results.append(extraction_result(intent_result, slots))
+ return results
+
+ @fitted_required
+ def get_intents(self, text):
+ """Returns the list of intents ordered by decreasing probability
+
+ The length of the returned list is exactly the number of intents in the
+ dataset + 1 for the None intent
+ """
+ return self.intent_classifier.get_intents(text)
+
+ @fitted_required
+ def get_slots(self, text, intent):
+ """Extracts slots from a text input, with the knowledge of the intent
+
+ Args:
+ text (str): input
+ intent (str): the intent which the input corresponds to
+
+ Returns:
+ list: the list of extracted slots
+
+ Raises:
+ IntentNotFoundError: When the intent was not part of the training
+ data
+ """
+ if intent is None:
+ return []
+
+ if intent not in self.slot_fillers:
+ raise IntentNotFoundError(intent)
+ return self.slot_fillers[intent].get_slots(text)
+
+ @check_persisted_path
+ def persist(self, path):
+ """Persists the object at the given path"""
+ path.mkdir()
+ sorted_slot_fillers = sorted(iteritems(self.slot_fillers))
+ slot_fillers = []
+ for i, (intent, slot_filler) in enumerate(sorted_slot_fillers):
+ slot_filler_name = "slot_filler_%s" % i
+ slot_filler.persist(path / slot_filler_name)
+ slot_fillers.append({
+ "intent": intent,
+ "slot_filler_name": slot_filler_name
+ })
+
+ if self.intent_classifier is not None:
+ self.intent_classifier.persist(path / "intent_classifier")
+
+ model = {
+ "config": self.config.to_dict(),
+ "slot_fillers": slot_fillers
+ }
+ model_json = json_string(model)
+ model_path = path / "intent_parser.json"
+ with model_path.open(mode="w") as f:
+ f.write(model_json)
+ self.persist_metadata(path)
+
+ @classmethod
+ def from_path(cls, path, **shared):
+ """Loads a :class:`ProbabilisticIntentParser` instance from a path
+
+ The data at the given path must have been generated using
+ :func:`~ProbabilisticIntentParser.persist`
+ """
+ path = Path(path)
+ model_path = path / "intent_parser.json"
+ if not model_path.exists():
+ raise LoadingError(
+ "Missing probabilistic intent parser model file: %s"
+ % model_path.name)
+
+ with model_path.open(encoding="utf8") as f:
+ model = json.load(f)
+
+ config = cls.config_type.from_dict(model["config"])
+ parser = cls(config=config, **shared)
+ classifier = None
+ intent_classifier_path = path / "intent_classifier"
+ if intent_classifier_path.exists():
+ classifier_unit_name = config.intent_classifier_config.unit_name
+ classifier = IntentClassifier.load_from_path(
+ intent_classifier_path, classifier_unit_name, **shared)
+
+ slot_fillers = dict()
+ slot_filler_unit_name = config.slot_filler_config.unit_name
+ for slot_filler_conf in model["slot_fillers"]:
+ intent = slot_filler_conf["intent"]
+ slot_filler_path = path / slot_filler_conf["slot_filler_name"]
+ slot_filler = SlotFiller.load_from_path(
+ slot_filler_path, slot_filler_unit_name, **shared)
+ slot_fillers[intent] = slot_filler
+
+ parser.intent_classifier = classifier
+ parser.slot_fillers = slot_fillers
+ return parser
diff --git a/snips_inference_agl/languages.py b/snips_inference_agl/languages.py
new file mode 100644
index 0000000..cc205a3
--- /dev/null
+++ b/snips_inference_agl/languages.py
@@ -0,0 +1,44 @@
+from __future__ import unicode_literals
+
+import re
+import string
+
+_PUNCTUATION_REGEXES = dict()
+_NUM2WORDS_SUPPORT = dict()
+
+
+# pylint:disable=unused-argument
+def get_default_sep(language):
+ return " "
+
+
+# pylint:enable=unused-argument
+
+# pylint:disable=unused-argument
+def get_punctuation(language):
+ return string.punctuation
+
+
+# pylint:enable=unused-argument
+
+
+def get_punctuation_regex(language):
+ global _PUNCTUATION_REGEXES
+ if language not in _PUNCTUATION_REGEXES:
+ pattern = r"|".join(re.escape(p) for p in get_punctuation(language))
+ _PUNCTUATION_REGEXES[language] = re.compile(pattern)
+ return _PUNCTUATION_REGEXES[language]
+
+
+def supports_num2words(language):
+ from num2words import num2words
+
+ global _NUM2WORDS_SUPPORT
+
+ if language not in _NUM2WORDS_SUPPORT:
+ try:
+ num2words(0, lang=language)
+ _NUM2WORDS_SUPPORT[language] = True
+ except NotImplementedError:
+ _NUM2WORDS_SUPPORT[language] = False
+ return _NUM2WORDS_SUPPORT[language]
diff --git a/snips_inference_agl/nlu_engine/__init__.py b/snips_inference_agl/nlu_engine/__init__.py
new file mode 100644
index 0000000..b45eb32
--- /dev/null
+++ b/snips_inference_agl/nlu_engine/__init__.py
@@ -0,0 +1 @@
+from snips_inference_agl.nlu_engine.nlu_engine import SnipsNLUEngine
diff --git a/snips_inference_agl/nlu_engine/nlu_engine.py b/snips_inference_agl/nlu_engine/nlu_engine.py
new file mode 100644
index 0000000..3f19aba
--- /dev/null
+++ b/snips_inference_agl/nlu_engine/nlu_engine.py
@@ -0,0 +1,330 @@
+from __future__ import unicode_literals
+
+import json
+import logging
+from builtins import str
+from pathlib import Path
+
+from future.utils import itervalues
+
+from snips_inference_agl.__about__ import __model_version__, __version__
+from snips_inference_agl.common.log_utils import log_elapsed_time
+from snips_inference_agl.common.utils import (fitted_required)
+from snips_inference_agl.constants import (
+ AUTOMATICALLY_EXTENSIBLE, BUILTIN_ENTITY_PARSER, CUSTOM_ENTITY_PARSER,
+ ENTITIES, ENTITY_KIND, LANGUAGE, RESOLVED_VALUE, RES_ENTITY,
+ RES_INTENT, RES_INTENT_NAME, RES_MATCH_RANGE, RES_PROBA, RES_SLOTS,
+ RES_VALUE, RESOURCES, BYPASS_VERSION_CHECK)
+# from snips_inference_agl.dataset import validate_and_format_dataset
+from snips_inference_agl.entity_parser import CustomEntityParser
+from snips_inference_agl.entity_parser.builtin_entity_parser import (
+ BuiltinEntityParser, is_builtin_entity)
+from snips_inference_agl.exceptions import (
+ InvalidInputError, IntentNotFoundError, LoadingError,
+ IncompatibleModelError)
+from snips_inference_agl.intent_parser import IntentParser
+from snips_inference_agl.pipeline.configs import NLUEngineConfig
+from snips_inference_agl.pipeline.processing_unit import ProcessingUnit
+from snips_inference_agl.resources import load_resources_from_dir
+from snips_inference_agl.result import (
+ builtin_slot, custom_slot, empty_result, extraction_result, is_empty,
+ parsing_result)
+
+logger = logging.getLogger(__name__)
+
+
+@ProcessingUnit.register("nlu_engine")
+class SnipsNLUEngine(ProcessingUnit):
+ """Main class to use for intent parsing
+
+ A :class:`SnipsNLUEngine` relies on a list of :class:`.IntentParser`
+ object to parse intents, by calling them successively using the first
+ positive output.
+
+ With the default parameters, it will use the two following intent parsers
+ in this order:
+
+ - a :class:`.DeterministicIntentParser`
+ - a :class:`.ProbabilisticIntentParser`
+
+ The logic behind is to first use a conservative parser which has a very
+ good precision while its recall is modest, so simple patterns will be
+ caught, and then fallback on a second parser which is machine-learning
+ based and will be able to parse unseen utterances while ensuring a good
+ precision and recall.
+ """
+
+ config_type = NLUEngineConfig
+
+ def __init__(self, config=None, **shared):
+ """The NLU engine can be configured by passing a
+ :class:`.NLUEngineConfig`"""
+ super(SnipsNLUEngine, self).__init__(config, **shared)
+ self.intent_parsers = []
+ """list of :class:`.IntentParser`"""
+ self.dataset_metadata = None
+
+ @classmethod
+ def default_config(cls):
+ # Do not use the global default config, and use per-language default
+ # configs instead
+ return None
+
+ @property
+ def fitted(self):
+ """Whether or not the nlu engine has already been fitted"""
+ return self.dataset_metadata is not None
+
+ @log_elapsed_time(logger, logging.DEBUG, "Parsed input in {elapsed_time}")
+ @fitted_required
+ def parse(self, text, intents=None, top_n=None):
+ """Performs intent parsing on the provided *text* by calling its intent
+ parsers successively
+
+ Args:
+ text (str): Input
+ intents (str or list of str, optional): If provided, reduces the
+ scope of intent parsing to the provided list of intents.
+ The ``None`` intent is never filtered out, meaning that it can
+ be returned even when using an intents scope.
+ top_n (int, optional): when provided, this method will return a
+ list of at most ``top_n`` most likely intents, instead of a
+ single parsing result.
+ Note that the returned list can contain less than ``top_n``
+ elements, for instance when the parameter ``intents`` is not
+ None, or when ``top_n`` is greater than the total number of
+ intents.
+
+ Returns:
+ dict or list: the most likely intent(s) along with the extracted
+ slots. See :func:`.parsing_result` and :func:`.extraction_result`
+ for the output format.
+
+ Raises:
+ NotTrained: When the nlu engine is not fitted
+ InvalidInputError: When input type is not unicode
+ """
+ if not isinstance(text, str):
+ raise InvalidInputError("Expected unicode but received: %s"
+ % type(text))
+
+ if isinstance(intents, str):
+ intents = {intents}
+ elif isinstance(intents, list):
+ intents = set(intents)
+
+ if intents is not None:
+ for intent in intents:
+ if intent not in self.dataset_metadata["slot_name_mappings"]:
+ raise IntentNotFoundError(intent)
+
+ if top_n is None:
+ none_proba = 0.0
+ for parser in self.intent_parsers:
+ res = parser.parse(text, intents)
+ if is_empty(res):
+ none_proba = res[RES_INTENT][RES_PROBA]
+ continue
+ resolved_slots = self._resolve_slots(text, res[RES_SLOTS])
+ return parsing_result(text, intent=res[RES_INTENT],
+ slots=resolved_slots)
+ return empty_result(text, none_proba)
+
+ intents_results = self.get_intents(text)
+ if intents is not None:
+ intents_results = [res for res in intents_results
+ if res[RES_INTENT_NAME] is None
+ or res[RES_INTENT_NAME] in intents]
+ intents_results = intents_results[:top_n]
+ results = []
+ for intent_res in intents_results:
+ slots = self.get_slots(text, intent_res[RES_INTENT_NAME])
+ results.append(extraction_result(intent_res, slots))
+ return results
+
+ @log_elapsed_time(logger, logging.DEBUG, "Got intents in {elapsed_time}")
+ @fitted_required
+ def get_intents(self, text):
+ """Performs intent classification on the provided *text* and returns
+ the list of intents ordered by decreasing probability
+
+ The length of the returned list is exactly the number of intents in the
+ dataset + 1 for the None intent
+
+ .. note::
+
+ The probabilities returned along with each intent are not
+ guaranteed to sum to 1.0. They should be considered as scores
+ between 0 and 1.
+ """
+ results = None
+ for parser in self.intent_parsers:
+ parser_results = parser.get_intents(text)
+ if results is None:
+ results = {res[RES_INTENT_NAME]: res for res in parser_results}
+ continue
+
+ for res in parser_results:
+ intent = res[RES_INTENT_NAME]
+ proba = max(res[RES_PROBA], results[intent][RES_PROBA])
+ results[intent][RES_PROBA] = proba
+
+ return sorted(itervalues(results), key=lambda res: -res[RES_PROBA])
+
+ @log_elapsed_time(logger, logging.DEBUG, "Parsed slots in {elapsed_time}")
+ @fitted_required
+ def get_slots(self, text, intent):
+ """Extracts slots from a text input, with the knowledge of the intent
+
+ Args:
+ text (str): input
+ intent (str): the intent which the input corresponds to
+
+ Returns:
+ list: the list of extracted slots
+
+ Raises:
+ IntentNotFoundError: When the intent was not part of the training
+ data
+ InvalidInputError: When input type is not unicode
+ """
+ if not isinstance(text, str):
+ raise InvalidInputError("Expected unicode but received: %s"
+ % type(text))
+
+ if intent is None:
+ return []
+
+ if intent not in self.dataset_metadata["slot_name_mappings"]:
+ raise IntentNotFoundError(intent)
+
+ for parser in self.intent_parsers:
+ slots = parser.get_slots(text, intent)
+ if not slots:
+ continue
+ return self._resolve_slots(text, slots)
+ return []
+
+
+ @classmethod
+ def from_path(cls, path, **shared):
+ """Loads a :class:`SnipsNLUEngine` instance from a directory path
+
+ The data at the given path must have been generated using
+ :func:`~SnipsNLUEngine.persist`
+
+ Args:
+ path (str): The path where the nlu engine is stored
+
+ Raises:
+ LoadingError: when some files are missing
+ IncompatibleModelError: when trying to load an engine model which
+ is not compatible with the current version of the lib
+ """
+ directory_path = Path(path)
+ model_path = directory_path / "nlu_engine.json"
+ if not model_path.exists():
+ raise LoadingError("Missing nlu engine model file: %s"
+ % model_path.name)
+
+ with model_path.open(encoding="utf8") as f:
+ model = json.load(f)
+ model_version = model.get("model_version")
+ if model_version is None or model_version != __model_version__:
+ bypass_version_check = shared.get(BYPASS_VERSION_CHECK, False)
+ if bypass_version_check:
+ logger.warning(
+ "Incompatible model version found. The library expected "
+ "'%s' but the loaded engine is '%s'. The NLU engine may "
+ "not load correctly.", __model_version__, model_version)
+ else:
+ raise IncompatibleModelError(model_version)
+
+ dataset_metadata = model["dataset_metadata"]
+ if shared.get(RESOURCES) is None and dataset_metadata is not None:
+ language = dataset_metadata["language_code"]
+ resources_dir = directory_path / "resources" / language
+ if resources_dir.is_dir():
+ resources = load_resources_from_dir(resources_dir)
+ shared[RESOURCES] = resources
+
+ if shared.get(BUILTIN_ENTITY_PARSER) is None:
+ path = model["builtin_entity_parser"]
+ if path is not None:
+ parser_path = directory_path / path
+ shared[BUILTIN_ENTITY_PARSER] = BuiltinEntityParser.from_path(
+ parser_path)
+
+ if shared.get(CUSTOM_ENTITY_PARSER) is None:
+ path = model["custom_entity_parser"]
+ if path is not None:
+ parser_path = directory_path / path
+ shared[CUSTOM_ENTITY_PARSER] = CustomEntityParser.from_path(
+ parser_path)
+
+ config = cls.config_type.from_dict(model["config"])
+ nlu_engine = cls(config=config, **shared)
+ nlu_engine.dataset_metadata = dataset_metadata
+ intent_parsers = []
+ for parser_idx, parser_name in enumerate(model["intent_parsers"]):
+ parser_config = config.intent_parsers_configs[parser_idx]
+ intent_parser_path = directory_path / parser_name
+ intent_parser = IntentParser.load_from_path(
+ intent_parser_path, parser_config.unit_name, **shared)
+ intent_parsers.append(intent_parser)
+ nlu_engine.intent_parsers = intent_parsers
+ return nlu_engine
+
+ def _resolve_slots(self, text, slots):
+ builtin_scope = [slot[RES_ENTITY] for slot in slots
+ if is_builtin_entity(slot[RES_ENTITY])]
+ custom_scope = [slot[RES_ENTITY] for slot in slots
+ if not is_builtin_entity(slot[RES_ENTITY])]
+ # Do not use cached entities here as datetimes must be computed using
+ # current context
+ builtin_entities = self.builtin_entity_parser.parse(
+ text, builtin_scope, use_cache=False)
+ custom_entities = self.custom_entity_parser.parse(
+ text, custom_scope, use_cache=True)
+
+ resolved_slots = []
+ for slot in slots:
+ entity_name = slot[RES_ENTITY]
+ raw_value = slot[RES_VALUE]
+ is_builtin = is_builtin_entity(entity_name)
+ if is_builtin:
+ entities = builtin_entities
+ parser = self.builtin_entity_parser
+ slot_builder = builtin_slot
+ use_cache = False
+ extensible = False
+ else:
+ entities = custom_entities
+ parser = self.custom_entity_parser
+ slot_builder = custom_slot
+ use_cache = True
+ extensible = self.dataset_metadata[ENTITIES][entity_name][
+ AUTOMATICALLY_EXTENSIBLE]
+
+ resolved_slot = None
+ for ent in entities:
+ if ent[ENTITY_KIND] == entity_name and \
+ ent[RES_MATCH_RANGE] == slot[RES_MATCH_RANGE]:
+ resolved_slot = slot_builder(slot, ent[RESOLVED_VALUE])
+ break
+ if resolved_slot is None:
+ matches = parser.parse(
+ raw_value, scope=[entity_name], use_cache=use_cache)
+ if matches:
+ match = matches[0]
+ if is_builtin or len(match[RES_VALUE]) == len(raw_value):
+ resolved_slot = slot_builder(
+ slot, match[RESOLVED_VALUE])
+
+ if resolved_slot is None and extensible:
+ resolved_slot = slot_builder(slot)
+
+ if resolved_slot is not None:
+ resolved_slots.append(resolved_slot)
+
+ return resolved_slots \ No newline at end of file
diff --git a/snips_inference_agl/pipeline/__init__.py b/snips_inference_agl/pipeline/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/snips_inference_agl/pipeline/__init__.py
diff --git a/snips_inference_agl/pipeline/configs/__init__.py b/snips_inference_agl/pipeline/configs/__init__.py
new file mode 100644
index 0000000..027f286
--- /dev/null
+++ b/snips_inference_agl/pipeline/configs/__init__.py
@@ -0,0 +1,10 @@
+from .config import Config, ProcessingUnitConfig
+from .features import default_features_factories
+from .intent_classifier import (CooccurrenceVectorizerConfig, FeaturizerConfig,
+ IntentClassifierDataAugmentationConfig,
+ LogRegIntentClassifierConfig)
+from .intent_parser import (DeterministicIntentParserConfig,
+ LookupIntentParserConfig,
+ ProbabilisticIntentParserConfig)
+from .nlu_engine import NLUEngineConfig
+from .slot_filler import CRFSlotFillerConfig, SlotFillerDataAugmentationConfig
diff --git a/snips_inference_agl/pipeline/configs/config.py b/snips_inference_agl/pipeline/configs/config.py
new file mode 100644
index 0000000..4267fa2
--- /dev/null
+++ b/snips_inference_agl/pipeline/configs/config.py
@@ -0,0 +1,49 @@
+from __future__ import unicode_literals
+
+from abc import ABCMeta, abstractmethod, abstractproperty
+from builtins import object
+
+from future.utils import with_metaclass
+
+
+class Config(with_metaclass(ABCMeta, object)):
+ @abstractmethod
+ def to_dict(self):
+ pass
+
+ @classmethod
+ def from_dict(cls, obj_dict):
+ raise NotImplementedError
+
+
+class ProcessingUnitConfig(with_metaclass(ABCMeta, Config)):
+ """Represents the configuration object needed to initialize a
+ :class:`.ProcessingUnit`"""
+
+ @abstractproperty
+ def unit_name(self):
+ raise NotImplementedError
+
+ def set_unit_name(self, value):
+ pass
+
+ def get_required_resources(self):
+ return None
+
+
+class DefaultProcessingUnitConfig(dict, ProcessingUnitConfig):
+ """Default config implemented as a simple dict"""
+
+ @property
+ def unit_name(self):
+ return self["unit_name"]
+
+ def set_unit_name(self, value):
+ self["unit_name"] = value
+
+ def to_dict(self):
+ return self
+
+ @classmethod
+ def from_dict(cls, obj_dict):
+ return cls(obj_dict)
diff --git a/snips_inference_agl/pipeline/configs/features.py b/snips_inference_agl/pipeline/configs/features.py
new file mode 100644
index 0000000..fa12e1a
--- /dev/null
+++ b/snips_inference_agl/pipeline/configs/features.py
@@ -0,0 +1,81 @@
+def default_features_factories():
+ """These are the default features used by the :class:`.CRFSlotFiller`
+ objects"""
+
+ from snips_inference_agl.slot_filler.crf_utils import TaggingScheme
+ from snips_inference_agl.slot_filler.feature_factory import (
+ NgramFactory, IsDigitFactory, IsFirstFactory, IsLastFactory,
+ ShapeNgramFactory, CustomEntityMatchFactory, BuiltinEntityMatchFactory)
+
+ return [
+ {
+ "args": {
+ "common_words_gazetteer_name": None,
+ "use_stemming": False,
+ "n": 1
+ },
+ "factory_name": NgramFactory.name,
+ "offsets": [-2, -1, 0, 1, 2]
+ },
+ {
+ "args": {
+ "common_words_gazetteer_name": None,
+ "use_stemming": False,
+ "n": 2
+ },
+ "factory_name": NgramFactory.name,
+ "offsets": [-2, 1]
+ },
+ {
+ "args": {},
+ "factory_name": IsDigitFactory.name,
+ "offsets": [-1, 0, 1]
+ },
+ {
+ "args": {},
+ "factory_name": IsFirstFactory.name,
+ "offsets": [-2, -1, 0]
+ },
+ {
+ "args": {},
+ "factory_name": IsLastFactory.name,
+ "offsets": [0, 1, 2]
+ },
+ {
+ "args": {
+ "n": 1
+ },
+ "factory_name": ShapeNgramFactory.name,
+ "offsets": [0]
+ },
+ {
+ "args": {
+ "n": 2
+ },
+ "factory_name": ShapeNgramFactory.name,
+ "offsets": [-1, 0]
+ },
+ {
+ "args": {
+ "n": 3
+ },
+ "factory_name": ShapeNgramFactory.name,
+ "offsets": [-1]
+ },
+ {
+ "args": {
+ "use_stemming": False,
+ "tagging_scheme_code": TaggingScheme.BILOU.value,
+ },
+ "factory_name": CustomEntityMatchFactory.name,
+ "offsets": [-2, -1, 0],
+ "drop_out": 0.5
+ },
+ {
+ "args": {
+ "tagging_scheme_code": TaggingScheme.BIO.value,
+ },
+ "factory_name": BuiltinEntityMatchFactory.name,
+ "offsets": [-2, -1, 0]
+ },
+ ]
diff --git a/snips_inference_agl/pipeline/configs/intent_classifier.py b/snips_inference_agl/pipeline/configs/intent_classifier.py
new file mode 100644
index 0000000..fc22c87
--- /dev/null
+++ b/snips_inference_agl/pipeline/configs/intent_classifier.py
@@ -0,0 +1,307 @@
+from __future__ import unicode_literals
+
+from snips_inference_agl.common.from_dict import FromDict
+from snips_inference_agl.constants import (
+ CUSTOM_ENTITY_PARSER_USAGE, NOISE, STEMS, STOP_WORDS, WORD_CLUSTERS)
+from snips_inference_agl.entity_parser.custom_entity_parser import (
+ CustomEntityParserUsage)
+from snips_inference_agl.pipeline.configs import Config, ProcessingUnitConfig
+from snips_inference_agl.resources import merge_required_resources
+
+
+class LogRegIntentClassifierConfig(FromDict, ProcessingUnitConfig):
+ """Configuration of a :class:`.LogRegIntentClassifier`"""
+
+ # pylint: disable=line-too-long
+ def __init__(self, data_augmentation_config=None, featurizer_config=None,
+ noise_reweight_factor=1.0):
+ """
+ Args:
+ data_augmentation_config (:class:`IntentClassifierDataAugmentationConfig`):
+ Defines the strategy of the underlying data augmentation
+ featurizer_config (:class:`FeaturizerConfig`): Configuration of the
+ :class:`.Featurizer` used underneath
+ noise_reweight_factor (float, optional): this parameter allows to
+ change the weight of the None class. By default, the class
+ weights are computed using a "balanced" strategy. The
+ noise_reweight_factor allows to deviate from this strategy.
+ """
+ if data_augmentation_config is None:
+ data_augmentation_config = IntentClassifierDataAugmentationConfig()
+ if featurizer_config is None:
+ featurizer_config = FeaturizerConfig()
+ self._data_augmentation_config = None
+ self.data_augmentation_config = data_augmentation_config
+ self._featurizer_config = None
+ self.featurizer_config = featurizer_config
+ self.noise_reweight_factor = noise_reweight_factor
+
+ # pylint: enable=line-too-long
+
+ @property
+ def data_augmentation_config(self):
+ return self._data_augmentation_config
+
+ @data_augmentation_config.setter
+ def data_augmentation_config(self, value):
+ if isinstance(value, dict):
+ self._data_augmentation_config = \
+ IntentClassifierDataAugmentationConfig.from_dict(value)
+ elif isinstance(value, IntentClassifierDataAugmentationConfig):
+ self._data_augmentation_config = value
+ else:
+ raise TypeError("Expected instance of "
+ "IntentClassifierDataAugmentationConfig or dict"
+ "but received: %s" % type(value))
+
+ @property
+ def featurizer_config(self):
+ return self._featurizer_config
+
+ @featurizer_config.setter
+ def featurizer_config(self, value):
+ if isinstance(value, dict):
+ self._featurizer_config = \
+ FeaturizerConfig.from_dict(value)
+ elif isinstance(value, FeaturizerConfig):
+ self._featurizer_config = value
+ else:
+ raise TypeError("Expected instance of FeaturizerConfig or dict"
+ "but received: %s" % type(value))
+
+ @property
+ def unit_name(self):
+ from snips_inference_agl.intent_classifier import LogRegIntentClassifier
+ return LogRegIntentClassifier.unit_name
+
+ def get_required_resources(self):
+ resources = self.data_augmentation_config.get_required_resources()
+ resources = merge_required_resources(
+ resources, self.featurizer_config.get_required_resources())
+ return resources
+
+ def to_dict(self):
+ return {
+ "unit_name": self.unit_name,
+ "data_augmentation_config":
+ self.data_augmentation_config.to_dict(),
+ "featurizer_config": self.featurizer_config.to_dict(),
+ "noise_reweight_factor": self.noise_reweight_factor,
+ }
+
+
+class IntentClassifierDataAugmentationConfig(FromDict, Config):
+ """Configuration used by a :class:`.LogRegIntentClassifier` which defines
+ how to augment data to improve the training of the classifier"""
+
+ def __init__(self, min_utterances=20, noise_factor=5,
+ add_builtin_entities_examples=True, unknown_word_prob=0,
+ unknown_words_replacement_string=None,
+ max_unknown_words=None):
+ """
+ Args:
+ min_utterances (int, optional): The minimum number of utterances to
+ automatically generate for each intent, based on the existing
+ utterances. Default is 20.
+ noise_factor (int, optional): Defines the size of the noise to
+ generate to train the implicit *None* intent, as a multiplier
+ of the average size of the other intents. Default is 5.
+ add_builtin_entities_examples (bool, optional): If True, some
+ builtin entity examples will be automatically added to the
+ training data. Default is True.
+ """
+ self.min_utterances = min_utterances
+ self.noise_factor = noise_factor
+ self.add_builtin_entities_examples = add_builtin_entities_examples
+ self.unknown_word_prob = unknown_word_prob
+ self.unknown_words_replacement_string = \
+ unknown_words_replacement_string
+ if max_unknown_words is not None and max_unknown_words < 0:
+ raise ValueError("max_unknown_words must be None or >= 0")
+ self.max_unknown_words = max_unknown_words
+ if unknown_word_prob > 0 and unknown_words_replacement_string is None:
+ raise ValueError("unknown_word_prob is positive (%s) but the "
+ "replacement string is None" % unknown_word_prob)
+
+ @staticmethod
+ def get_required_resources():
+ return {
+ NOISE: True,
+ STOP_WORDS: True
+ }
+
+ def to_dict(self):
+ return {
+ "min_utterances": self.min_utterances,
+ "noise_factor": self.noise_factor,
+ "add_builtin_entities_examples":
+ self.add_builtin_entities_examples,
+ "unknown_word_prob": self.unknown_word_prob,
+ "unknown_words_replacement_string":
+ self.unknown_words_replacement_string,
+ "max_unknown_words": self.max_unknown_words
+ }
+
+
+class FeaturizerConfig(FromDict, ProcessingUnitConfig):
+ """Configuration of a :class:`.Featurizer` object"""
+
+ # pylint: disable=line-too-long
+ def __init__(self, tfidf_vectorizer_config=None,
+ cooccurrence_vectorizer_config=None,
+ pvalue_threshold=0.4,
+ added_cooccurrence_feature_ratio=0):
+ """
+ Args:
+ tfidf_vectorizer_config (:class:`.TfidfVectorizerConfig`, optional):
+ empty configuration of the featurizer's
+ :attr:`tfidf_vectorizer`
+ cooccurrence_vectorizer_config: (:class:`.CooccurrenceVectorizerConfig`, optional):
+ configuration of the featurizer's
+ :attr:`cooccurrence_vectorizer`
+ pvalue_threshold (float): after fitting the training set to
+ extract tfidf features, a univariate feature selection is
+ applied. Features are tested for independence using a Chi-2
+ test, under the null hypothesis that each feature should be
+ equally present in each class. Only features having a p-value
+ lower than the threshold are kept
+ added_cooccurrence_feature_ratio (float, optional): proportion of
+ cooccurrence features to add with respect to the number of
+ tfidf features. For instance with a ratio of 0.5, if 100 tfidf
+ features are remaining after feature selection, a maximum of 50
+ cooccurrence features will be added
+ """
+ self.pvalue_threshold = pvalue_threshold
+ self.added_cooccurrence_feature_ratio = \
+ added_cooccurrence_feature_ratio
+
+ if tfidf_vectorizer_config is None:
+ tfidf_vectorizer_config = TfidfVectorizerConfig()
+ elif isinstance(tfidf_vectorizer_config, dict):
+ tfidf_vectorizer_config = TfidfVectorizerConfig.from_dict(
+ tfidf_vectorizer_config)
+ self.tfidf_vectorizer_config = tfidf_vectorizer_config
+
+ if cooccurrence_vectorizer_config is None:
+ cooccurrence_vectorizer_config = CooccurrenceVectorizerConfig()
+ elif isinstance(cooccurrence_vectorizer_config, dict):
+ cooccurrence_vectorizer_config = CooccurrenceVectorizerConfig \
+ .from_dict(cooccurrence_vectorizer_config)
+ self.cooccurrence_vectorizer_config = cooccurrence_vectorizer_config
+
+ # pylint: enable=line-too-long
+
+ @property
+ def unit_name(self):
+ from snips_inference_agl.intent_classifier import Featurizer
+ return Featurizer.unit_name
+
+ def get_required_resources(self):
+ required_resources = self.tfidf_vectorizer_config \
+ .get_required_resources()
+ if self.cooccurrence_vectorizer_config:
+ required_resources = merge_required_resources(
+ required_resources,
+ self.cooccurrence_vectorizer_config.get_required_resources())
+ return required_resources
+
+ def to_dict(self):
+ return {
+ "unit_name": self.unit_name,
+ "pvalue_threshold": self.pvalue_threshold,
+ "added_cooccurrence_feature_ratio":
+ self.added_cooccurrence_feature_ratio,
+ "tfidf_vectorizer_config": self.tfidf_vectorizer_config.to_dict(),
+ "cooccurrence_vectorizer_config":
+ self.cooccurrence_vectorizer_config.to_dict(),
+ }
+
+
+class TfidfVectorizerConfig(FromDict, ProcessingUnitConfig):
+ """Configuration of a :class:`.TfidfVectorizerConfig` object"""
+
+ def __init__(self, word_clusters_name=None, use_stemming=False):
+ """
+ Args:
+ word_clusters_name (str, optional): if a word cluster name is
+ provided then the featurizer will use the word clusters IDs
+ detected in the utterances and add them to the utterance text
+ before computing the tfidf. Default to None
+ use_stemming (bool, optional): use stemming before computing the
+ tfdif. Defaults to False (no stemming used)
+ """
+ self.word_clusters_name = word_clusters_name
+ self.use_stemming = use_stemming
+
+ @property
+ def unit_name(self):
+ from snips_inference_agl.intent_classifier import TfidfVectorizer
+ return TfidfVectorizer.unit_name
+
+ def get_required_resources(self):
+ resources = {STEMS: True if self.use_stemming else False}
+ if self.word_clusters_name:
+ resources[WORD_CLUSTERS] = {self.word_clusters_name}
+ return resources
+
+ def to_dict(self):
+ return {
+ "unit_name": self.unit_name,
+ "word_clusters_name": self.word_clusters_name,
+ "use_stemming": self.use_stemming
+ }
+
+
+class CooccurrenceVectorizerConfig(FromDict, ProcessingUnitConfig):
+ """Configuration of a :class:`.CooccurrenceVectorizer` object"""
+
+ def __init__(self, window_size=None, unknown_words_replacement_string=None,
+ filter_stop_words=True, keep_order=True):
+ """
+ Args:
+ window_size (int, optional): if provided, word cooccurrences will
+ be taken into account only in a context window of size
+ :attr:`window_size`. If the window size is 3 then given a word
+ w[i], the vectorizer will only extract the following pairs:
+ (w[i], w[i + 1]), (w[i], w[i + 2]) and (w[i], w[i + 3]).
+ Defaults to None, which means that we consider all words
+ unknown_words_replacement_string (str, optional)
+ filter_stop_words (bool, optional): if True, stop words are ignored
+ when computing cooccurrences
+ keep_order (bool, optional): if True then cooccurrence are computed
+ taking the words order into account, which means the pairs
+ (w1, w2) and (w2, w1) will count as two separate features.
+ Defaults to `True`.
+ """
+ self.window_size = window_size
+ self.unknown_words_replacement_string = \
+ unknown_words_replacement_string
+ self.filter_stop_words = filter_stop_words
+ self.keep_order = keep_order
+
+ @property
+ def unit_name(self):
+ from snips_inference_agl.intent_classifier import CooccurrenceVectorizer
+ return CooccurrenceVectorizer.unit_name
+
+ def get_required_resources(self):
+ return {
+ STOP_WORDS: self.filter_stop_words,
+ # We require the parser to be trained without stems because we
+ # don't normalize and stem when processing in the
+ # CooccurrenceVectorizer (in order to run the builtin and
+ # custom parser on the same unormalized input).
+ # Requiring no stems ensures we'll be able to parse the unstemmed
+ # input
+ CUSTOM_ENTITY_PARSER_USAGE: CustomEntityParserUsage.WITHOUT_STEMS
+ }
+
+ def to_dict(self):
+ return {
+ "unit_name": self.unit_name,
+ "unknown_words_replacement_string":
+ self.unknown_words_replacement_string,
+ "window_size": self.window_size,
+ "filter_stop_words": self.filter_stop_words,
+ "keep_order": self.keep_order
+ }
diff --git a/snips_inference_agl/pipeline/configs/intent_parser.py b/snips_inference_agl/pipeline/configs/intent_parser.py
new file mode 100644
index 0000000..f017472
--- /dev/null
+++ b/snips_inference_agl/pipeline/configs/intent_parser.py
@@ -0,0 +1,127 @@
+from __future__ import unicode_literals
+
+from snips_inference_agl.common.from_dict import FromDict
+from snips_inference_agl.constants import CUSTOM_ENTITY_PARSER_USAGE, STOP_WORDS
+from snips_inference_agl.entity_parser import CustomEntityParserUsage
+from snips_inference_agl.pipeline.configs import ProcessingUnitConfig
+from snips_inference_agl.resources import merge_required_resources
+
+
+class ProbabilisticIntentParserConfig(FromDict, ProcessingUnitConfig):
+ """Configuration of a :class:`.ProbabilisticIntentParser` object
+
+ Args:
+ intent_classifier_config (:class:`.ProcessingUnitConfig`): The
+ configuration of the underlying intent classifier, by default
+ it uses a :class:`.LogRegIntentClassifierConfig`
+ slot_filler_config (:class:`.ProcessingUnitConfig`): The configuration
+ that will be used for the underlying slot fillers, by default it
+ uses a :class:`.CRFSlotFillerConfig`
+ """
+
+ def __init__(self, intent_classifier_config=None, slot_filler_config=None):
+ from snips_inference_agl.intent_classifier import IntentClassifier
+ from snips_inference_agl.slot_filler import SlotFiller
+
+ if intent_classifier_config is None:
+ from snips_inference_agl.pipeline.configs import LogRegIntentClassifierConfig
+ intent_classifier_config = LogRegIntentClassifierConfig()
+ if slot_filler_config is None:
+ from snips_inference_agl.pipeline.configs import CRFSlotFillerConfig
+ slot_filler_config = CRFSlotFillerConfig()
+ self.intent_classifier_config = IntentClassifier.get_config(
+ intent_classifier_config)
+ self.slot_filler_config = SlotFiller.get_config(slot_filler_config)
+
+ @property
+ def unit_name(self):
+ from snips_inference_agl.intent_parser import ProbabilisticIntentParser
+ return ProbabilisticIntentParser.unit_name
+
+ def get_required_resources(self):
+ resources = self.intent_classifier_config.get_required_resources()
+ resources = merge_required_resources(
+ resources, self.slot_filler_config.get_required_resources())
+ return resources
+
+ def to_dict(self):
+ return {
+ "unit_name": self.unit_name,
+ "slot_filler_config": self.slot_filler_config.to_dict(),
+ "intent_classifier_config": self.intent_classifier_config.to_dict()
+ }
+
+
+class DeterministicIntentParserConfig(FromDict, ProcessingUnitConfig):
+ """Configuration of a :class:`.DeterministicIntentParser`
+
+ Args:
+ max_queries (int, optional): Maximum number of regex patterns per
+ intent. 50 by default.
+ max_pattern_length (int, optional): Maximum length of regex patterns.
+ ignore_stop_words (bool, optional): If True, stop words will be
+ removed before building patterns.
+
+
+ This allows to deactivate the usage of regular expression when they are
+ too big to avoid explosion in time and memory
+
+ Note:
+ In the future, a FST will be used instead of regexps, removing the need
+ for all this
+ """
+
+ def __init__(self, max_queries=100, max_pattern_length=1000,
+ ignore_stop_words=False):
+ self.max_queries = max_queries
+ self.max_pattern_length = max_pattern_length
+ self.ignore_stop_words = ignore_stop_words
+
+ @property
+ def unit_name(self):
+ from snips_inference_agl.intent_parser import DeterministicIntentParser
+ return DeterministicIntentParser.unit_name
+
+ def get_required_resources(self):
+ return {
+ CUSTOM_ENTITY_PARSER_USAGE: CustomEntityParserUsage.WITHOUT_STEMS,
+ STOP_WORDS: self.ignore_stop_words
+ }
+
+ def to_dict(self):
+ return {
+ "unit_name": self.unit_name,
+ "max_queries": self.max_queries,
+ "max_pattern_length": self.max_pattern_length,
+ "ignore_stop_words": self.ignore_stop_words
+ }
+
+
+class LookupIntentParserConfig(FromDict, ProcessingUnitConfig):
+ """Configuration of a :class:`.LookupIntentParser`
+
+ Args:
+ ignore_stop_words (bool, optional): If True, stop words will be
+ removed before building patterns.
+ """
+
+ def __init__(self, ignore_stop_words=False):
+ self.ignore_stop_words = ignore_stop_words
+
+ @property
+ def unit_name(self):
+ from snips_inference_agl.intent_parser.lookup_intent_parser import \
+ LookupIntentParser
+ return LookupIntentParser.unit_name
+
+ def get_required_resources(self):
+ return {
+ CUSTOM_ENTITY_PARSER_USAGE: CustomEntityParserUsage.WITHOUT_STEMS,
+ STOP_WORDS: self.ignore_stop_words
+ }
+
+ def to_dict(self):
+ return {
+ "unit_name": self.unit_name,
+ "ignore_stop_words": self.ignore_stop_words
+ }
diff --git a/snips_inference_agl/pipeline/configs/nlu_engine.py b/snips_inference_agl/pipeline/configs/nlu_engine.py
new file mode 100644
index 0000000..3826702
--- /dev/null
+++ b/snips_inference_agl/pipeline/configs/nlu_engine.py
@@ -0,0 +1,55 @@
+from __future__ import unicode_literals
+
+from snips_inference_agl.common.from_dict import FromDict
+from snips_inference_agl.constants import CUSTOM_ENTITY_PARSER_USAGE
+from snips_inference_agl.entity_parser import CustomEntityParserUsage
+from snips_inference_agl.pipeline.configs import ProcessingUnitConfig
+from snips_inference_agl.resources import merge_required_resources
+
+
+class NLUEngineConfig(FromDict, ProcessingUnitConfig):
+ """Configuration of a :class:`.SnipsNLUEngine` object
+
+ Args:
+ intent_parsers_configs (list): List of intent parser configs
+ (:class:`.ProcessingUnitConfig`). The order in the list determines
+ the order in which each parser will be called by the nlu engine.
+ """
+
+ def __init__(self, intent_parsers_configs=None, random_seed=None):
+ from snips_inference_agl.intent_parser import IntentParser
+
+ if intent_parsers_configs is None:
+ from snips_inference_agl.pipeline.configs import (
+ ProbabilisticIntentParserConfig,
+ DeterministicIntentParserConfig)
+ intent_parsers_configs = [
+ DeterministicIntentParserConfig(),
+ ProbabilisticIntentParserConfig()
+ ]
+ self.intent_parsers_configs = [
+ IntentParser.get_config(conf) for conf in intent_parsers_configs]
+ self.random_seed = random_seed
+
+ @property
+ def unit_name(self):
+ from snips_inference_agl.nlu_engine.nlu_engine import SnipsNLUEngine
+ return SnipsNLUEngine.unit_name
+
+ def get_required_resources(self):
+ # Resolving custom slot values must be done without stemming
+ resources = {
+ CUSTOM_ENTITY_PARSER_USAGE: CustomEntityParserUsage.WITHOUT_STEMS
+ }
+ for config in self.intent_parsers_configs:
+ resources = merge_required_resources(
+ resources, config.get_required_resources())
+ return resources
+
+ def to_dict(self):
+ return {
+ "unit_name": self.unit_name,
+ "intent_parsers_configs": [
+ config.to_dict() for config in self.intent_parsers_configs
+ ]
+ }
diff --git a/snips_inference_agl/pipeline/configs/slot_filler.py b/snips_inference_agl/pipeline/configs/slot_filler.py
new file mode 100644
index 0000000..be36e9c
--- /dev/null
+++ b/snips_inference_agl/pipeline/configs/slot_filler.py
@@ -0,0 +1,145 @@
+from __future__ import unicode_literals
+
+from snips_inference_agl.common.from_dict import FromDict
+from snips_inference_agl.constants import STOP_WORDS
+from snips_inference_agl.pipeline.configs import (
+ Config, ProcessingUnitConfig, default_features_factories)
+from snips_inference_agl.resources import merge_required_resources
+
+
+class CRFSlotFillerConfig(FromDict, ProcessingUnitConfig):
+ # pylint: disable=line-too-long
+ """Configuration of a :class:`.CRFSlotFiller`
+
+ Args:
+ feature_factory_configs (list, optional): List of configurations that
+ specify the list of :class:`.CRFFeatureFactory` to use with the CRF
+ tagging_scheme (:class:`.TaggingScheme`, optional): Tagging scheme to
+ use to enrich CRF labels (default=BIO)
+ crf_args (dict, optional): Allow to overwrite the parameters of the CRF
+ defined in *sklearn_crfsuite*, see :class:`sklearn_crfsuite.CRF`
+ (default={"c1": .1, "c2": .1, "algorithm": "lbfgs"})
+ data_augmentation_config (dict or :class:`.SlotFillerDataAugmentationConfig`, optional):
+ Specify how to augment data before training the CRF, see the
+ corresponding config object for more details.
+ random_seed (int, optional): Specify to make the CRF training
+ deterministic and reproducible (default=None)
+ """
+
+ # pylint: enable=line-too-long
+
+ def __init__(self, feature_factory_configs=None,
+ tagging_scheme=None, crf_args=None,
+ data_augmentation_config=None):
+ if tagging_scheme is None:
+ from snips_inference_agl.slot_filler.crf_utils import TaggingScheme
+ tagging_scheme = TaggingScheme.BIO
+ if feature_factory_configs is None:
+ feature_factory_configs = default_features_factories()
+ if crf_args is None:
+ crf_args = _default_crf_args()
+ if data_augmentation_config is None:
+ data_augmentation_config = SlotFillerDataAugmentationConfig()
+ self.feature_factory_configs = feature_factory_configs
+ self._tagging_scheme = None
+ self.tagging_scheme = tagging_scheme
+ self.crf_args = crf_args
+ self._data_augmentation_config = None
+ self.data_augmentation_config = data_augmentation_config
+
+ @property
+ def tagging_scheme(self):
+ return self._tagging_scheme
+
+ @tagging_scheme.setter
+ def tagging_scheme(self, value):
+ from snips_inference_agl.slot_filler.crf_utils import TaggingScheme
+ if isinstance(value, TaggingScheme):
+ self._tagging_scheme = value
+ elif isinstance(value, int):
+ self._tagging_scheme = TaggingScheme(value)
+ else:
+ raise TypeError("Expected instance of TaggingScheme or int but"
+ "received: %s" % type(value))
+
+ @property
+ def data_augmentation_config(self):
+ return self._data_augmentation_config
+
+ @data_augmentation_config.setter
+ def data_augmentation_config(self, value):
+ if isinstance(value, dict):
+ self._data_augmentation_config = \
+ SlotFillerDataAugmentationConfig.from_dict(value)
+ elif isinstance(value, SlotFillerDataAugmentationConfig):
+ self._data_augmentation_config = value
+ else:
+ raise TypeError("Expected instance of "
+ "SlotFillerDataAugmentationConfig or dict but "
+ "received: %s" % type(value))
+
+ @property
+ def unit_name(self):
+ from snips_inference_agl.slot_filler import CRFSlotFiller
+ return CRFSlotFiller.unit_name
+
+ def get_required_resources(self):
+ # Import here to avoid circular imports
+ from snips_inference_agl.slot_filler.feature_factory import CRFFeatureFactory
+
+ resources = self.data_augmentation_config.get_required_resources()
+ for config in self.feature_factory_configs:
+ factory = CRFFeatureFactory.from_config(config)
+ resources = merge_required_resources(
+ resources, factory.get_required_resources())
+ return resources
+
+ def to_dict(self):
+ return {
+ "unit_name": self.unit_name,
+ "feature_factory_configs": self.feature_factory_configs,
+ "crf_args": self.crf_args,
+ "tagging_scheme": self.tagging_scheme.value,
+ "data_augmentation_config":
+ self.data_augmentation_config.to_dict()
+ }
+
+
+class SlotFillerDataAugmentationConfig(FromDict, Config):
+ """Specify how to augment data before training the CRF
+
+ Data augmentation essentially consists in creating additional utterances
+ by combining utterance patterns and slot values
+
+ Args:
+ min_utterances (int, optional): Specify the minimum amount of
+ utterances to generate per intent (default=200)
+ capitalization_ratio (float, optional): If an entity has one or more
+ capitalized values, the data augmentation will randomly capitalize
+ its values with a ratio of *capitalization_ratio* (default=.2)
+ add_builtin_entities_examples (bool, optional): If True, some builtin
+ entity examples will be automatically added to the training data.
+ Default is True.
+ """
+
+ def __init__(self, min_utterances=200, capitalization_ratio=.2,
+ add_builtin_entities_examples=True):
+ self.min_utterances = min_utterances
+ self.capitalization_ratio = capitalization_ratio
+ self.add_builtin_entities_examples = add_builtin_entities_examples
+
+ def get_required_resources(self):
+ return {
+ STOP_WORDS: True
+ }
+
+ def to_dict(self):
+ return {
+ "min_utterances": self.min_utterances,
+ "capitalization_ratio": self.capitalization_ratio,
+ "add_builtin_entities_examples": self.add_builtin_entities_examples
+ }
+
+
+def _default_crf_args():
+ return {"c1": .1, "c2": .1, "algorithm": "lbfgs"}
diff --git a/snips_inference_agl/pipeline/processing_unit.py b/snips_inference_agl/pipeline/processing_unit.py
new file mode 100644
index 0000000..1928470
--- /dev/null
+++ b/snips_inference_agl/pipeline/processing_unit.py
@@ -0,0 +1,177 @@
+from __future__ import unicode_literals
+
+import io
+import json
+import shutil
+from abc import ABCMeta, abstractmethod, abstractproperty
+from builtins import str, bytes
+from pathlib import Path
+
+from future.utils import with_metaclass
+
+from snips_inference_agl.common.abc_utils import abstractclassmethod, classproperty
+from snips_inference_agl.common.io_utils import temp_dir, unzip_archive
+from snips_inference_agl.common.registrable import Registrable
+from snips_inference_agl.common.utils import (
+ json_string, check_random_state)
+from snips_inference_agl.constants import (
+ BUILTIN_ENTITY_PARSER, CUSTOM_ENTITY_PARSER, CUSTOM_ENTITY_PARSER_USAGE,
+ RESOURCES, LANGUAGE, RANDOM_STATE)
+from snips_inference_agl.entity_parser import (
+ BuiltinEntityParser, CustomEntityParser, CustomEntityParserUsage)
+from snips_inference_agl.exceptions import LoadingError
+from snips_inference_agl.pipeline.configs import ProcessingUnitConfig
+from snips_inference_agl.pipeline.configs.config import DefaultProcessingUnitConfig
+from snips_inference_agl.resources import load_resources
+
+
+class ProcessingUnit(with_metaclass(ABCMeta, Registrable)):
+ """Abstraction of a NLU pipeline unit
+
+ Pipeline processing units such as intent parsers, intent classifiers and
+ slot fillers must implement this class.
+
+ A :class:`ProcessingUnit` is associated with a *config_type*, which
+ represents the :class:`.ProcessingUnitConfig` used to initialize it.
+ """
+
+ def __init__(self, config, **shared):
+ if config is None:
+ self.config = self.default_config()
+ elif isinstance(config, ProcessingUnitConfig):
+ self.config = config
+ elif isinstance(config, dict):
+ self.config = self.config_type.from_dict(config)
+ else:
+ raise ValueError("Unexpected config type: %s" % type(config))
+ if self.config is not None:
+ self.config.set_unit_name(self.unit_name)
+ self.builtin_entity_parser = shared.get(BUILTIN_ENTITY_PARSER)
+ self.custom_entity_parser = shared.get(CUSTOM_ENTITY_PARSER)
+ self.resources = shared.get(RESOURCES)
+ self.random_state = check_random_state(shared.get(RANDOM_STATE))
+
+ @classproperty
+ def config_type(cls): # pylint:disable=no-self-argument
+ return DefaultProcessingUnitConfig
+
+ @classmethod
+ def default_config(cls):
+ config = cls.config_type() # pylint:disable=no-value-for-parameter
+ config.set_unit_name(cls.unit_name)
+ return config
+
+ @classproperty
+ def unit_name(cls): # pylint:disable=no-self-argument
+ return ProcessingUnit.registered_name(cls)
+
+ @classmethod
+ def from_config(cls, unit_config, **shared):
+ """Build a :class:`ProcessingUnit` from the provided config"""
+ unit = cls.by_name(unit_config.unit_name)
+ return unit(unit_config, **shared)
+
+ @classmethod
+ def load_from_path(cls, unit_path, unit_name=None, **shared):
+ """Load a :class:`ProcessingUnit` from a persisted processing unit
+ directory
+
+ Args:
+ unit_path (str or :class:`pathlib.Path`): path to the persisted
+ processing unit
+ unit_name (str, optional): Name of the processing unit to load.
+ By default, the unit name is assumed to be stored in a
+ "metadata.json" file located in the directory at unit_path.
+
+ Raises:
+ LoadingError: when unit_name is None and no metadata file is found
+ in the processing unit directory
+ """
+ unit_path = Path(unit_path)
+ if unit_name is None:
+ metadata_path = unit_path / "metadata.json"
+ if not metadata_path.exists():
+ raise LoadingError(
+ "Missing metadata for processing unit at path %s"
+ % str(unit_path))
+ with metadata_path.open(encoding="utf8") as f:
+ metadata = json.load(f)
+ unit_name = metadata["unit_name"]
+ unit = cls.by_name(unit_name)
+ return unit.from_path(unit_path, **shared)
+
+ @classmethod
+ def get_config(cls, unit_config):
+ """Returns the :class:`.ProcessingUnitConfig` corresponding to
+ *unit_config*"""
+ if isinstance(unit_config, ProcessingUnitConfig):
+ return unit_config
+ elif isinstance(unit_config, dict):
+ unit_name = unit_config["unit_name"]
+ processing_unit_type = cls.by_name(unit_name)
+ return processing_unit_type.config_type.from_dict(unit_config)
+ elif isinstance(unit_config, (str, bytes)):
+ unit_name = unit_config
+ unit_config = {"unit_name": unit_name}
+ processing_unit_type = cls.by_name(unit_name)
+ return processing_unit_type.config_type.from_dict(unit_config)
+ else:
+ raise ValueError(
+ "Expected `unit_config` to be an instance of "
+ "ProcessingUnitConfig or dict or str but found: %s"
+ % type(unit_config))
+
+ @abstractproperty
+ def fitted(self):
+ """Whether or not the processing unit has already been trained"""
+ pass
+
+ def load_resources_if_needed(self, language):
+ if self.resources is None or self.fitted:
+ required_resources = None
+ if self.config is not None:
+ required_resources = self.config.get_required_resources()
+ self.resources = load_resources(language, required_resources)
+
+ def fit_builtin_entity_parser_if_needed(self, dataset):
+ # We only fit a builtin entity parser when the unit has already been
+ # fitted or if the parser is none.
+ # In the other cases the parser is provided fitted by another unit.
+ if self.builtin_entity_parser is None or self.fitted:
+ self.builtin_entity_parser = BuiltinEntityParser.build(
+ dataset=dataset)
+ return self
+
+ def fit_custom_entity_parser_if_needed(self, dataset):
+ # We only fit a custom entity parser when the unit has already been
+ # fitted or if the parser is none.
+ # In the other cases the parser is provided fitted by another unit.
+ required_resources = self.config.get_required_resources()
+ if not required_resources or not required_resources.get(
+ CUSTOM_ENTITY_PARSER_USAGE):
+ # In these cases we need a custom entity parser only to do the
+ # final slot resolution step, which must be done without stemming.
+ parser_usage = CustomEntityParserUsage.WITHOUT_STEMS
+ else:
+ parser_usage = required_resources[CUSTOM_ENTITY_PARSER_USAGE]
+
+ if self.custom_entity_parser is None or self.fitted:
+ self.load_resources_if_needed(dataset[LANGUAGE])
+ self.custom_entity_parser = CustomEntityParser.build(
+ dataset, parser_usage, self.resources)
+ return self
+
+ def persist_metadata(self, path, **kwargs):
+ metadata = {"unit_name": self.unit_name}
+ metadata.update(kwargs)
+ metadata_json = json_string(metadata)
+ with (path / "metadata.json").open(mode="w", encoding="utf8") as f:
+ f.write(metadata_json)
+
+ # @abstractmethod
+ def persist(self, path):
+ pass
+
+ @abstractclassmethod
+ def from_path(cls, path, **shared):
+ pass
diff --git a/snips_inference_agl/preprocessing.py b/snips_inference_agl/preprocessing.py
new file mode 100644
index 0000000..cfb4aa5
--- /dev/null
+++ b/snips_inference_agl/preprocessing.py
@@ -0,0 +1,97 @@
+# coding=utf-8
+from __future__ import unicode_literals
+
+from builtins import object
+
+from snips_inference_agl.resources import get_stems
+
+
+def stem(string, language, resources):
+ from snips_nlu_utils import normalize
+
+ normalized_string = normalize(string)
+ tokens = tokenize_light(normalized_string, language)
+ stemmed_tokens = [_stem(token, resources) for token in tokens]
+ return " ".join(stemmed_tokens)
+
+
+def stem_token(token, resources):
+ from snips_nlu_utils import normalize
+
+ if token.stemmed_value:
+ return token.stemmed_value
+ if not token.normalized_value:
+ token.normalized_value = normalize(token.value)
+ token.stemmed_value = _stem(token.normalized_value, resources)
+ return token.stemmed_value
+
+
+def normalize_token(token):
+ from snips_nlu_utils import normalize
+
+ if token.normalized_value:
+ return token.normalized_value
+ token.normalized_value = normalize(token.value)
+ return token.normalized_value
+
+
+def _stem(string, resources):
+ return get_stems(resources).get(string, string)
+
+
+class Token(object):
+ """Token object which is output by the tokenization
+
+ Attributes:
+ value (str): Tokenized string
+ start (int): Start position of the token within the sentence
+ end (int): End position of the token within the sentence
+ normalized_value (str): Normalized value of the tokenized string
+ stemmed_value (str): Stemmed value of the tokenized string
+ """
+
+ def __init__(self, value, start, end, normalized_value=None,
+ stemmed_value=None):
+ self.value = value
+ self.start = start
+ self.end = end
+ self.normalized_value = normalized_value
+ self.stemmed_value = stemmed_value
+
+ def __eq__(self, other):
+ if not isinstance(other, type(self)):
+ return False
+ return (self.value == other.value
+ and self.start == other.start
+ and self.end == other.end)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+def tokenize(string, language):
+ """Tokenizes the input
+
+ Args:
+ string (str): Input to tokenize
+ language (str): Language to use during tokenization
+
+ Returns:
+ list of :class:`.Token`: The list of tokenized values
+ """
+ from snips_nlu_utils import tokenize as _tokenize
+
+ tokens = [Token(value=token["value"],
+ start=token["char_range"]["start"],
+ end=token["char_range"]["end"])
+ for token in _tokenize(string, language)]
+ return tokens
+
+
+def tokenize_light(string, language):
+ """Same behavior as :func:`tokenize` but returns tokenized strings instead
+ of :class:`Token` objects"""
+ from snips_nlu_utils import tokenize_light as _tokenize_light
+
+ tokenized_string = _tokenize_light(string, language)
+ return tokenized_string
diff --git a/snips_inference_agl/resources.py b/snips_inference_agl/resources.py
new file mode 100644
index 0000000..3b3e478
--- /dev/null
+++ b/snips_inference_agl/resources.py
@@ -0,0 +1,267 @@
+from __future__ import unicode_literals
+
+import json
+from copy import deepcopy
+from pathlib import Path
+
+from snips_inference_agl.common.utils import get_package_path, is_package
+from snips_inference_agl.constants import (
+ CUSTOM_ENTITY_PARSER_USAGE, DATA_PATH, GAZETTEERS, NOISE,
+ STEMS, STOP_WORDS, WORD_CLUSTERS, METADATA)
+from snips_inference_agl.entity_parser.custom_entity_parser import (
+ CustomEntityParserUsage)
+
+
+class MissingResource(LookupError):
+ pass
+
+
+def load_resources(name, required_resources=None):
+ """Load language specific resources
+
+ Args:
+ name (str): Resource name as in ``snips-nlu download <name>``. Can also
+ be the name of a python package or a directory path.
+ required_resources (dict, optional): Resources requirement
+ dict which, when provided, allows to limit the amount of resources
+ to load. By default, all existing resources are loaded.
+ """
+ if name in set(d.name for d in DATA_PATH.iterdir()):
+ return load_resources_from_dir(DATA_PATH / name, required_resources)
+ elif is_package(name):
+ package_path = get_package_path(name)
+ resources_sub_dir = get_resources_sub_directory(package_path)
+ return load_resources_from_dir(resources_sub_dir, required_resources)
+ elif Path(name).exists():
+ path = Path(name)
+ if (path / "__init__.py").exists():
+ path = get_resources_sub_directory(path)
+ return load_resources_from_dir(path, required_resources)
+ else:
+ raise MissingResource("Language resource '{r}' not found. This may be "
+ "solved by running "
+ "'python -m snips_nlu download {r}'"
+ .format(r=name))
+
+
+def load_resources_from_dir(resources_dir, required_resources=None):
+ with (resources_dir / "metadata.json").open(encoding="utf8") as f:
+ metadata = json.load(f)
+ metadata = _update_metadata(metadata, required_resources)
+ gazetteer_names = metadata["gazetteers"]
+ clusters_names = metadata["word_clusters"]
+ stop_words_filename = metadata["stop_words"]
+ stems_filename = metadata["stems"]
+ noise_filename = metadata["noise"]
+
+ gazetteers = _get_gazetteers(resources_dir / "gazetteers", gazetteer_names)
+ word_clusters = _get_word_clusters(resources_dir / "word_clusters",
+ clusters_names)
+
+ stems = None
+ stop_words = None
+ noise = None
+
+ if stems_filename is not None:
+ stems = _get_stems(resources_dir / "stemming", stems_filename)
+ if stop_words_filename is not None:
+ stop_words = _get_stop_words(resources_dir, stop_words_filename)
+ if noise_filename is not None:
+ noise = _get_noise(resources_dir, noise_filename)
+
+ return {
+ METADATA: metadata,
+ WORD_CLUSTERS: word_clusters,
+ GAZETTEERS: gazetteers,
+ STOP_WORDS: stop_words,
+ NOISE: noise,
+ STEMS: stems,
+ }
+
+
+def _update_metadata(metadata, required_resources):
+ if required_resources is None:
+ return metadata
+ metadata = deepcopy(metadata)
+ required_gazetteers = required_resources.get(GAZETTEERS, [])
+ required_word_clusters = required_resources.get(WORD_CLUSTERS, [])
+ for gazetter in required_gazetteers:
+ if gazetter not in metadata["gazetteers"]:
+ raise ValueError("Unknown gazetteer for language '%s': '%s'"
+ % (metadata["language"], gazetter))
+ for word_clusters in required_word_clusters:
+ if word_clusters not in metadata["word_clusters"]:
+ raise ValueError("Unknown word clusters for language '%s': '%s'"
+ % (metadata["language"], word_clusters))
+ metadata["gazetteers"] = required_gazetteers
+ metadata["word_clusters"] = required_word_clusters
+ if not required_resources.get(STEMS, False):
+ metadata["stems"] = None
+ if not required_resources.get(NOISE, False):
+ metadata["noise"] = None
+ if not required_resources.get(STOP_WORDS, False):
+ metadata["stop_words"] = None
+ return metadata
+
+
+def get_resources_sub_directory(resources_dir):
+ resources_dir = Path(resources_dir)
+ with (resources_dir / "metadata.json").open(encoding="utf8") as f:
+ metadata = json.load(f)
+ resource_name = metadata["name"]
+ version = metadata["version"]
+ sub_dir_name = "{r}-{v}".format(r=resource_name, v=version)
+ return resources_dir / sub_dir_name
+
+
+def get_stop_words(resources):
+ return _get_resource(resources, STOP_WORDS)
+
+
+def get_noise(resources):
+ return _get_resource(resources, NOISE)
+
+
+def get_word_clusters(resources):
+ return _get_resource(resources, WORD_CLUSTERS)
+
+
+def get_word_cluster(resources, cluster_name):
+ word_clusters = get_word_clusters(resources)
+ if cluster_name not in word_clusters:
+ raise MissingResource("Word cluster '{}' not found" % cluster_name)
+ return word_clusters[cluster_name]
+
+
+def get_gazetteer(resources, gazetteer_name):
+ gazetteers = _get_resource(resources, GAZETTEERS)
+ if gazetteer_name not in gazetteers:
+ raise MissingResource("Gazetteer '%s' not found in resources"
+ % gazetteer_name)
+ return gazetteers[gazetteer_name]
+
+
+def get_stems(resources):
+ return _get_resource(resources, STEMS)
+
+
+def merge_required_resources(lhs, rhs):
+ if not lhs:
+ return dict() if rhs is None else rhs
+ if not rhs:
+ return dict() if lhs is None else lhs
+ merged_resources = dict()
+ if lhs.get(NOISE, False) or rhs.get(NOISE, False):
+ merged_resources[NOISE] = True
+ if lhs.get(STOP_WORDS, False) or rhs.get(STOP_WORDS, False):
+ merged_resources[STOP_WORDS] = True
+ if lhs.get(STEMS, False) or rhs.get(STEMS, False):
+ merged_resources[STEMS] = True
+ lhs_parser_usage = lhs.get(CUSTOM_ENTITY_PARSER_USAGE)
+ rhs_parser_usage = rhs.get(CUSTOM_ENTITY_PARSER_USAGE)
+ parser_usage = CustomEntityParserUsage.merge_usages(
+ lhs_parser_usage, rhs_parser_usage)
+ merged_resources[CUSTOM_ENTITY_PARSER_USAGE] = parser_usage
+ gazetteers = lhs.get(GAZETTEERS, set()).union(rhs.get(GAZETTEERS, set()))
+ if gazetteers:
+ merged_resources[GAZETTEERS] = gazetteers
+ word_clusters = lhs.get(WORD_CLUSTERS, set()).union(
+ rhs.get(WORD_CLUSTERS, set()))
+ if word_clusters:
+ merged_resources[WORD_CLUSTERS] = word_clusters
+ return merged_resources
+
+
+def _get_resource(resources, resource_name):
+ if resource_name not in resources or resources[resource_name] is None:
+ raise MissingResource("Resource '%s' not found" % resource_name)
+ return resources[resource_name]
+
+
+def _get_stop_words(resources_dir, stop_words_filename):
+ if not stop_words_filename:
+ return None
+ stop_words_path = (resources_dir / stop_words_filename).with_suffix(".txt")
+ return _load_stop_words(stop_words_path)
+
+
+def _load_stop_words(stop_words_path):
+ with stop_words_path.open(encoding="utf8") as f:
+ stop_words = set(l.strip() for l in f if l)
+ return stop_words
+
+
+def _get_noise(resources_dir, noise_filename):
+ if not noise_filename:
+ return None
+ noise_path = (resources_dir / noise_filename).with_suffix(".txt")
+ return _load_noise(noise_path)
+
+
+def _load_noise(noise_path):
+ with noise_path.open(encoding="utf8") as f:
+ # Here we split on a " " knowing that it's always ignored by
+ # the tokenization (see tokenization unit tests)
+ # It is not important to tokenize precisely as this noise is just used
+ # to generate utterances for the None intent
+ noise = [word for l in f for word in l.split()]
+ return noise
+
+
+def _get_word_clusters(word_clusters_dir, clusters_names):
+ if not clusters_names:
+ return dict()
+
+ clusters = dict()
+ for clusters_name in clusters_names:
+ clusters_path = (word_clusters_dir / clusters_name).with_suffix(".txt")
+ clusters[clusters_name] = _load_word_clusters(clusters_path)
+ return clusters
+
+
+def _load_word_clusters(path):
+ clusters = dict()
+ with path.open(encoding="utf8") as f:
+ for line in f:
+ split = line.rstrip().split("\t")
+ if not split:
+ continue
+ clusters[split[0]] = split[1]
+ return clusters
+
+
+def _get_gazetteers(gazetteers_dir, gazetteer_names):
+ if not gazetteer_names:
+ return dict()
+
+ gazetteers = dict()
+ for gazetteer_name in gazetteer_names:
+ gazetteer_path = (gazetteers_dir / gazetteer_name).with_suffix(".txt")
+ gazetteers[gazetteer_name] = _load_gazetteer(gazetteer_path)
+ return gazetteers
+
+
+def _load_gazetteer(path):
+ with path.open(encoding="utf8") as f:
+ gazetteer = set(v.strip() for v in f if v)
+ return gazetteer
+
+
+
+def _get_stems(stems_dir, filename):
+ if not filename:
+ return None
+ stems_path = (stems_dir / filename).with_suffix(".txt")
+ return _load_stems(stems_path)
+
+
+def _load_stems(path):
+ with path.open(encoding="utf8") as f:
+ stems = dict()
+ for line in f:
+ elements = line.strip().split(',')
+ stem = elements[0]
+ for value in elements[1:]:
+ stems[value] = stem
+ return stems
+
diff --git a/snips_inference_agl/result.py b/snips_inference_agl/result.py
new file mode 100644
index 0000000..39930c8
--- /dev/null
+++ b/snips_inference_agl/result.py
@@ -0,0 +1,342 @@
+from __future__ import unicode_literals
+
+from snips_inference_agl.constants import (
+ RES_ENTITY, RES_INPUT, RES_INTENT, RES_INTENT_NAME, RES_MATCH_RANGE,
+ RES_PROBA, RES_RAW_VALUE, RES_SLOTS, RES_SLOT_NAME, RES_VALUE, ENTITY_KIND,
+ RESOLVED_VALUE, VALUE)
+
+
+def intent_classification_result(intent_name, probability):
+ """Creates an intent classification result to be returned by
+ :meth:`.IntentClassifier.get_intent`
+
+ Example:
+
+ >>> intent_classification_result("GetWeather", 0.93)
+ {'intentName': 'GetWeather', 'probability': 0.93}
+ """
+ return {
+ RES_INTENT_NAME: intent_name,
+ RES_PROBA: probability
+ }
+
+
+def unresolved_slot(match_range, value, entity, slot_name):
+ """Creates an internal slot yet to be resolved
+
+ Example:
+
+ >>> from snips_inference_agl.common.utils import json_string
+ >>> slot = unresolved_slot([0, 8], "tomorrow", "snips/datetime", \
+ "startDate")
+ >>> print(json_string(slot, indent=4, sort_keys=True))
+ {
+ "entity": "snips/datetime",
+ "range": {
+ "end": 8,
+ "start": 0
+ },
+ "slotName": "startDate",
+ "value": "tomorrow"
+ }
+ """
+ return {
+ RES_MATCH_RANGE: _convert_range(match_range),
+ RES_VALUE: value,
+ RES_ENTITY: entity,
+ RES_SLOT_NAME: slot_name
+ }
+
+
+def custom_slot(internal_slot, resolved_value=None):
+ """Creates a custom slot with *resolved_value* being the reference value
+ of the slot
+
+ Example:
+
+ >>> s = unresolved_slot([10, 19], "earl grey", "beverage", "beverage")
+ >>> from snips_inference_agl.common.utils import json_string
+ >>> print(json_string(custom_slot(s, "tea"), indent=4, sort_keys=True))
+ {
+ "entity": "beverage",
+ "range": {
+ "end": 19,
+ "start": 10
+ },
+ "rawValue": "earl grey",
+ "slotName": "beverage",
+ "value": {
+ "kind": "Custom",
+ "value": "tea"
+ }
+ }
+ """
+
+ if resolved_value is None:
+ resolved_value = internal_slot[RES_VALUE]
+ return {
+ RES_MATCH_RANGE: _convert_range(internal_slot[RES_MATCH_RANGE]),
+ RES_RAW_VALUE: internal_slot[RES_VALUE],
+ RES_VALUE: {
+ "kind": "Custom",
+ "value": resolved_value
+ },
+ RES_ENTITY: internal_slot[RES_ENTITY],
+ RES_SLOT_NAME: internal_slot[RES_SLOT_NAME]
+ }
+
+
+def builtin_slot(internal_slot, resolved_value):
+ """Creates a builtin slot with *resolved_value* being the resolved value
+ of the slot
+
+ Example:
+
+ >>> rng = [10, 32]
+ >>> raw_value = "twenty degrees celsius"
+ >>> entity = "snips/temperature"
+ >>> slot_name = "beverageTemperature"
+ >>> s = unresolved_slot(rng, raw_value, entity, slot_name)
+ >>> resolved = {
+ ... "kind": "Temperature",
+ ... "value": 20,
+ ... "unit": "celsius"
+ ... }
+ >>> from snips_inference_agl.common.utils import json_string
+ >>> print(json_string(builtin_slot(s, resolved), indent=4))
+ {
+ "entity": "snips/temperature",
+ "range": {
+ "end": 32,
+ "start": 10
+ },
+ "rawValue": "twenty degrees celsius",
+ "slotName": "beverageTemperature",
+ "value": {
+ "kind": "Temperature",
+ "unit": "celsius",
+ "value": 20
+ }
+ }
+ """
+ return {
+ RES_MATCH_RANGE: _convert_range(internal_slot[RES_MATCH_RANGE]),
+ RES_RAW_VALUE: internal_slot[RES_VALUE],
+ RES_VALUE: resolved_value,
+ RES_ENTITY: internal_slot[RES_ENTITY],
+ RES_SLOT_NAME: internal_slot[RES_SLOT_NAME]
+ }
+
+
+def resolved_slot(match_range, raw_value, resolved_value, entity, slot_name):
+ """Creates a resolved slot
+
+ Args:
+ match_range (dict): Range of the slot within the sentence
+ (ex: {"start": 3, "end": 10})
+ raw_value (str): Slot value as it appears in the sentence
+ resolved_value (dict): Resolved value of the slot
+ entity (str): Entity which the slot belongs to
+ slot_name (str): Slot type
+
+ Returns:
+ dict: The resolved slot
+
+ Example:
+
+ >>> resolved_value = {
+ ... "kind": "Temperature",
+ ... "value": 20,
+ ... "unit": "celsius"
+ ... }
+ >>> slot = resolved_slot({"start": 10, "end": 19}, "earl grey",
+ ... resolved_value, "beverage", "beverage")
+ >>> from snips_inference_agl.common.utils import json_string
+ >>> print(json_string(slot, indent=4, sort_keys=True))
+ {
+ "entity": "beverage",
+ "range": {
+ "end": 19,
+ "start": 10
+ },
+ "rawValue": "earl grey",
+ "slotName": "beverage",
+ "value": {
+ "kind": "Temperature",
+ "unit": "celsius",
+ "value": 20
+ }
+ }
+ """
+ return {
+ RES_MATCH_RANGE: match_range,
+ RES_RAW_VALUE: raw_value,
+ RES_VALUE: resolved_value,
+ RES_ENTITY: entity,
+ RES_SLOT_NAME: slot_name
+ }
+
+
+def parsing_result(input, intent, slots): # pylint:disable=redefined-builtin
+ """Create the final output of :meth:`.SnipsNLUEngine.parse` or
+ :meth:`.IntentParser.parse`
+
+ Example:
+
+ >>> text = "Hello Bill!"
+ >>> intent_result = intent_classification_result("Greeting", 0.95)
+ >>> internal_slot = unresolved_slot([6, 10], "Bill", "name",
+ ... "greetee")
+ >>> slots = [custom_slot(internal_slot, "William")]
+ >>> res = parsing_result(text, intent_result, slots)
+ >>> from snips_inference_agl.common.utils import json_string
+ >>> print(json_string(res, indent=4, sort_keys=True))
+ {
+ "input": "Hello Bill!",
+ "intent": {
+ "intentName": "Greeting",
+ "probability": 0.95
+ },
+ "slots": [
+ {
+ "entity": "name",
+ "range": {
+ "end": 10,
+ "start": 6
+ },
+ "rawValue": "Bill",
+ "slotName": "greetee",
+ "value": {
+ "kind": "Custom",
+ "value": "William"
+ }
+ }
+ ]
+ }
+ """
+ return {
+ RES_INPUT: input,
+ RES_INTENT: intent,
+ RES_SLOTS: slots
+ }
+
+
+def extraction_result(intent, slots):
+ """Create the items in the output of :meth:`.SnipsNLUEngine.parse` or
+ :meth:`.IntentParser.parse` when called with a defined ``top_n`` value
+
+ This differs from :func:`.parsing_result` in that the input is omitted.
+
+ Example:
+
+ >>> intent_result = intent_classification_result("Greeting", 0.95)
+ >>> internal_slot = unresolved_slot([6, 10], "Bill", "name",
+ ... "greetee")
+ >>> slots = [custom_slot(internal_slot, "William")]
+ >>> res = extraction_result(intent_result, slots)
+ >>> from snips_inference_agl.common.utils import json_string
+ >>> print(json_string(res, indent=4, sort_keys=True))
+ {
+ "intent": {
+ "intentName": "Greeting",
+ "probability": 0.95
+ },
+ "slots": [
+ {
+ "entity": "name",
+ "range": {
+ "end": 10,
+ "start": 6
+ },
+ "rawValue": "Bill",
+ "slotName": "greetee",
+ "value": {
+ "kind": "Custom",
+ "value": "William"
+ }
+ }
+ ]
+ }
+ """
+ return {
+ RES_INTENT: intent,
+ RES_SLOTS: slots
+ }
+
+
+def is_empty(result):
+ """Check if a result is empty
+
+ Example:
+
+ >>> res = empty_result("foo bar", 1.0)
+ >>> is_empty(res)
+ True
+ """
+ return result[RES_INTENT][RES_INTENT_NAME] is None
+
+
+def empty_result(input, probability): # pylint:disable=redefined-builtin
+ """Creates an empty parsing result of the same format as the one of
+ :func:`parsing_result`
+
+ An empty is typically returned by a :class:`.SnipsNLUEngine` or
+ :class:`.IntentParser` when no intent nor slots were found.
+
+ Example:
+
+ >>> res = empty_result("foo bar", 0.8)
+ >>> from snips_inference_agl.common.utils import json_string
+ >>> print(json_string(res, indent=4, sort_keys=True))
+ {
+ "input": "foo bar",
+ "intent": {
+ "intentName": null,
+ "probability": 0.8
+ },
+ "slots": []
+ }
+ """
+ intent = intent_classification_result(None, probability)
+ return parsing_result(input=input, intent=intent, slots=[])
+
+
+def parsed_entity(entity_kind, entity_value, entity_resolved_value,
+ entity_range):
+ """Create the items in the output of
+ :meth:`snips_inference_agl.entity_parser.EntityParser.parse`
+
+ Example:
+ >>> resolved_value = dict(age=28, role="datascientist")
+ >>> range = dict(start=0, end=6)
+ >>> ent = parsed_entity("snipster", "adrien", resolved_value, range)
+ >>> import json
+ >>> print(json.dumps(ent, indent=4, sort_keys=True))
+ {
+ "entity_kind": "snipster",
+ "range": {
+ "end": 6,
+ "start": 0
+ },
+ "resolved_value": {
+ "age": 28,
+ "role": "datascientist"
+ },
+ "value": "adrien"
+ }
+ """
+ return {
+ VALUE: entity_value,
+ RESOLVED_VALUE: entity_resolved_value,
+ ENTITY_KIND: entity_kind,
+ RES_MATCH_RANGE: entity_range
+ }
+
+
+def _convert_range(rng):
+ if isinstance(rng, dict):
+ return rng
+ return {
+ "start": rng[0],
+ "end": rng[1]
+ }
diff --git a/snips_inference_agl/slot_filler/__init__.py b/snips_inference_agl/slot_filler/__init__.py
new file mode 100644
index 0000000..70974aa
--- /dev/null
+++ b/snips_inference_agl/slot_filler/__init__.py
@@ -0,0 +1,3 @@
+from .crf_slot_filler import CRFSlotFiller
+from .feature import Feature
+from .slot_filler import SlotFiller
diff --git a/snips_inference_agl/slot_filler/crf_slot_filler.py b/snips_inference_agl/slot_filler/crf_slot_filler.py
new file mode 100644
index 0000000..e6ec7e6
--- /dev/null
+++ b/snips_inference_agl/slot_filler/crf_slot_filler.py
@@ -0,0 +1,467 @@
+from __future__ import unicode_literals
+
+import base64
+import json
+import logging
+import math
+import os
+import shutil
+import tempfile
+from builtins import range
+from copy import deepcopy
+from pathlib import Path
+
+from future.utils import iteritems
+
+from snips_inference_agl.common.dataset_utils import get_slot_name_mapping
+from snips_inference_agl.common.dict_utils import UnupdatableDict
+from snips_inference_agl.common.io_utils import mkdir_p
+from snips_inference_agl.common.log_utils import DifferedLoggingMessage, log_elapsed_time
+from snips_inference_agl.common.utils import (
+ check_persisted_path, fitted_required, json_string)
+from snips_inference_agl.constants import DATA, LANGUAGE
+from snips_inference_agl.data_augmentation import augment_utterances
+from snips_inference_agl.dataset import validate_and_format_dataset
+from snips_inference_agl.exceptions import LoadingError
+from snips_inference_agl.pipeline.configs import CRFSlotFillerConfig
+from snips_inference_agl.preprocessing import tokenize
+from snips_inference_agl.slot_filler.crf_utils import (
+ OUTSIDE, TAGS, TOKENS, tags_to_slots, utterance_to_sample)
+from snips_inference_agl.slot_filler.feature import TOKEN_NAME
+from snips_inference_agl.slot_filler.feature_factory import CRFFeatureFactory
+from snips_inference_agl.slot_filler.slot_filler import SlotFiller
+
+CRF_MODEL_FILENAME = "model.crfsuite"
+
+logger = logging.getLogger(__name__)
+
+
+@SlotFiller.register("crf_slot_filler")
+class CRFSlotFiller(SlotFiller):
+ """Slot filler which uses Linear-Chain Conditional Random Fields underneath
+
+ Check https://en.wikipedia.org/wiki/Conditional_random_field to learn
+ more about CRFs
+ """
+
+ config_type = CRFSlotFillerConfig
+
+ def __init__(self, config=None, **shared):
+ """The CRF slot filler can be configured by passing a
+ :class:`.CRFSlotFillerConfig`"""
+ # The CRFSlotFillerConfig must be deep-copied as it is mutated when
+ # fitting the feature factories
+ config = deepcopy(config)
+ super(CRFSlotFiller, self).__init__(config, **shared)
+ self.crf_model = None
+ self.features_factories = [
+ CRFFeatureFactory.from_config(conf, **shared)
+ for conf in self.config.feature_factory_configs]
+ self._features = None
+ self.language = None
+ self.intent = None
+ self.slot_name_mapping = None
+
+ @property
+ def features(self):
+ """List of :class:`.Feature` used by the CRF"""
+ if self._features is None:
+ self._features = []
+ feature_names = set()
+ for factory in self.features_factories:
+ for feature in factory.build_features():
+ if feature.name in feature_names:
+ raise KeyError("Duplicated feature: %s" % feature.name)
+ feature_names.add(feature.name)
+ self._features.append(feature)
+ return self._features
+
+ @property
+ def labels(self):
+ """List of CRF labels
+
+ These labels differ from the slot names as they contain an additional
+ prefix which depends on the :class:`.TaggingScheme` that is used
+ (BIO by default).
+ """
+ labels = []
+ if self.crf_model.tagger_ is not None:
+ labels = [_decode_tag(label) for label in
+ self.crf_model.tagger_.labels()]
+ return labels
+
+ @property
+ def fitted(self):
+ """Whether or not the slot filler has already been fitted"""
+ return self.slot_name_mapping is not None
+
+ @log_elapsed_time(logger, logging.INFO,
+ "Fitted CRFSlotFiller in {elapsed_time}")
+ # pylint:disable=arguments-differ
+ def fit(self, dataset, intent):
+ """Fits the slot filler
+
+ Args:
+ dataset (dict): A valid Snips dataset
+ intent (str): The specific intent of the dataset to train
+ the slot filler on
+
+ Returns:
+ :class:`CRFSlotFiller`: The same instance, trained
+ """
+ logger.info("Fitting %s slot filler...", intent)
+ dataset = validate_and_format_dataset(dataset)
+ self.load_resources_if_needed(dataset[LANGUAGE])
+ self.fit_builtin_entity_parser_if_needed(dataset)
+ self.fit_custom_entity_parser_if_needed(dataset)
+
+ for factory in self.features_factories:
+ factory.custom_entity_parser = self.custom_entity_parser
+ factory.builtin_entity_parser = self.builtin_entity_parser
+ factory.resources = self.resources
+
+ self.language = dataset[LANGUAGE]
+ self.intent = intent
+ self.slot_name_mapping = get_slot_name_mapping(dataset, intent)
+
+ if not self.slot_name_mapping:
+ # No need to train the CRF if the intent has no slots
+ return self
+
+ augmented_intent_utterances = augment_utterances(
+ dataset, self.intent, language=self.language,
+ resources=self.resources, random_state=self.random_state,
+ **self.config.data_augmentation_config.to_dict())
+
+ crf_samples = [
+ utterance_to_sample(u[DATA], self.config.tagging_scheme,
+ self.language)
+ for u in augmented_intent_utterances]
+
+ for factory in self.features_factories:
+ factory.fit(dataset, intent)
+
+ # Ensure that X, Y are safe and that the OUTSIDE label is learnt to
+ # avoid segfault at inference time
+ # pylint: disable=C0103
+ X = [self.compute_features(sample[TOKENS], drop_out=True)
+ for sample in crf_samples]
+ Y = [[tag for tag in sample[TAGS]] for sample in crf_samples]
+ X, Y = _ensure_safe(X, Y)
+
+ # ensure ascii tags
+ Y = [[_encode_tag(tag) for tag in y] for y in Y]
+
+ # pylint: enable=C0103
+ self.crf_model = _get_crf_model(self.config.crf_args)
+ self.crf_model.fit(X, Y)
+
+ logger.debug(
+ "Most relevant features for %s:\n%s", self.intent,
+ DifferedLoggingMessage(self.log_weights))
+ return self
+
+ # pylint:enable=arguments-differ
+
+ @fitted_required
+ def get_slots(self, text):
+ """Extracts slots from the provided text
+
+ Returns:
+ list of dict: The list of extracted slots
+
+ Raises:
+ NotTrained: When the slot filler is not fitted
+ """
+ if not self.slot_name_mapping:
+ # Early return if the intent has no slots
+ return []
+
+ tokens = tokenize(text, self.language)
+ if not tokens:
+ return []
+ features = self.compute_features(tokens)
+ tags = self.crf_model.predict_single(features)
+ logger.debug(DifferedLoggingMessage(
+ self.log_inference_weights, text, tokens=tokens, features=features,
+ tags=tags))
+ decoded_tags = [_decode_tag(t) for t in tags]
+ return tags_to_slots(text, tokens, decoded_tags,
+ self.config.tagging_scheme,
+ self.slot_name_mapping)
+
+ def compute_features(self, tokens, drop_out=False):
+ """Computes features on the provided tokens
+
+ The *drop_out* parameters allows to activate drop out on features that
+ have a positive drop out ratio. This should only be used during
+ training.
+ """
+
+ cache = [{TOKEN_NAME: token} for token in tokens]
+ features = []
+ for i in range(len(tokens)):
+ token_features = UnupdatableDict()
+ for feature in self.features:
+ f_drop_out = feature.drop_out
+ if drop_out and self.random_state.rand() < f_drop_out:
+ continue
+ value = feature.compute(i, cache)
+ if value is not None:
+ token_features[feature.name] = value
+ features.append(token_features)
+ return features
+
+ @fitted_required
+ def get_sequence_probability(self, tokens, labels):
+ """Gives the joint probability of a sequence of tokens and CRF labels
+
+ Args:
+ tokens (list of :class:`.Token`): list of tokens
+ labels (list of str): CRF labels with their tagging scheme prefix
+ ("B-color", "I-color", "O", etc)
+
+ Note:
+ The absolute value returned here is generally not very useful,
+ however it can be used to compare a sequence of labels relatively
+ to another one.
+ """
+ if not self.slot_name_mapping:
+ return 0.0 if any(label != OUTSIDE for label in labels) else 1.0
+ features = self.compute_features(tokens)
+ return self._get_sequence_probability(features, labels)
+
+ @fitted_required
+ def _get_sequence_probability(self, features, labels):
+ # Use a default substitution label when a label was not seen during
+ # training
+ substitution_label = OUTSIDE if OUTSIDE in self.labels else \
+ self.labels[0]
+ cleaned_labels = [
+ _encode_tag(substitution_label if l not in self.labels else l)
+ for l in labels]
+ self.crf_model.tagger_.set(features)
+ return self.crf_model.tagger_.probability(cleaned_labels)
+
+ @fitted_required
+ def log_weights(self):
+ """Returns a logs for both the label-to-label and label-to-features
+ weights"""
+ if not self.slot_name_mapping:
+ return "No weights to display: intent '%s' has no slots" \
+ % self.intent
+ log = ""
+ transition_features = self.crf_model.transition_features_
+ transition_features = sorted(
+ iteritems(transition_features), key=_weight_absolute_value,
+ reverse=True)
+ log += "\nTransition weights: \n\n"
+ for (state_1, state_2), weight in transition_features:
+ log += "\n%s %s: %s" % (
+ _decode_tag(state_1), _decode_tag(state_2), weight)
+ feature_weights = self.crf_model.state_features_
+ feature_weights = sorted(
+ iteritems(feature_weights), key=_weight_absolute_value,
+ reverse=True)
+ log += "\n\nFeature weights: \n\n"
+ for (feat, tag), weight in feature_weights:
+ log += "\n%s %s: %s" % (feat, _decode_tag(tag), weight)
+ return log
+
+ def log_inference_weights(self, text, tokens, features, tags):
+ model_features = set(
+ f for (f, _), w in iteritems(self.crf_model.state_features_))
+ log = "Feature weights for \"%s\":\n\n" % text
+ max_index = len(tokens) - 1
+ tokens_logs = []
+ for i, (token, feats, tag) in enumerate(zip(tokens, features, tags)):
+ token_log = "# Token \"%s\" (tagged as %s):" \
+ % (token.value, _decode_tag(tag))
+ if i != 0:
+ weights = sorted(self._get_outgoing_weights(tags[i - 1]),
+ key=_weight_absolute_value, reverse=True)
+ if weights:
+ token_log += "\n\nTransition weights from previous tag:"
+ weight_lines = (
+ "- (%s, %s) -> %s"
+ % (_decode_tag(a), _decode_tag(b), w)
+ for (a, b), w in weights
+ )
+ token_log += "\n" + "\n".join(weight_lines)
+ else:
+ token_log += \
+ "\n\nNo transition from previous tag seen at" \
+ " train time !"
+
+ if i != max_index:
+ weights = sorted(self._get_incoming_weights(tags[i + 1]),
+ key=_weight_absolute_value, reverse=True)
+ if weights:
+ token_log += "\n\nTransition weights to next tag:"
+ weight_lines = (
+ "- (%s, %s) -> %s"
+ % (_decode_tag(a), _decode_tag(b), w)
+ for (a, b), w in weights
+ )
+ token_log += "\n" + "\n".join(weight_lines)
+ else:
+ token_log += \
+ "\n\nNo transition to next tag seen at train time !"
+ feats = [":".join(f) for f in iteritems(feats)]
+ weights = (w for f in feats for w in self._get_feature_weight(f))
+ weights = sorted(weights, key=_weight_absolute_value, reverse=True)
+ if weights:
+ token_log += "\n\nFeature weights:\n"
+ token_log += "\n".join(
+ "- (%s, %s) -> %s"
+ % (f, _decode_tag(t), w) for (f, t), w in weights
+ )
+ else:
+ token_log += "\n\nNo feature weights !"
+
+ unseen_features = sorted(
+ set(f for f in feats if f not in model_features))
+ if unseen_features:
+ token_log += "\n\nFeatures not seen at train time:\n%s" % \
+ "\n".join("- %s" % f for f in unseen_features)
+ tokens_logs.append(token_log)
+
+ log += "\n\n\n".join(tokens_logs)
+ return log
+
+ @fitted_required
+ def _get_incoming_weights(self, tag):
+ return [((first, second), w) for (first, second), w
+ in iteritems(self.crf_model.transition_features_)
+ if second == tag]
+
+ @fitted_required
+ def _get_outgoing_weights(self, tag):
+ return [((first, second), w) for (first, second), w
+ in iteritems(self.crf_model.transition_features_)
+ if first == tag]
+
+ @fitted_required
+ def _get_feature_weight(self, feature):
+ return [((f, tag), w) for (f, tag), w
+ in iteritems(self.crf_model.state_features_) if f == feature]
+
+ @check_persisted_path
+ def persist(self, path):
+ """Persists the object at the given path"""
+ path.mkdir()
+
+ crf_model_file = None
+ if self.crf_model is not None:
+ crf_model_file = CRF_MODEL_FILENAME
+ destination = path / crf_model_file
+ shutil.copy(self.crf_model.modelfile.name, str(destination))
+ # On windows, permissions of crfsuite files are correct
+ if os.name == "posix":
+ umask = os.umask(0o022) # retrieve the system umask
+ os.umask(umask) # restore the sys umask to its original value
+ os.chmod(str(destination), 0o644 & ~umask)
+
+ model = {
+ "language_code": self.language,
+ "intent": self.intent,
+ "crf_model_file": crf_model_file,
+ "slot_name_mapping": self.slot_name_mapping,
+ "config": self.config.to_dict(),
+ }
+ model_json = json_string(model)
+ model_path = path / "slot_filler.json"
+ with model_path.open(mode="w", encoding="utf8") as f:
+ f.write(model_json)
+ self.persist_metadata(path)
+
+ @classmethod
+ def from_path(cls, path, **shared):
+ """Loads a :class:`CRFSlotFiller` instance from a path
+
+ The data at the given path must have been generated using
+ :func:`~CRFSlotFiller.persist`
+ """
+ path = Path(path)
+ model_path = path / "slot_filler.json"
+ if not model_path.exists():
+ raise LoadingError(
+ "Missing slot filler model file: %s" % model_path.name)
+
+ with model_path.open(encoding="utf8") as f:
+ model = json.load(f)
+
+ slot_filler_config = cls.config_type.from_dict(model["config"])
+ slot_filler = cls(config=slot_filler_config, **shared)
+ slot_filler.language = model["language_code"]
+ slot_filler.intent = model["intent"]
+ slot_filler.slot_name_mapping = model["slot_name_mapping"]
+ crf_model_file = model["crf_model_file"]
+ if crf_model_file is not None:
+ crf = _crf_model_from_path(path / crf_model_file)
+ slot_filler.crf_model = crf
+ return slot_filler
+
+ def _cleanup(self):
+ if self.crf_model is not None:
+ self.crf_model.modelfile.cleanup()
+
+ def __del__(self):
+ self._cleanup()
+
+
+def _get_crf_model(crf_args):
+ from sklearn_crfsuite import CRF
+
+ model_filename = crf_args.get("model_filename", None)
+ if model_filename is not None:
+ directory = Path(model_filename).parent
+ if not directory.is_dir():
+ mkdir_p(directory)
+
+ return CRF(model_filename=model_filename, **crf_args)
+
+
+def _encode_tag(tag):
+ return base64.b64encode(tag.encode("utf8"))
+
+
+def _decode_tag(tag):
+ return base64.b64decode(tag).decode("utf8")
+
+
+def _crf_model_from_path(crf_model_path):
+ from sklearn_crfsuite import CRF
+
+ with crf_model_path.open(mode="rb") as f:
+ crf_model_data = f.read()
+ with tempfile.NamedTemporaryFile(suffix=".crfsuite", prefix="model",
+ delete=False) as f:
+ f.write(crf_model_data)
+ f.flush()
+ crf = CRF(model_filename=f.name)
+ return crf
+
+
+# pylint: disable=invalid-name
+def _ensure_safe(X, Y):
+ """Ensures that Y has at least one not empty label, otherwise the CRF model
+ does not contain any label and crashes at
+
+ Args:
+ X: features
+ Y: labels
+
+ Returns:
+ (safe_X, safe_Y): a pair of safe features and labels
+ """
+ safe_X = list(X)
+ safe_Y = list(Y)
+ if not any(X) or not any(Y):
+ safe_X.append([""]) # empty feature
+ safe_Y.append([OUTSIDE]) # outside label
+ return safe_X, safe_Y
+
+
+def _weight_absolute_value(x):
+ return math.fabs(x[1])
diff --git a/snips_inference_agl/slot_filler/crf_utils.py b/snips_inference_agl/slot_filler/crf_utils.py
new file mode 100644
index 0000000..817a59b
--- /dev/null
+++ b/snips_inference_agl/slot_filler/crf_utils.py
@@ -0,0 +1,219 @@
+from __future__ import unicode_literals
+
+from builtins import range
+from enum import Enum, unique
+
+from snips_inference_agl.constants import END, SLOT_NAME, START, TEXT
+from snips_inference_agl.preprocessing import Token, tokenize
+from snips_inference_agl.result import unresolved_slot
+
+BEGINNING_PREFIX = "B-"
+INSIDE_PREFIX = "I-"
+LAST_PREFIX = "L-"
+UNIT_PREFIX = "U-"
+OUTSIDE = "O"
+
+RANGE = "range"
+TAGS = "tags"
+TOKENS = "tokens"
+
+
+@unique
+class TaggingScheme(Enum):
+ """CRF Coding Scheme"""
+
+ IO = 0
+ """Inside-Outside scheme"""
+ BIO = 1
+ """Beginning-Inside-Outside scheme"""
+ BILOU = 2
+ """Beginning-Inside-Last-Outside-Unit scheme, sometimes referred as
+ BWEMO"""
+
+
+def tag_name_to_slot_name(tag):
+ return tag[2:]
+
+
+def start_of_io_slot(tags, i):
+ if i == 0:
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ return tags[i - 1] == OUTSIDE
+
+
+def end_of_io_slot(tags, i):
+ if i + 1 == len(tags):
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ return tags[i + 1] == OUTSIDE
+
+
+def start_of_bio_slot(tags, i):
+ if i == 0:
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ if tags[i].startswith(BEGINNING_PREFIX):
+ return True
+ if tags[i - 1] != OUTSIDE:
+ return False
+ return True
+
+
+def end_of_bio_slot(tags, i):
+ if i + 1 == len(tags):
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ if tags[i + 1].startswith(INSIDE_PREFIX):
+ return False
+ return True
+
+
+def start_of_bilou_slot(tags, i):
+ if i == 0:
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ if tags[i].startswith(BEGINNING_PREFIX):
+ return True
+ if tags[i].startswith(UNIT_PREFIX):
+ return True
+ if tags[i - 1].startswith(UNIT_PREFIX):
+ return True
+ if tags[i - 1].startswith(LAST_PREFIX):
+ return True
+ if tags[i - 1] != OUTSIDE:
+ return False
+ return True
+
+
+def end_of_bilou_slot(tags, i):
+ if i + 1 == len(tags):
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ if tags[i + 1] == OUTSIDE:
+ return True
+ if tags[i].startswith(LAST_PREFIX):
+ return True
+ if tags[i].startswith(UNIT_PREFIX):
+ return True
+ if tags[i + 1].startswith(BEGINNING_PREFIX):
+ return True
+ if tags[i + 1].startswith(UNIT_PREFIX):
+ return True
+ return False
+
+
+def _tags_to_preslots(tags, tokens, is_start_of_slot, is_end_of_slot):
+ slots = []
+ current_slot_start = 0
+ for i, tag in enumerate(tags):
+ if is_start_of_slot(tags, i):
+ current_slot_start = i
+ if is_end_of_slot(tags, i):
+ slots.append({
+ RANGE: {
+ START: tokens[current_slot_start].start,
+ END: tokens[i].end
+ },
+ SLOT_NAME: tag_name_to_slot_name(tag)
+ })
+ current_slot_start = i
+ return slots
+
+
+def tags_to_preslots(tokens, tags, tagging_scheme):
+ if tagging_scheme == TaggingScheme.IO:
+ slots = _tags_to_preslots(tags, tokens, start_of_io_slot,
+ end_of_io_slot)
+ elif tagging_scheme == TaggingScheme.BIO:
+ slots = _tags_to_preslots(tags, tokens, start_of_bio_slot,
+ end_of_bio_slot)
+ elif tagging_scheme == TaggingScheme.BILOU:
+ slots = _tags_to_preslots(tags, tokens, start_of_bilou_slot,
+ end_of_bilou_slot)
+ else:
+ raise ValueError("Unknown tagging scheme %s" % tagging_scheme)
+ return slots
+
+
+def tags_to_slots(text, tokens, tags, tagging_scheme, intent_slots_mapping):
+ slots = tags_to_preslots(tokens, tags, tagging_scheme)
+ return [
+ unresolved_slot(match_range=slot[RANGE],
+ value=text[slot[RANGE][START]:slot[RANGE][END]],
+ entity=intent_slots_mapping[slot[SLOT_NAME]],
+ slot_name=slot[SLOT_NAME])
+ for slot in slots
+ ]
+
+
+def positive_tagging(tagging_scheme, slot_name, slot_size):
+ if slot_name == OUTSIDE:
+ return [OUTSIDE for _ in range(slot_size)]
+
+ if tagging_scheme == TaggingScheme.IO:
+ tags = [INSIDE_PREFIX + slot_name for _ in range(slot_size)]
+ elif tagging_scheme == TaggingScheme.BIO:
+ if slot_size > 0:
+ tags = [BEGINNING_PREFIX + slot_name]
+ tags += [INSIDE_PREFIX + slot_name for _ in range(1, slot_size)]
+ else:
+ tags = []
+ elif tagging_scheme == TaggingScheme.BILOU:
+ if slot_size == 0:
+ tags = []
+ elif slot_size == 1:
+ tags = [UNIT_PREFIX + slot_name]
+ else:
+ tags = [BEGINNING_PREFIX + slot_name]
+ tags += [INSIDE_PREFIX + slot_name
+ for _ in range(1, slot_size - 1)]
+ tags.append(LAST_PREFIX + slot_name)
+ else:
+ raise ValueError("Invalid tagging scheme %s" % tagging_scheme)
+ return tags
+
+
+def negative_tagging(size):
+ return [OUTSIDE for _ in range(size)]
+
+
+def utterance_to_sample(query_data, tagging_scheme, language):
+ tokens, tags = [], []
+ current_length = 0
+ for chunk in query_data:
+ chunk_tokens = tokenize(chunk[TEXT], language)
+ tokens += [Token(t.value, current_length + t.start,
+ current_length + t.end) for t in chunk_tokens]
+ current_length += len(chunk[TEXT])
+ if SLOT_NAME not in chunk:
+ tags += negative_tagging(len(chunk_tokens))
+ else:
+ tags += positive_tagging(tagging_scheme, chunk[SLOT_NAME],
+ len(chunk_tokens))
+ return {TOKENS: tokens, TAGS: tags}
+
+
+def get_scheme_prefix(index, indexes, tagging_scheme):
+ if tagging_scheme == TaggingScheme.IO:
+ return INSIDE_PREFIX
+ elif tagging_scheme == TaggingScheme.BIO:
+ if index == indexes[0]:
+ return BEGINNING_PREFIX
+ return INSIDE_PREFIX
+ elif tagging_scheme == TaggingScheme.BILOU:
+ if len(indexes) == 1:
+ return UNIT_PREFIX
+ if index == indexes[0]:
+ return BEGINNING_PREFIX
+ if index == indexes[-1]:
+ return LAST_PREFIX
+ return INSIDE_PREFIX
+ else:
+ raise ValueError("Invalid tagging scheme %s" % tagging_scheme)
diff --git a/snips_inference_agl/slot_filler/feature.py b/snips_inference_agl/slot_filler/feature.py
new file mode 100644
index 0000000..a6da552
--- /dev/null
+++ b/snips_inference_agl/slot_filler/feature.py
@@ -0,0 +1,69 @@
+from __future__ import unicode_literals
+
+from builtins import object
+
+TOKEN_NAME = "token"
+
+
+class Feature(object):
+ """CRF Feature which is used by :class:`.CRFSlotFiller`
+
+ Attributes:
+ base_name (str): Feature name (e.g. 'is_digit', 'is_first' etc)
+ func (function): The actual feature function for example:
+
+ def is_first(tokens, token_index):
+ return "1" if token_index == 0 else None
+
+ offset (int, optional): Token offset to consider when computing
+ the feature (e.g -1 for computing the feature on the previous word)
+ drop_out (float, optional): Drop out to use when computing the
+ feature during training
+
+ Note:
+ The easiest way to add additional features to the existing ones is
+ to create a :class:`.CRFFeatureFactory`
+ """
+
+ def __init__(self, base_name, func, offset=0, drop_out=0):
+ if base_name == TOKEN_NAME:
+ raise ValueError("'%s' name is reserved" % TOKEN_NAME)
+ self.offset = offset
+ self._name = None
+ self._base_name = None
+ self.base_name = base_name
+ self.function = func
+ self.drop_out = drop_out
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def base_name(self):
+ return self._base_name
+
+ @base_name.setter
+ def base_name(self, value):
+ self._name = _offset_name(value, self.offset)
+ self._base_name = _offset_name(value, 0)
+
+ def compute(self, token_index, cache):
+ if not 0 <= (token_index + self.offset) < len(cache):
+ return None
+
+ if self.base_name in cache[token_index + self.offset]:
+ return cache[token_index + self.offset][self.base_name]
+
+ tokens = [c["token"] for c in cache]
+ value = self.function(tokens, token_index + self.offset)
+ cache[token_index + self.offset][self.base_name] = value
+ return value
+
+
+def _offset_name(name, offset):
+ if offset > 0:
+ return "%s[+%s]" % (name, offset)
+ if offset < 0:
+ return "%s[%s]" % (name, offset)
+ return name
diff --git a/snips_inference_agl/slot_filler/feature_factory.py b/snips_inference_agl/slot_filler/feature_factory.py
new file mode 100644
index 0000000..50f4598
--- /dev/null
+++ b/snips_inference_agl/slot_filler/feature_factory.py
@@ -0,0 +1,568 @@
+from __future__ import unicode_literals
+
+import logging
+
+from abc import ABCMeta, abstractmethod
+from builtins import str
+
+from future.utils import with_metaclass
+
+from snips_inference_agl.common.abc_utils import classproperty
+from snips_inference_agl.common.registrable import Registrable
+from snips_inference_agl.common.utils import check_random_state
+from snips_inference_agl.constants import (
+ CUSTOM_ENTITY_PARSER_USAGE, END, GAZETTEERS, LANGUAGE, RES_MATCH_RANGE,
+ START, STEMS, WORD_CLUSTERS, CUSTOM_ENTITY_PARSER, BUILTIN_ENTITY_PARSER,
+ RESOURCES, RANDOM_STATE, AUTOMATICALLY_EXTENSIBLE, ENTITIES)
+from snips_inference_agl.dataset import (
+ extract_intent_entities, get_dataset_gazetteer_entities)
+from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity
+from snips_inference_agl.entity_parser.custom_entity_parser import (
+ CustomEntityParserUsage)
+from snips_inference_agl.languages import get_default_sep
+from snips_inference_agl.preprocessing import Token, normalize_token, stem_token
+from snips_inference_agl.resources import get_gazetteer, get_word_cluster
+from snips_inference_agl.slot_filler.crf_utils import TaggingScheme, get_scheme_prefix
+from snips_inference_agl.slot_filler.feature import Feature
+from snips_inference_agl.slot_filler.features_utils import (
+ entity_filter, get_word_chunk, initial_string_from_tokens)
+
+logger = logging.getLogger(__name__)
+
+
+class CRFFeatureFactory(with_metaclass(ABCMeta, Registrable)):
+ """Abstraction to implement to build CRF features
+
+ A :class:`CRFFeatureFactory` is initialized with a dict which describes
+ the feature, it must contains the three following keys:
+
+ - 'factory_name'
+ - 'args': the parameters of the feature, if any
+ - 'offsets': the offsets to consider when using the feature in the CRF.
+ An empty list corresponds to no feature.
+
+
+ In addition, a 'drop_out' to use at training time can be specified.
+ """
+
+ def __init__(self, factory_config, **shared):
+ self.factory_config = factory_config
+ self.resources = shared.get(RESOURCES)
+ self.builtin_entity_parser = shared.get(BUILTIN_ENTITY_PARSER)
+ self.custom_entity_parser = shared.get(CUSTOM_ENTITY_PARSER)
+ self.random_state = check_random_state(shared.get(RANDOM_STATE))
+
+ @classmethod
+ def from_config(cls, factory_config, **shared):
+ """Retrieve the :class:`CRFFeatureFactory` corresponding the provided
+ config
+
+ Raises:
+ NotRegisteredError: when the factory is not registered
+ """
+ factory_name = factory_config["factory_name"]
+ factory = cls.by_name(factory_name)
+ return factory(factory_config, **shared)
+
+ @classproperty
+ def name(cls): # pylint:disable=no-self-argument
+ return CRFFeatureFactory.registered_name(cls)
+
+ @property
+ def args(self):
+ return self.factory_config["args"]
+
+ @property
+ def offsets(self):
+ return self.factory_config["offsets"]
+
+ @property
+ def drop_out(self):
+ return self.factory_config.get("drop_out", 0.0)
+
+ def fit(self, dataset, intent): # pylint: disable=unused-argument
+ """Fit the factory, if needed, with the provided *dataset* and *intent*
+ """
+ return self
+
+ @abstractmethod
+ def build_features(self):
+ """Build a list of :class:`.Feature`"""
+ pass
+
+ def get_required_resources(self):
+ return None
+
+
+class SingleFeatureFactory(with_metaclass(ABCMeta, CRFFeatureFactory)):
+ """A CRF feature factory which produces only one feature"""
+
+ @property
+ def feature_name(self):
+ # by default, use the factory name
+ return self.name
+
+ @abstractmethod
+ def compute_feature(self, tokens, token_index):
+ pass
+
+ def build_features(self):
+ return [
+ Feature(
+ base_name=self.feature_name,
+ func=self.compute_feature,
+ offset=offset,
+ drop_out=self.drop_out) for offset in self.offsets
+ ]
+
+
+@CRFFeatureFactory.register("is_digit")
+class IsDigitFactory(SingleFeatureFactory):
+ """Feature: is the considered token a digit?"""
+
+ def compute_feature(self, tokens, token_index):
+ return "1" if tokens[token_index].value.isdigit() else None
+
+
+@CRFFeatureFactory.register("is_first")
+class IsFirstFactory(SingleFeatureFactory):
+ """Feature: is the considered token the first in the input?"""
+
+ def compute_feature(self, tokens, token_index):
+ return "1" if token_index == 0 else None
+
+
+@CRFFeatureFactory.register("is_last")
+class IsLastFactory(SingleFeatureFactory):
+ """Feature: is the considered token the last in the input?"""
+
+ def compute_feature(self, tokens, token_index):
+ return "1" if token_index == len(tokens) - 1 else None
+
+
+@CRFFeatureFactory.register("ngram")
+class NgramFactory(SingleFeatureFactory):
+ """Feature: the n-gram consisting of the considered token and potentially
+ the following ones
+
+ This feature has several parameters:
+
+ - 'n' (int): Corresponds to the size of the n-gram. n=1 corresponds to a
+ unigram, n=2 is a bigram etc
+ - 'use_stemming' (bool): Whether or not to stem the n-gram
+ - 'common_words_gazetteer_name' (str, optional): If defined, use a
+ gazetteer of common words and replace out-of-corpus ngram with the
+ alias
+ 'rare_word'
+
+ """
+
+ def __init__(self, factory_config, **shared):
+ super(NgramFactory, self).__init__(factory_config, **shared)
+ self.n = self.args["n"]
+ if self.n < 1:
+ raise ValueError("n should be >= 1")
+
+ self.use_stemming = self.args["use_stemming"]
+ self.common_words_gazetteer_name = self.args[
+ "common_words_gazetteer_name"]
+ self._gazetteer = None
+ self._language = None
+ self.language = self.args.get("language_code")
+
+ @property
+ def language(self):
+ return self._language
+
+ @language.setter
+ def language(self, value):
+ if value is not None:
+ self._language = value
+ self.args["language_code"] = self.language
+
+ @property
+ def gazetteer(self):
+ # Load the gazetteer lazily
+ if self.common_words_gazetteer_name is None:
+ return None
+ if self._gazetteer is None:
+ self._gazetteer = get_gazetteer(
+ self.resources, self.common_words_gazetteer_name)
+ return self._gazetteer
+
+ @property
+ def feature_name(self):
+ return "ngram_%s" % self.n
+
+ def fit(self, dataset, intent):
+ self.language = dataset[LANGUAGE]
+
+ def compute_feature(self, tokens, token_index):
+ max_len = len(tokens)
+ end = token_index + self.n
+ if 0 <= token_index < max_len and end <= max_len:
+ if self.gazetteer is None:
+ if self.use_stemming:
+ stems = (stem_token(t, self.resources)
+ for t in tokens[token_index:end])
+ return get_default_sep(self.language).join(stems)
+ normalized_values = (normalize_token(t)
+ for t in tokens[token_index:end])
+ return get_default_sep(self.language).join(normalized_values)
+ words = []
+ for t in tokens[token_index:end]:
+ if self.use_stemming:
+ value = stem_token(t, self.resources)
+ else:
+ value = normalize_token(t)
+ words.append(value if value in self.gazetteer else "rare_word")
+ return get_default_sep(self.language).join(words)
+ return None
+
+ def get_required_resources(self):
+ resources = dict()
+ if self.common_words_gazetteer_name is not None:
+ resources[GAZETTEERS] = {self.common_words_gazetteer_name}
+ if self.use_stemming:
+ resources[STEMS] = True
+ return resources
+
+
+@CRFFeatureFactory.register("shape_ngram")
+class ShapeNgramFactory(SingleFeatureFactory):
+ """Feature: the shape of the n-gram consisting of the considered token and
+ potentially the following ones
+
+ This feature has one parameters, *n*, which corresponds to the size of the
+ n-gram.
+
+ Possible types of shape are:
+
+ - 'xxx' -> lowercased
+ - 'Xxx' -> Capitalized
+ - 'XXX' -> UPPERCASED
+ - 'xX' -> None of the above
+ """
+
+ def __init__(self, factory_config, **shared):
+ super(ShapeNgramFactory, self).__init__(factory_config, **shared)
+ self.n = self.args["n"]
+ if self.n < 1:
+ raise ValueError("n should be >= 1")
+ self._language = None
+ self.language = self.args.get("language_code")
+
+ @property
+ def language(self):
+ return self._language
+
+ @language.setter
+ def language(self, value):
+ if value is not None:
+ self._language = value
+ self.args["language_code"] = value
+
+ @property
+ def feature_name(self):
+ return "shape_ngram_%s" % self.n
+
+ def fit(self, dataset, intent):
+ self.language = dataset[LANGUAGE]
+
+ def compute_feature(self, tokens, token_index):
+ from snips_nlu_utils import get_shape
+
+ max_len = len(tokens)
+ end = token_index + self.n
+ if 0 <= token_index < max_len and end <= max_len:
+ return get_default_sep(self.language).join(
+ get_shape(t.value) for t in tokens[token_index:end])
+ return None
+
+
+@CRFFeatureFactory.register("word_cluster")
+class WordClusterFactory(SingleFeatureFactory):
+ """Feature: The cluster which the considered token belongs to, if any
+
+ This feature has several parameters:
+
+ - 'cluster_name' (str): the name of the word cluster to use
+ - 'use_stemming' (bool): whether or not to stem the token before looking
+ for its cluster
+
+ Typical words clusters are the Brown Clusters in which words are
+ clustered into a binary tree resulting in clusters of the form '100111001'
+ See https://en.wikipedia.org/wiki/Brown_clustering
+ """
+
+ def __init__(self, factory_config, **shared):
+ super(WordClusterFactory, self).__init__(factory_config, **shared)
+ self.cluster_name = self.args["cluster_name"]
+ self.use_stemming = self.args["use_stemming"]
+ self._cluster = None
+
+ @property
+ def cluster(self):
+ if self._cluster is None:
+ self._cluster = get_word_cluster(self.resources, self.cluster_name)
+ return self._cluster
+
+ @property
+ def feature_name(self):
+ return "word_cluster_%s" % self.cluster_name
+
+ def compute_feature(self, tokens, token_index):
+ if self.use_stemming:
+ value = stem_token(tokens[token_index], self.resources)
+ else:
+ value = normalize_token(tokens[token_index])
+ return self.cluster.get(value, None)
+
+ def get_required_resources(self):
+ return {
+ WORD_CLUSTERS: {self.cluster_name},
+ STEMS: self.use_stemming
+ }
+
+
+@CRFFeatureFactory.register("entity_match")
+class CustomEntityMatchFactory(CRFFeatureFactory):
+ """Features: does the considered token belongs to the values of one of the
+ entities in the training dataset
+
+ This factory builds as many features as there are entities in the dataset,
+ one per entity.
+
+ It has the following parameters:
+
+ - 'use_stemming' (bool): whether or not to stem the token before looking
+ for it among the (stemmed) entity values
+ - 'tagging_scheme_code' (int): Represents a :class:`.TaggingScheme`. This
+ allows to give more information about the match.
+ - 'entity_filter' (dict): a filter applied to select the custom entities
+ for which the custom match feature will be computed. Available
+ filters:
+ - 'automatically_extensible': if True, selects automatically
+ extensible entities only, if False selects non automatically
+ extensible entities only
+ """
+
+ def __init__(self, factory_config, **shared):
+ super(CustomEntityMatchFactory, self).__init__(factory_config,
+ **shared)
+ self.use_stemming = self.args["use_stemming"]
+ self.tagging_scheme = TaggingScheme(
+ self.args["tagging_scheme_code"])
+ self._entities = None
+ self.entities = self.args.get("entities")
+ ent_filter = self.args.get("entity_filter")
+ if ent_filter:
+ try:
+ _check_custom_entity_filter(ent_filter)
+ except _InvalidCustomEntityFilter as e:
+ logger.warning(
+ "Invalid filter '%s', invalid arguments have been ignored:"
+ " %s", ent_filter, e,
+ )
+ self.entity_filter = ent_filter or dict()
+
+ @property
+ def entities(self):
+ return self._entities
+
+ @entities.setter
+ def entities(self, value):
+ if value is not None:
+ self._entities = value
+ self.args["entities"] = value
+
+ def fit(self, dataset, intent):
+ entities_names = extract_intent_entities(
+ dataset, lambda e: not is_builtin_entity(e))[intent]
+ extensible = self.entity_filter.get(AUTOMATICALLY_EXTENSIBLE)
+ if extensible is not None:
+ entities_names = [
+ e for e in entities_names
+ if dataset[ENTITIES][e][AUTOMATICALLY_EXTENSIBLE] == extensible
+ ]
+ self.entities = list(entities_names)
+ return self
+
+ def _transform(self, tokens):
+ if self.use_stemming:
+ light_tokens = (stem_token(t, self.resources) for t in tokens)
+ else:
+ light_tokens = (normalize_token(t) for t in tokens)
+ current_index = 0
+ transformed_tokens = []
+ for light_token in light_tokens:
+ transformed_token = Token(
+ value=light_token,
+ start=current_index,
+ end=current_index + len(light_token))
+ transformed_tokens.append(transformed_token)
+ current_index = transformed_token.end + 1
+ return transformed_tokens
+
+ def build_features(self):
+ features = []
+ for entity_name in self.entities:
+ # We need to call this wrapper in order to properly capture
+ # `entity_name`
+ entity_match = self._build_entity_match_fn(entity_name)
+
+ for offset in self.offsets:
+ feature = Feature("entity_match_%s" % entity_name,
+ entity_match, offset, self.drop_out)
+ features.append(feature)
+ return features
+
+ def _build_entity_match_fn(self, entity):
+
+ def entity_match(tokens, token_index):
+ transformed_tokens = self._transform(tokens)
+ text = initial_string_from_tokens(transformed_tokens)
+ token_start = transformed_tokens[token_index].start
+ token_end = transformed_tokens[token_index].end
+ custom_entities = self.custom_entity_parser.parse(
+ text, scope=[entity], use_cache=True)
+ # only keep builtin entities (of type `entity`) which overlap with
+ # the current token
+ custom_entities = [ent for ent in custom_entities
+ if entity_filter(ent, token_start, token_end)]
+ if custom_entities:
+ # In most cases, 0 or 1 entity will be found. We fall back to
+ # the first entity if 2 or more were found
+ ent = custom_entities[0]
+ indexes = []
+ for index, token in enumerate(transformed_tokens):
+ if entity_filter(ent, token.start, token.end):
+ indexes.append(index)
+ return get_scheme_prefix(token_index, indexes,
+ self.tagging_scheme)
+ return None
+
+ return entity_match
+
+ def get_required_resources(self):
+ if self.use_stemming:
+ return {
+ STEMS: True,
+ CUSTOM_ENTITY_PARSER_USAGE: CustomEntityParserUsage.WITH_STEMS
+ }
+ return {
+ STEMS: False,
+ CUSTOM_ENTITY_PARSER_USAGE:
+ CustomEntityParserUsage.WITHOUT_STEMS
+ }
+
+
+class _InvalidCustomEntityFilter(ValueError):
+ pass
+
+
+CUSTOM_ENTITIES_FILTER_KEYS = {"automatically_extensible"}
+
+
+# pylint: disable=redefined-outer-name
+def _check_custom_entity_filter(entity_filter):
+ for k in entity_filter:
+ if k not in CUSTOM_ENTITIES_FILTER_KEYS:
+ msg = "Invalid custom entity filter key '%s'. Accepted filter " \
+ "keys are %s" % (k, list(CUSTOM_ENTITIES_FILTER_KEYS))
+ raise _InvalidCustomEntityFilter(msg)
+
+
+@CRFFeatureFactory.register("builtin_entity_match")
+class BuiltinEntityMatchFactory(CRFFeatureFactory):
+ """Features: is the considered token part of a builtin entity such as a
+ date, a temperature etc
+
+ This factory builds as many features as there are builtin entities
+ available in the considered language.
+
+ It has one parameter, *tagging_scheme_code*, which represents a
+ :class:`.TaggingScheme`. This allows to give more information about the
+ match.
+ """
+
+ def __init__(self, factory_config, **shared):
+ super(BuiltinEntityMatchFactory, self).__init__(factory_config,
+ **shared)
+ self.tagging_scheme = TaggingScheme(
+ self.args["tagging_scheme_code"])
+ self.builtin_entities = None
+ self.builtin_entities = self.args.get("entity_labels")
+ self._language = None
+ self.language = self.args.get("language_code")
+
+ @property
+ def language(self):
+ return self._language
+
+ @language.setter
+ def language(self, value):
+ if value is not None:
+ self._language = value
+ self.args["language_code"] = self.language
+
+ def fit(self, dataset, intent):
+ self.language = dataset[LANGUAGE]
+ self.builtin_entities = sorted(
+ self._get_builtin_entity_scope(dataset, intent))
+ self.args["entity_labels"] = self.builtin_entities
+
+ def build_features(self):
+ features = []
+
+ for builtin_entity in self.builtin_entities:
+ # We need to call this wrapper in order to properly capture
+ # `builtin_entity`
+ builtin_entity_match = self._build_entity_match_fn(builtin_entity)
+ for offset in self.offsets:
+ feature_name = "builtin_entity_match_%s" % builtin_entity
+ feature = Feature(feature_name, builtin_entity_match, offset,
+ self.drop_out)
+ features.append(feature)
+
+ return features
+
+ def _build_entity_match_fn(self, builtin_entity):
+
+ def builtin_entity_match(tokens, token_index):
+ text = initial_string_from_tokens(tokens)
+ start = tokens[token_index].start
+ end = tokens[token_index].end
+
+ builtin_entities = self.builtin_entity_parser.parse(
+ text, scope=[builtin_entity], use_cache=True)
+ # only keep builtin entities (of type `builtin_entity`) which
+ # overlap with the current token
+ builtin_entities = [ent for ent in builtin_entities
+ if entity_filter(ent, start, end)]
+ if builtin_entities:
+ # In most cases, 0 or 1 entity will be found. We fall back to
+ # the first entity if 2 or more were found
+ ent = builtin_entities[0]
+ entity_start = ent[RES_MATCH_RANGE][START]
+ entity_end = ent[RES_MATCH_RANGE][END]
+ indexes = []
+ for index, token in enumerate(tokens):
+ if (entity_start <= token.start < entity_end) \
+ and (entity_start < token.end <= entity_end):
+ indexes.append(index)
+ return get_scheme_prefix(token_index, indexes,
+ self.tagging_scheme)
+ return None
+
+ return builtin_entity_match
+
+ @staticmethod
+ def _get_builtin_entity_scope(dataset, intent=None):
+ from snips_nlu_parsers import get_supported_grammar_entities
+
+ language = dataset[LANGUAGE]
+ grammar_entities = list(get_supported_grammar_entities(language))
+ gazetteer_entities = list(
+ get_dataset_gazetteer_entities(dataset, intent))
+ return grammar_entities + gazetteer_entities
diff --git a/snips_inference_agl/slot_filler/features_utils.py b/snips_inference_agl/slot_filler/features_utils.py
new file mode 100644
index 0000000..483e9c0
--- /dev/null
+++ b/snips_inference_agl/slot_filler/features_utils.py
@@ -0,0 +1,47 @@
+from __future__ import unicode_literals
+
+from copy import deepcopy
+
+from snips_inference_agl.common.dict_utils import LimitedSizeDict
+from snips_inference_agl.constants import END, RES_MATCH_RANGE, START
+
+_NGRAMS_CACHE = LimitedSizeDict(size_limit=1000)
+
+
+def get_all_ngrams(tokens):
+ from snips_nlu_utils import compute_all_ngrams
+
+ if not tokens:
+ return []
+ key = "<||>".join(tokens)
+ if key not in _NGRAMS_CACHE:
+ ngrams = compute_all_ngrams(tokens, len(tokens))
+ _NGRAMS_CACHE[key] = ngrams
+ return deepcopy(_NGRAMS_CACHE[key])
+
+
+def get_word_chunk(word, chunk_size, chunk_start, reverse=False):
+ if chunk_size < 1:
+ raise ValueError("chunk size should be >= 1")
+ if chunk_size > len(word):
+ return None
+ start = chunk_start - chunk_size if reverse else chunk_start
+ end = chunk_start if reverse else chunk_start + chunk_size
+ return word[start:end]
+
+
+def initial_string_from_tokens(tokens):
+ current_index = 0
+ s = ""
+ for t in tokens:
+ if t.start > current_index:
+ s += " " * (t.start - current_index)
+ s += t.value
+ current_index = t.end
+ return s
+
+
+def entity_filter(entity, start, end):
+ entity_start = entity[RES_MATCH_RANGE][START]
+ entity_end = entity[RES_MATCH_RANGE][END]
+ return entity_start <= start < end <= entity_end
diff --git a/snips_inference_agl/slot_filler/keyword_slot_filler.py b/snips_inference_agl/slot_filler/keyword_slot_filler.py
new file mode 100644
index 0000000..087d997
--- /dev/null
+++ b/snips_inference_agl/slot_filler/keyword_slot_filler.py
@@ -0,0 +1,70 @@
+from __future__ import unicode_literals
+
+import json
+
+from snips_inference_agl.common.utils import json_string
+from snips_inference_agl.preprocessing import tokenize
+from snips_inference_agl.result import unresolved_slot
+from snips_inference_agl.slot_filler import SlotFiller
+
+
+@SlotFiller.register("keyword_slot_filler")
+class KeywordSlotFiller(SlotFiller):
+ def __init__(self, config=None, **shared):
+ super(KeywordSlotFiller, self).__init__(config, **shared)
+ self.slots_keywords = None
+ self.language = None
+
+ @property
+ def fitted(self):
+ return self.slots_keywords is not None
+
+ def fit(self, dataset, intent):
+ self.language = dataset["language"]
+ self.slots_keywords = dict()
+ utterances = dataset["intents"][intent]["utterances"]
+ for utterance in utterances:
+ for chunk in utterance["data"]:
+ if "slot_name" in chunk:
+ text = chunk["text"]
+ if self.config.get("lowercase", False):
+ text = text.lower()
+ self.slots_keywords[text] = [
+ chunk["entity"],
+ chunk["slot_name"]
+ ]
+ return self
+
+ def get_slots(self, text):
+ tokens = tokenize(text, self.language)
+ slots = []
+ for token in tokens:
+ normalized_value = token.value
+ if self.config.get("lowercase", False):
+ normalized_value = normalized_value.lower()
+ if normalized_value in self.slots_keywords:
+ entity = self.slots_keywords[normalized_value][0]
+ slot_name = self.slots_keywords[normalized_value][1]
+ slot = unresolved_slot((token.start, token.end), token.value,
+ entity, slot_name)
+ slots.append(slot)
+ return slots
+
+ def persist(self, path):
+ model = {
+ "language": self.language,
+ "slots_keywords": self.slots_keywords,
+ "config": self.config.to_dict()
+ }
+ with path.open(mode="w", encoding="utf8") as f:
+ f.write(json_string(model))
+
+ @classmethod
+ def from_path(cls, path, **shared):
+ with path.open() as f:
+ model = json.load(f)
+ slot_filler = cls()
+ slot_filler.language = model["language"]
+ slot_filler.slots_keywords = model["slots_keywords"]
+ slot_filler.config = cls.config_type.from_dict(model["config"])
+ return slot_filler
diff --git a/snips_inference_agl/slot_filler/slot_filler.py b/snips_inference_agl/slot_filler/slot_filler.py
new file mode 100644
index 0000000..a1fc937
--- /dev/null
+++ b/snips_inference_agl/slot_filler/slot_filler.py
@@ -0,0 +1,33 @@
+from abc import abstractmethod, ABCMeta
+
+from future.utils import with_metaclass
+
+from snips_inference_agl.common.abc_utils import classproperty
+from snips_inference_agl.pipeline.processing_unit import ProcessingUnit
+
+
+class SlotFiller(with_metaclass(ABCMeta, ProcessingUnit)):
+ """Abstraction which performs slot filling
+
+ A custom slot filler must inherit this class to be used in a
+ :class:`.ProbabilisticIntentParser`
+ """
+
+ @classproperty
+ def unit_name(cls): # pylint:disable=no-self-argument
+ return SlotFiller.registered_name(cls)
+
+ @abstractmethod
+ def fit(self, dataset, intent):
+ """Fit the slot filler with a valid Snips dataset"""
+ pass
+
+ @abstractmethod
+ def get_slots(self, text):
+ """Performs slot extraction (slot filling) on the provided *text*
+
+ Returns:
+ list of dict: The list of extracted slots. See
+ :func:`.unresolved_slot` for the output format of a slot
+ """
+ pass
diff --git a/snips_inference_agl/string_variations.py b/snips_inference_agl/string_variations.py
new file mode 100644
index 0000000..f65e34e
--- /dev/null
+++ b/snips_inference_agl/string_variations.py
@@ -0,0 +1,195 @@
+# coding=utf-8
+from __future__ import unicode_literals
+
+import itertools
+import re
+from builtins import range, str, zip
+
+from future.utils import iteritems
+
+from snips_inference_agl.constants import (
+ END, LANGUAGE_DE, LANGUAGE_EN, LANGUAGE_ES, LANGUAGE_FR, RESOLVED_VALUE,
+ RES_MATCH_RANGE, SNIPS_NUMBER, START, VALUE)
+from snips_inference_agl.languages import (
+ get_default_sep, get_punctuation_regex, supports_num2words)
+from snips_inference_agl.preprocessing import tokenize_light
+
+AND_UTTERANCES = {
+ LANGUAGE_EN: ["and", "&"],
+ LANGUAGE_FR: ["et", "&"],
+ LANGUAGE_ES: ["y", "&"],
+ LANGUAGE_DE: ["und", "&"],
+}
+
+AND_REGEXES = {
+ language: re.compile(
+ r"|".join(r"(?<=\s)%s(?=\s)" % re.escape(u) for u in utterances),
+ re.IGNORECASE)
+ for language, utterances in iteritems(AND_UTTERANCES)
+}
+
+MAX_ENTITY_VARIATIONS = 10
+
+
+def build_variated_query(string, ranges_and_utterances):
+ variated_string = ""
+ current_ix = 0
+ for rng, u in ranges_and_utterances:
+ start = rng[START]
+ end = rng[END]
+ variated_string += string[current_ix:start]
+ variated_string += u
+ current_ix = end
+ variated_string += string[current_ix:]
+ return variated_string
+
+
+def and_variations(string, language):
+ and_regex = AND_REGEXES.get(language, None)
+ if and_regex is None:
+ return set()
+
+ matches = [m for m in and_regex.finditer(string)]
+ if not matches:
+ return set()
+
+ matches = sorted(matches, key=lambda x: x.start())
+ and_utterances = AND_UTTERANCES[language]
+ values = [({START: m.start(), END: m.end()}, and_utterances)
+ for m in matches]
+
+ n_values = len(values)
+ n_and_utterances = len(and_utterances)
+ if n_and_utterances ** n_values > MAX_ENTITY_VARIATIONS:
+ return set()
+
+ combinations = itertools.product(range(n_and_utterances), repeat=n_values)
+ variations = set()
+ for c in combinations:
+ ranges_and_utterances = [(values[i][0], values[i][1][ix])
+ for i, ix in enumerate(c)]
+ variations.add(build_variated_query(string, ranges_and_utterances))
+ return variations
+
+
+def punctuation_variations(string, language):
+ matches = [m for m in get_punctuation_regex(language).finditer(string)]
+ if not matches:
+ return set()
+
+ matches = sorted(matches, key=lambda x: x.start())
+ values = [({START: m.start(), END: m.end()}, (m.group(0), ""))
+ for m in matches]
+
+ n_values = len(values)
+ if 2 ** n_values > MAX_ENTITY_VARIATIONS:
+ return set()
+
+ combinations = itertools.product(range(2), repeat=n_values)
+ variations = set()
+ for c in combinations:
+ ranges_and_utterances = [(values[i][0], values[i][1][ix])
+ for i, ix in enumerate(c)]
+ variations.add(build_variated_query(string, ranges_and_utterances))
+ return variations
+
+
+def digit_value(number_entity):
+ value = number_entity[RESOLVED_VALUE][VALUE]
+ if value == int(value):
+ # Convert 24.0 into "24" instead of "24.0"
+ value = int(value)
+ return str(value)
+
+
+def alphabetic_value(number_entity, language):
+ from num2words import num2words
+
+ value = number_entity[RESOLVED_VALUE][VALUE]
+ if value != int(value): # num2words does not handle floats correctly
+ return None
+ return num2words(int(value), lang=language)
+
+
+def numbers_variations(string, language, builtin_entity_parser):
+ if not supports_num2words(language):
+ return set()
+
+ number_entities = builtin_entity_parser.parse(
+ string, scope=[SNIPS_NUMBER], use_cache=True)
+
+ number_entities = sorted(number_entities,
+ key=lambda x: x[RES_MATCH_RANGE][START])
+ if not number_entities:
+ return set()
+
+ digit_values = [digit_value(e) for e in number_entities]
+ alpha_values = [alphabetic_value(e, language) for e in number_entities]
+
+ values = [(n[RES_MATCH_RANGE], (d, a)) for (n, d, a) in
+ zip(number_entities, digit_values, alpha_values)
+ if a is not None]
+
+ n_values = len(values)
+ if 2 ** n_values > MAX_ENTITY_VARIATIONS:
+ return set()
+
+ combinations = itertools.product(range(2), repeat=n_values)
+ variations = set()
+ for c in combinations:
+ ranges_and_utterances = [(values[i][0], values[i][1][ix])
+ for i, ix in enumerate(c)]
+ variations.add(build_variated_query(string, ranges_and_utterances))
+ return variations
+
+
+def case_variations(string):
+ return {string.lower(), string.title()}
+
+
+def normalization_variations(string):
+ from snips_nlu_utils import normalize
+
+ return {normalize(string)}
+
+
+def flatten(results):
+ return set(i for r in results for i in r)
+
+
+def get_string_variations(string, language, builtin_entity_parser,
+ numbers=True, case=True, and_=True,
+ punctuation=True):
+ variations = {string}
+ if case:
+ variations.update(flatten(case_variations(v) for v in variations))
+
+ variations.update(flatten(normalization_variations(v) for v in variations))
+ # We re-generate case variations as normalization can produce new
+ # variations
+ if case:
+ variations.update(flatten(case_variations(v) for v in variations))
+ if and_:
+ variations.update(
+ flatten(and_variations(v, language) for v in variations))
+ if punctuation:
+ variations.update(
+ flatten(punctuation_variations(v, language) for v in variations))
+
+ # Special case of number variation which are long to generate due to the
+ # BuilinEntityParser running on each variation
+ if numbers:
+ variations.update(
+ flatten(numbers_variations(v, language, builtin_entity_parser)
+ for v in variations)
+ )
+
+ # Add single space variations
+ single_space_variations = set(" ".join(v.split()) for v in variations)
+ variations.update(single_space_variations)
+ # Add tokenized variations
+ tokenized_variations = set(
+ get_default_sep(language).join(tokenize_light(v, language)) for v in
+ variations)
+ variations.update(tokenized_variations)
+ return variations