diff options
81 files changed, 10057 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..c8dbf9d --- /dev/null +++ b/.gitignore @@ -0,0 +1,32 @@ +# Byte-compiled / optimized files +__pycache__/ +*.py[cod] +*$py.class + +# Distribution / packaging +dist/ +build/ +*.egg-info/ +*.egg + +# Virtual environments +venv/ +venv390/ +venv310/ +env/ +.env/ + +# IDE specific files +.idea/ +.vscode/ + +# Logs and databases +*.log +*.sqlite3 + +# OS specific files +.DS_Store +Thumbs.db + +# Unnecessary scripts +temp_*.sh
\ No newline at end of file diff --git a/.gitreview b/.gitreview new file mode 100644 index 0000000..eb0e0d3 --- /dev/null +++ b/.gitreview @@ -0,0 +1,5 @@ +[gerrit] +host=gerrit.automotivelinux.org +port=29418 +project=src/snips-inference-agl +defaultbranch=master diff --git a/CONTRIBUTORS.rst b/CONTRIBUTORS.rst new file mode 100644 index 0000000..2c54baf --- /dev/null +++ b/CONTRIBUTORS.rst @@ -0,0 +1,16 @@ +Contributors +============ + + +This is the list of everyone who made contributions to the Snips inference AGL library. + +* `Malik Talha <https://github.com/malik727>`_ + + +This is the list of everyone who has made contributions to the original Snips NLU which is used by Snips Inference AGL as a foundation. + +* `Alice Coucke <https://github.com/choufractal>`_ +* `cclauss <https://github.com/cclauss>`_ +* `ddorian <https://github.com/ddorian>`_ +* `Josh Meyer <https://github.com/JRMeyer>`_ +* `Matthieu Brouillard <https://github.com/McFoggy>`_ @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023 Malik Talha, Snips Open Source Community + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000..4ad28ce --- /dev/null +++ b/README.md @@ -0,0 +1,2 @@ +# Snips Inference AGL +Inferency only module of the original Snips NLU library. This module is also compatible with Python version upto 3.10.
\ No newline at end of file diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..92178e2 --- /dev/null +++ b/setup.py @@ -0,0 +1,57 @@ +import io +import os + +from setuptools import setup, find_packages + +packages = [p for p in find_packages() + if "tests" not in p and "debug" not in p] + +root = os.path.abspath(os.path.dirname(__file__)) + +with io.open(os.path.join(root, "snips_inference_agl", "__about__.py"), + encoding="utf8") as f: + about = dict() + exec(f.read(), about) + +required = [ + "deprecation>=2.0,<3.0", + "future>=0.16,<0.18", + "numpy>=1.22.0,<1.22.4", + "num2words>=0.5.6,<0.6", + "pyaml>=17.0,<20.0", + "requests>=2.0,<3.0", + "scipy>=1.8.0,<1.9.0", + "threadpoolctl>=2.0.0", + "scikit-learn==0.24.2", + "sklearn-crfsuite>=0.3.6,<0.4", + "snips-nlu-parsers>=0.4.3,<0.4.4", + "snips-nlu-utils>=0.9.1,<0.9.2", +] + +setup(name=about["__title__"], + description=about["__summary__"], + version=about["__version__"], + author=about["__author__"], + author_email=about["__email__"], + license=about["__license__"], + url=about["__github_url__"], + project_urls={ + "Source": about["__github_url__"], + "Tracker": about["__tracker_url__"], + }, + install_requires=required, + classifiers=[ + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + ], + keywords="nlu nlp language machine learning text processing intent", + packages=packages, + python_requires='>=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, <4', + include_package_data=True, + entry_points={ + "console_scripts": [ + "snips-inference=snips_inference_agl.cli:main" + ] + }, + zip_safe=False) diff --git a/snips_inference_agl/__about__.py b/snips_inference_agl/__about__.py new file mode 100644 index 0000000..eec36b5 --- /dev/null +++ b/snips_inference_agl/__about__.py @@ -0,0 +1,24 @@ +# inspired from: +# https://python-packaging-user-guide.readthedocs.io/guides/single-sourcing-package-version/ +# https://github.com/pypa/warehouse/blob/master/warehouse/__about__.py + +# pylint:disable=line-too-long + +__title__ = "snips_inference_agl" +__summary__ = "A modified version of original Snips NLU library for NLU inference on AGL platform." +__github_url__ = "https://github.com/snipsco/snips-nlu" +__tracker_url__ = "https://github.com/snipsco/snips-nlu/issues" +__author__ = "Clement Doumouro, Adrien Ball" +__email__ = "clement.doumouro@snips.ai, adrien.ball@snips.ai" +__license__ = "Apache License, Version 2.0" + +__version__ = "0.20.2" +__model_version__ = "0.20.0" + +__download_url__ = "https://github.com/snipsco/snips-nlu-language-resources/releases/download" +__compatibility__ = "https://raw.githubusercontent.com/snipsco/snips-nlu-language-resources/master/compatibility.json" +__shortcuts__ = "https://raw.githubusercontent.com/snipsco/snips-nlu-language-resources/master/shortcuts.json" + +__entities_download_url__ = "https://resources.snips.ai/nlu/gazetteer-entities" + +# pylint:enable=line-too-long diff --git a/snips_inference_agl/__init__.py b/snips_inference_agl/__init__.py new file mode 100644 index 0000000..fe05b12 --- /dev/null +++ b/snips_inference_agl/__init__.py @@ -0,0 +1,15 @@ +from deprecation import deprecated + +from snips_inference_agl.__about__ import __model_version__, __version__ +from snips_inference_agl.nlu_engine import SnipsNLUEngine +from snips_inference_agl.pipeline.configs import NLUEngineConfig + + +@deprecated(deprecated_in="0.19.7", removed_in="0.21.0", + current_version=__version__, + details="Loading resources in the client code is no longer " + "required") +def load_resources(name, required_resources=None): + from snips_inference_agl.resources import load_resources as _load_resources + + return _load_resources(name, required_resources) diff --git a/snips_inference_agl/__main__.py b/snips_inference_agl/__main__.py new file mode 100644 index 0000000..7cf8cfe --- /dev/null +++ b/snips_inference_agl/__main__.py @@ -0,0 +1,6 @@ +from __future__ import print_function, unicode_literals + + +if __name__ == "__main__": + from snips_inference_agl.cli import main + main() diff --git a/snips_inference_agl/cli/__init__.py b/snips_inference_agl/cli/__init__.py new file mode 100644 index 0000000..ccfbf18 --- /dev/null +++ b/snips_inference_agl/cli/__init__.py @@ -0,0 +1,39 @@ +import argparse + + +class Formatter(argparse.ArgumentDefaultsHelpFormatter): + def __init__(self, prog): + super(Formatter, self).__init__(prog, max_help_position=35, width=150) + + +def get_arg_parser(): + from snips_inference_agl.cli.inference import add_parse_parser + from snips_inference_agl.cli.versions import ( + add_version_parser, add_model_version_parser) + + arg_parser = argparse.ArgumentParser( + description="Snips NLU command line interface", + prog="python -m snips_nlu", formatter_class=Formatter) + arg_parser.add_argument("-v", "--version", action="store_true", + help="Print package version") + subparsers = arg_parser.add_subparsers( + title="available commands", metavar="command [options ...]") + add_parse_parser(subparsers, formatter_class=Formatter) + add_version_parser(subparsers, formatter_class=Formatter) + add_model_version_parser(subparsers, formatter_class=Formatter) + return arg_parser + + +def main(): + from snips_inference_agl.__about__ import __version__ + + arg_parser = get_arg_parser() + args = arg_parser.parse_args() + + if hasattr(args, "func"): + args.func(args) + elif "version" in args: + print(__version__) + else: + arg_parser.print_help() + exit(1) diff --git a/snips_inference_agl/cli/inference.py b/snips_inference_agl/cli/inference.py new file mode 100644 index 0000000..4ef5618 --- /dev/null +++ b/snips_inference_agl/cli/inference.py @@ -0,0 +1,66 @@ +from __future__ import unicode_literals, print_function + + +def add_parse_parser(subparsers, formatter_class): + subparser = subparsers.add_parser( + "parse", formatter_class=formatter_class, + help="Load a trained NLU engine and perform parsing") + subparser.add_argument("training_path", type=str, + help="Path to a trained engine") + subparser.add_argument("-q", "--query", type=str, + help="Query to parse. If provided, it disables the " + "interactive behavior.") + subparser.add_argument("-v", "--verbosity", action="count", default=0, + help="Increase output verbosity") + subparser.add_argument("-f", "--intents-filter", type=str, + help="Intents filter as a comma-separated list") + subparser.set_defaults(func=_parse) + return subparser + + +def _parse(args_namespace): + return parse(args_namespace.training_path, args_namespace.query, + args_namespace.verbosity, args_namespace.intents_filter) + + +def parse(training_path, query, verbose=False, intents_filter=None): + """Load a trained NLU engine and play with its parsing API interactively""" + import csv + import logging + from builtins import input, str + from snips_inference_agl import SnipsNLUEngine + from snips_inference_agl.cli.utils import set_nlu_logger + + if verbose == 1: + set_nlu_logger(logging.INFO) + elif verbose >= 2: + set_nlu_logger(logging.DEBUG) + if intents_filter: + # use csv in order to properly handle commas and other special + # characters in intent names + intents_filter = next(csv.reader([intents_filter])) + else: + intents_filter = None + + engine = SnipsNLUEngine.from_path(training_path) + + if query: + print_parsing_result(engine, query, intents_filter) + return + + while True: + query = input("Enter a query (type 'q' to quit): ").strip() + if not isinstance(query, str): + query = query.decode("utf-8") + if query == "q": + break + print_parsing_result(engine, query, intents_filter) + + +def print_parsing_result(engine, query, intents_filter): + from snips_inference_agl.common.utils import unicode_string, json_string + + query = unicode_string(query) + json_dump = json_string(engine.parse(query, intents_filter), + sort_keys=True, indent=2) + print(json_dump) diff --git a/snips_inference_agl/cli/utils.py b/snips_inference_agl/cli/utils.py new file mode 100644 index 0000000..0e5464c --- /dev/null +++ b/snips_inference_agl/cli/utils.py @@ -0,0 +1,79 @@ +from __future__ import print_function, unicode_literals + +import logging +import sys +from enum import Enum, unique + +import requests + +import snips_inference_agl +from snips_inference_agl import __about__ + + +@unique +class PrettyPrintLevel(Enum): + INFO = 0 + WARNING = 1 + ERROR = 2 + SUCCESS = 3 + + +FMT = "[%(levelname)s][%(asctime)s.%(msecs)03d][%(name)s]: %(message)s" +DATE_FMT = "%H:%M:%S" + + +def pretty_print(*texts, **kwargs): + """Print formatted message + + Args: + *texts (str): Texts to print. Each argument is rendered as paragraph. + **kwargs: 'title' becomes coloured headline. exits=True performs sys + exit. + """ + exits = kwargs.get("exits") + title = kwargs.get("title") + level = kwargs.get("level", PrettyPrintLevel.INFO) + title_color = _color_from_level(level) + if title: + title = "\033[{color}m{title}\033[0m\n".format(title=title, + color=title_color) + else: + title = "" + message = "\n\n".join([text for text in texts]) + print("\n{title}{message}\n".format(title=title, message=message)) + if exits is not None: + sys.exit(exits) + + +def _color_from_level(level): + if level == PrettyPrintLevel.INFO: + return "92" + if level == PrettyPrintLevel.WARNING: + return "93" + if level == PrettyPrintLevel.ERROR: + return "91" + if level == PrettyPrintLevel.SUCCESS: + return "92" + else: + raise ValueError("Unknown PrettyPrintLevel: %s" % level) + + +def get_json(url, desc): + r = requests.get(url) + if r.status_code != 200: + raise OSError("%s: Received status code %s when fetching the resource" + % (desc, r.status_code)) + return r.json() + + +def set_nlu_logger(level=logging.INFO): + logger = logging.getLogger(snips_inference_agl.__name__) + logger.setLevel(level) + + formatter = logging.Formatter(FMT, DATE_FMT) + + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(formatter) + handler.setLevel(level) + + logger.addHandler(handler) diff --git a/snips_inference_agl/cli/versions.py b/snips_inference_agl/cli/versions.py new file mode 100644 index 0000000..d7922ef --- /dev/null +++ b/snips_inference_agl/cli/versions.py @@ -0,0 +1,19 @@ +from __future__ import print_function + + +def add_version_parser(subparsers, formatter_class): + from snips_inference_agl.__about__ import __version__ + subparser = subparsers.add_parser( + "version", formatter_class=formatter_class, + help="Print the package version") + subparser.set_defaults(func=lambda _: print(__version__)) + return subparser + + +def add_model_version_parser(subparsers, formatter_class): + from snips_inference_agl.__about__ import __model_version__ + subparser = subparsers.add_parser( + "model-version", formatter_class=formatter_class, + help="Print the model version") + subparser.set_defaults(func=lambda _: print(__model_version__)) + return subparser diff --git a/snips_inference_agl/common/__init__.py b/snips_inference_agl/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/snips_inference_agl/common/__init__.py diff --git a/snips_inference_agl/common/abc_utils.py b/snips_inference_agl/common/abc_utils.py new file mode 100644 index 0000000..db7b933 --- /dev/null +++ b/snips_inference_agl/common/abc_utils.py @@ -0,0 +1,36 @@ +class abstractclassmethod(classmethod): # pylint: disable=invalid-name + __isabstractmethod__ = True + + def __init__(self, callable): + callable.__isabstractmethod__ = True + super(abstractclassmethod, self).__init__(callable) + + +class ClassPropertyDescriptor(object): + def __init__(self, fget, fset=None): + self.fget = fget + self.fset = fset + + def __get__(self, obj, klass=None): + if klass is None: + klass = type(obj) + return self.fget.__get__(obj, klass)() + + def __set__(self, obj, value): + if not self.fset: + raise AttributeError("can't set attribute") + type_ = type(obj) + return self.fset.__get__(obj, type_)(value) + + def setter(self, func): + if not isinstance(func, (classmethod, staticmethod)): + func = classmethod(func) + self.fset = func + return self + + +def classproperty(func): + if not isinstance(func, (classmethod, staticmethod)): + func = classmethod(func) + + return ClassPropertyDescriptor(func) diff --git a/snips_inference_agl/common/dataset_utils.py b/snips_inference_agl/common/dataset_utils.py new file mode 100644 index 0000000..34648e6 --- /dev/null +++ b/snips_inference_agl/common/dataset_utils.py @@ -0,0 +1,48 @@ +from snips_inference_agl.constants import INTENTS, UTTERANCES, DATA, SLOT_NAME, ENTITY +from snips_inference_agl.exceptions import DatasetFormatError + + +def type_error(expected_type, found_type, object_label=None): + if object_label is None: + raise DatasetFormatError("Invalid type: expected %s but found %s" + % (expected_type, found_type)) + raise DatasetFormatError("Invalid type for '%s': expected %s but found %s" + % (object_label, expected_type, found_type)) + + +def validate_type(obj, expected_type, object_label=None): + if not isinstance(obj, expected_type): + type_error(expected_type, type(obj), object_label) + + +def missing_key_error(key, object_label=None): + if object_label is None: + raise DatasetFormatError("Missing key: '%s'" % key) + raise DatasetFormatError("Expected %s to have key: '%s'" + % (object_label, key)) + + +def validate_key(obj, key, object_label=None): + if key not in obj: + missing_key_error(key, object_label) + + +def validate_keys(obj, keys, object_label=None): + for key in keys: + validate_key(obj, key, object_label) + + +def get_slot_name_mapping(dataset, intent): + """Returns a dict which maps slot names to entities for the provided intent + """ + slot_name_mapping = dict() + for utterance in dataset[INTENTS][intent][UTTERANCES]: + for chunk in utterance[DATA]: + if SLOT_NAME in chunk: + slot_name_mapping[chunk[SLOT_NAME]] = chunk[ENTITY] + return slot_name_mapping + +def get_slot_name_mappings(dataset): + """Returns a dict which maps intents to their slot name mapping""" + return {intent: get_slot_name_mapping(dataset, intent) + for intent in dataset[INTENTS]}
\ No newline at end of file diff --git a/snips_inference_agl/common/dict_utils.py b/snips_inference_agl/common/dict_utils.py new file mode 100644 index 0000000..a70b217 --- /dev/null +++ b/snips_inference_agl/common/dict_utils.py @@ -0,0 +1,36 @@ +from collections import OrderedDict + + +class LimitedSizeDict(OrderedDict): + def __init__(self, *args, **kwds): + if "size_limit" not in kwds: + raise ValueError("'size_limit' must be passed as a keyword " + "argument") + self.size_limit = kwds.pop("size_limit") + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + if len(args) == 1 and len(args[0]) + len(kwds) > self.size_limit: + raise ValueError("Tried to initialize LimitedSizedDict with more " + "value than permitted with 'limit_size'") + super(LimitedSizeDict, self).__init__(*args, **kwds) + + def __setitem__(self, key, value, dict_setitem=OrderedDict.__setitem__): + dict_setitem(self, key, value) + self._check_size_limit() + + def _check_size_limit(self): + if self.size_limit is not None: + while len(self) > self.size_limit: + self.popitem(last=False) + + def __eq__(self, other): + if self.size_limit != other.size_limit: + return False + return super(LimitedSizeDict, self).__eq__(other) + + +class UnupdatableDict(dict): + def __setitem__(self, key, value): + if key in self: + raise KeyError("Can't update key '%s'" % key) + super(UnupdatableDict, self).__setitem__(key, value) diff --git a/snips_inference_agl/common/from_dict.py b/snips_inference_agl/common/from_dict.py new file mode 100644 index 0000000..2b776a6 --- /dev/null +++ b/snips_inference_agl/common/from_dict.py @@ -0,0 +1,30 @@ +try: + import funcsigs as inspect +except ImportError: + import inspect + +from future.utils import iteritems + +KEYWORD_KINDS = {inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY} + + +class FromDict(object): + @classmethod + def from_dict(cls, dict): + if dict is None: + return cls() + params = inspect.signature(cls.__init__).parameters + + if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in + params.values()): + return cls(**dict) + + param_names = set() + for i, (name, param) in enumerate(iteritems(params)): + if i == 0 and name == "self": + continue + if param.kind in KEYWORD_KINDS: + param_names.add(name) + filtered_dict = {k: v for k, v in iteritems(dict) if k in param_names} + return cls(**filtered_dict) diff --git a/snips_inference_agl/common/io_utils.py b/snips_inference_agl/common/io_utils.py new file mode 100644 index 0000000..f4407a3 --- /dev/null +++ b/snips_inference_agl/common/io_utils.py @@ -0,0 +1,36 @@ +import errno +import os +import shutil +from contextlib import contextmanager +from pathlib import Path +from tempfile import mkdtemp +from zipfile import ZipFile, ZIP_DEFLATED + + +def mkdir_p(path): + """Reproduces the 'mkdir -p shell' command + + See + http://stackoverflow.com/questions/600268/mkdir-p-functionality-in-python + """ + try: + os.makedirs(str(path)) + except OSError as exc: # Python >2.5 + if exc.errno == errno.EEXIST and path.is_dir(): + pass + else: + raise + + +@contextmanager +def temp_dir(): + tmp_dir = mkdtemp() + try: + yield Path(tmp_dir) + finally: + shutil.rmtree(tmp_dir) + + +def unzip_archive(archive_file, destination_dir): + with ZipFile(archive_file, "r", ZIP_DEFLATED) as zipf: + zipf.extractall(str(destination_dir)) diff --git a/snips_inference_agl/common/log_utils.py b/snips_inference_agl/common/log_utils.py new file mode 100644 index 0000000..47da34e --- /dev/null +++ b/snips_inference_agl/common/log_utils.py @@ -0,0 +1,61 @@ +from __future__ import unicode_literals + +from builtins import str +from datetime import datetime +from functools import wraps + +from snips_inference_agl.common.utils import json_debug_string + + +class DifferedLoggingMessage(object): + + def __init__(self, fn, *args, **kwargs): + self.fn = fn + self.args = args + self.kwargs = kwargs + + def __str__(self): + return str(self.fn(*self.args, **self.kwargs)) + + +def log_elapsed_time(logger, level, output_msg=None): + if output_msg is None: + output_msg = "Elapsed time ->:\n{elapsed_time}" + + def get_wrapper(fn): + @wraps(fn) + def wrapped(*args, **kwargs): + start = datetime.now() + msg_fmt = dict() + res = fn(*args, **kwargs) + if "elapsed_time" in output_msg: + msg_fmt["elapsed_time"] = datetime.now() - start + logger.log(level, output_msg.format(**msg_fmt)) + return res + + return wrapped + + return get_wrapper + + +def log_result(logger, level, output_msg=None): + if output_msg is None: + output_msg = "Result ->:\n{result}" + + def get_wrapper(fn): + @wraps(fn) + def wrapped(*args, **kwargs): + msg_fmt = dict() + res = fn(*args, **kwargs) + if "result" in output_msg: + try: + res_debug_string = json_debug_string(res) + except TypeError: + res_debug_string = str(res) + msg_fmt["result"] = res_debug_string + logger.log(level, output_msg.format(**msg_fmt)) + return res + + return wrapped + + return get_wrapper diff --git a/snips_inference_agl/common/registrable.py b/snips_inference_agl/common/registrable.py new file mode 100644 index 0000000..bbc7bdc --- /dev/null +++ b/snips_inference_agl/common/registrable.py @@ -0,0 +1,73 @@ +# This module is largely inspired by the AllenNLP library +# See github.com/allenai/allennlp/blob/master/allennlp/common/registrable.py + +from collections import defaultdict +from future.utils import iteritems + +from snips_inference_agl.exceptions import AlreadyRegisteredError, NotRegisteredError + + +class Registrable(object): + """ + Any class that inherits from ``Registrable`` gains access to a named + registry for its subclasses. To register them, just decorate them with the + classmethod ``@BaseClass.register(name)``. + + After which you can call ``BaseClass.list_available()`` to get the keys + for the registered subclasses, and ``BaseClass.by_name(name)`` to get the + corresponding subclass. + + Note that if you use this class to implement a new ``Registrable`` + abstract class, you must ensure that all subclasses of the abstract class + are loaded when the module is loaded, because the subclasses register + themselves in their respective files. You can achieve this by having the + abstract class and all subclasses in the __init__.py of the module in + which they reside (as this causes any import of either the abstract class + or a subclass to load all other subclasses and the abstract class). + """ + _registry = defaultdict(dict) + + @classmethod + def register(cls, name, override=False): + """Decorator used to add the decorated subclass to the registry of the + base class + + Args: + name (str): name use to identify the registered subclass + override (bool, optional): this parameter controls the behavior in + case where a subclass is registered with the same identifier. + If True, then the previous subclass will be unregistered in + profit of the new subclass. + + Raises: + AlreadyRegisteredError: when ``override`` is False, while trying + to register a subclass with a name already used by another + registered subclass + """ + registry = Registrable._registry[cls] + + def add_subclass_to_registry(subclass): + # Add to registry, raise an error if key has already been used. + if not override and name in registry: + raise AlreadyRegisteredError(name, cls, registry[name]) + registry[name] = subclass + return subclass + + return add_subclass_to_registry + + @classmethod + def registered_name(cls, registered_class): + for name, subclass in iteritems(Registrable._registry[cls]): + if subclass == registered_class: + return name + raise NotRegisteredError(cls, registered_cls=registered_class) + + @classmethod + def by_name(cls, name): + if name not in Registrable._registry[cls]: + raise NotRegisteredError(cls, name=name) + return Registrable._registry[cls][name] + + @classmethod + def list_available(cls): + return list(Registrable._registry[cls].keys()) diff --git a/snips_inference_agl/common/utils.py b/snips_inference_agl/common/utils.py new file mode 100644 index 0000000..816dae7 --- /dev/null +++ b/snips_inference_agl/common/utils.py @@ -0,0 +1,239 @@ +from __future__ import unicode_literals + +import importlib +import json +import numbers +import re +from builtins import bytes as newbytes, str as newstr +from datetime import datetime +from functools import wraps +from pathlib import Path + +from future.utils import text_type + +from snips_inference_agl.constants import (END, ENTITY_KIND, RES_MATCH_RANGE, RES_VALUE, + START) +from snips_inference_agl.exceptions import NotTrained, PersistingError + +REGEX_PUNCT = {'\\', '.', '+', '*', '?', '(', ')', '|', '[', ']', '{', '}', + '^', '$', '#', '&', '-', '~'} + + +# pylint:disable=line-too-long +def regex_escape(s): + """Escapes all regular expression meta characters in *s* + + The string returned may be safely used as a literal in a regular + expression. + + This function is more precise than :func:`re.escape`, the latter escapes + all non-alphanumeric characters which can cause cross-platform + compatibility issues. + + References: + + - https://github.com/rust-lang/regex/blob/master/regex-syntax/src/lib.rs#L1685 + - https://github.com/rust-lang/regex/blob/master/regex-syntax/src/parser.rs#L1378 + """ + escaped_string = "" + for c in s: + if c in REGEX_PUNCT: + escaped_string += "\\" + escaped_string += c + return escaped_string + + +# pylint:enable=line-too-long + + +def check_random_state(seed): + """Turn seed into a :class:`numpy.random.RandomState` instance + + If seed is None, return the RandomState singleton used by np.random. + If seed is an int, return a new RandomState instance seeded with seed. + If seed is already a RandomState instance, return it. + Otherwise raise ValueError. + """ + import numpy as np + + # pylint: disable=W0212 + # pylint: disable=c-extension-no-member + if seed is None or seed is np.random: + return np.random.mtrand._rand # pylint: disable=c-extension-no-member + if isinstance(seed, (numbers.Integral, np.integer)): + return np.random.RandomState(seed) + if isinstance(seed, np.random.RandomState): + return seed + raise ValueError('%r cannot be used to seed a numpy.random.RandomState' + ' instance' % seed) + + +def ranges_overlap(lhs_range, rhs_range): + if isinstance(lhs_range, dict) and isinstance(rhs_range, dict): + return lhs_range[END] > rhs_range[START] \ + and lhs_range[START] < rhs_range[END] + elif isinstance(lhs_range, (tuple, list)) \ + and isinstance(rhs_range, (tuple, list)): + return lhs_range[1] > rhs_range[0] and lhs_range[0] < rhs_range[1] + else: + raise TypeError("Cannot check overlap on objects of type: %s and %s" + % (type(lhs_range), type(rhs_range))) + + +def elapsed_since(time): + return datetime.now() - time + + +def json_debug_string(dict_data): + return json.dumps(dict_data, ensure_ascii=False, indent=2, sort_keys=True) + +def json_string(json_object, indent=2, sort_keys=True): + json_dump = json.dumps(json_object, indent=indent, sort_keys=sort_keys, + separators=(',', ': ')) + return unicode_string(json_dump) + +def unicode_string(string): + if isinstance(string, text_type): + return string + if isinstance(string, bytes): + return string.decode("utf8") + if isinstance(string, newstr): + return text_type(string) + if isinstance(string, newbytes): + string = bytes(string).decode("utf8") + + raise TypeError("Cannot convert %s into unicode string" % type(string)) + + +def check_persisted_path(func): + @wraps(func) + def func_wrapper(self, path, *args, **kwargs): + path = Path(path) + if path.exists(): + raise PersistingError(path) + return func(self, path, *args, **kwargs) + + return func_wrapper + + +def fitted_required(func): + @wraps(func) + def func_wrapper(self, *args, **kwargs): + if not self.fitted: + raise NotTrained("%s must be fitted" % self.unit_name) + return func(self, *args, **kwargs) + + return func_wrapper + + +def is_package(name): + """Check if name maps to a package installed via pip. + + Args: + name (str): Name of package + + Returns: + bool: True if an installed packaged corresponds to this name, False + otherwise. + """ + import pkg_resources + + name = name.lower().replace("-", "_") + packages = pkg_resources.working_set.by_key.keys() + for package in packages: + if package.lower().replace("-", "_") == name: + return True + return False + + +def get_package_path(name): + """Get the path to an installed package. + + Args: + name (str): Package name + + Returns: + class:`.Path`: Path to the installed package + """ + name = name.lower().replace("-", "_") + pkg = importlib.import_module(name) + return Path(pkg.__file__).parent + + +def deduplicate_overlapping_items(items, overlap_fn, sort_key_fn): + """Deduplicates the items by looping over the items, sorted using + sort_key_fn, and checking overlaps with previously seen items using + overlap_fn + """ + sorted_items = sorted(items, key=sort_key_fn) + deduplicated_items = [] + for item in sorted_items: + if not any(overlap_fn(item, dedup_item) + for dedup_item in deduplicated_items): + deduplicated_items.append(item) + return deduplicated_items + + +def replace_entities_with_placeholders(text, entities, placeholder_fn): + """Processes the text in order to replace entity values with placeholders + as defined by the placeholder function + """ + if not entities: + return dict(), text + + entities = deduplicate_overlapping_entities(entities) + entities = sorted( + entities, key=lambda e: e[RES_MATCH_RANGE][START]) + + range_mapping = dict() + processed_text = "" + offset = 0 + current_ix = 0 + for ent in entities: + ent_start = ent[RES_MATCH_RANGE][START] + ent_end = ent[RES_MATCH_RANGE][END] + rng_start = ent_start + offset + + processed_text += text[current_ix:ent_start] + + entity_length = ent_end - ent_start + entity_place_holder = placeholder_fn(ent[ENTITY_KIND]) + + offset += len(entity_place_holder) - entity_length + + processed_text += entity_place_holder + rng_end = ent_end + offset + new_range = (rng_start, rng_end) + range_mapping[new_range] = ent[RES_MATCH_RANGE] + current_ix = ent_end + + processed_text += text[current_ix:] + return range_mapping, processed_text + + +def deduplicate_overlapping_entities(entities): + """Deduplicates entities based on overlapping ranges""" + + def overlap(lhs_entity, rhs_entity): + return ranges_overlap(lhs_entity[RES_MATCH_RANGE], + rhs_entity[RES_MATCH_RANGE]) + + def sort_key_fn(entity): + return -len(entity[RES_VALUE]) + + deduplicated_entities = deduplicate_overlapping_items( + entities, overlap, sort_key_fn) + return sorted(deduplicated_entities, + key=lambda entity: entity[RES_MATCH_RANGE][START]) + + +SEMVER_PATTERN = r"^(?P<major>0|[1-9]\d*)" \ + r".(?P<minor>0|[1-9]\d*)" \ + r".(?P<patch>0|[1-9]\d*)" \ + r"(?:.(?P<subpatch>0|[1-9]\d*))?" \ + r"(?:-(?P<prerelease>(?:0|[1-9]\d*|\d*[a-zA-Z-]" \ + r"[0-9a-zA-Z-]*)" \ + r"(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?" \ + r"(?:\+(?P<buildmetadata>[0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*)" \ + r")?$" +SEMVER_REGEX = re.compile(SEMVER_PATTERN) diff --git a/snips_inference_agl/constants.py b/snips_inference_agl/constants.py new file mode 100644 index 0000000..8affd08 --- /dev/null +++ b/snips_inference_agl/constants.py @@ -0,0 +1,74 @@ +from __future__ import unicode_literals + +from pathlib import Path + +# package +ROOT_PATH = Path(__file__).parent.parent +PACKAGE_NAME = "snips_inference_agl" +DATA_PACKAGE_NAME = "data" +DATA_PATH = ROOT_PATH / PACKAGE_NAME / DATA_PACKAGE_NAME +PACKAGE_PATH = ROOT_PATH / PACKAGE_NAME + +# result +RES_INPUT = "input" +RES_INTENT = "intent" +RES_SLOTS = "slots" +RES_INTENT_NAME = "intentName" +RES_PROBA = "probability" +RES_SLOT_NAME = "slotName" +RES_ENTITY = "entity" +RES_VALUE = "value" +RES_RAW_VALUE = "rawValue" +RES_MATCH_RANGE = "range" + +# miscellaneous +AUTOMATICALLY_EXTENSIBLE = "automatically_extensible" +USE_SYNONYMS = "use_synonyms" +SYNONYMS = "synonyms" +DATA = "data" +INTENTS = "intents" +ENTITIES = "entities" +ENTITY = "entity" +ENTITY_KIND = "entity_kind" +RESOLVED_VALUE = "resolved_value" +SLOT_NAME = "slot_name" +TEXT = "text" +UTTERANCES = "utterances" +LANGUAGE = "language" +VALUE = "value" +NGRAM = "ngram" +CAPITALIZE = "capitalize" +UNKNOWNWORD = "unknownword" +VALIDATED = "validated" +START = "start" +END = "end" +BUILTIN_ENTITY_PARSER = "builtin_entity_parser" +CUSTOM_ENTITY_PARSER = "custom_entity_parser" +MATCHING_STRICTNESS = "matching_strictness" +LICENSE_INFO = "license_info" +RANDOM_STATE = "random_state" +BYPASS_VERSION_CHECK = "bypass_version_check" + +# resources +RESOURCES = "resources" +METADATA = "metadata" +STOP_WORDS = "stop_words" +NOISE = "noise" +GAZETTEERS = "gazetteers" +STEMS = "stems" +CUSTOM_ENTITY_PARSER_USAGE = "custom_entity_parser_usage" +WORD_CLUSTERS = "word_clusters" + +# builtin entities +SNIPS_NUMBER = "snips/number" + +# languages +LANGUAGE_DE = "de" +LANGUAGE_EN = "en" +LANGUAGE_ES = "es" +LANGUAGE_FR = "fr" +LANGUAGE_IT = "it" +LANGUAGE_JA = "ja" +LANGUAGE_KO = "ko" +LANGUAGE_PT_BR = "pt_br" +LANGUAGE_PT_PT = "pt_pt" diff --git a/snips_inference_agl/data_augmentation.py b/snips_inference_agl/data_augmentation.py new file mode 100644 index 0000000..5a37f5e --- /dev/null +++ b/snips_inference_agl/data_augmentation.py @@ -0,0 +1,121 @@ +from __future__ import unicode_literals + +from builtins import next +from copy import deepcopy +from itertools import cycle + +from future.utils import iteritems + +from snips_inference_agl.constants import ( + CAPITALIZE, DATA, ENTITIES, ENTITY, INTENTS, TEXT, UTTERANCES) +from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity +from snips_inference_agl.languages import get_default_sep +from snips_inference_agl.preprocessing import tokenize_light +from snips_inference_agl.resources import get_stop_words + + +def capitalize(text, language, resources): + tokens = tokenize_light(text, language) + stop_words = get_stop_words(resources) + return get_default_sep(language).join( + t.title() if t.lower() not in stop_words + else t.lower() for t in tokens) + + +def capitalize_utterances(utterances, entities, language, ratio, resources, + random_state): + capitalized_utterances = [] + for utterance in utterances: + capitalized_utterance = deepcopy(utterance) + for i, chunk in enumerate(capitalized_utterance[DATA]): + capitalized_utterance[DATA][i][TEXT] = chunk[TEXT].lower() + if ENTITY not in chunk: + continue + entity_label = chunk[ENTITY] + if is_builtin_entity(entity_label): + continue + if not entities[entity_label][CAPITALIZE]: + continue + if random_state.rand() > ratio: + continue + capitalized_utterance[DATA][i][TEXT] = capitalize( + chunk[TEXT], language, resources) + capitalized_utterances.append(capitalized_utterance) + return capitalized_utterances + + +def generate_utterance(contexts_iterator, entities_iterators): + context = deepcopy(next(contexts_iterator)) + context_data = [] + for chunk in context[DATA]: + if ENTITY in chunk: + chunk[TEXT] = deepcopy( + next(entities_iterators[chunk[ENTITY]])) + chunk[TEXT] = chunk[TEXT].strip() + " " + context_data.append(chunk) + context[DATA] = context_data + return context + + +def get_contexts_iterator(dataset, intent_name, random_state): + shuffled_utterances = random_state.permutation( + dataset[INTENTS][intent_name][UTTERANCES]) + return cycle(shuffled_utterances) + + +def get_entities_iterators(intent_entities, language, + add_builtin_entities_examples, random_state): + from snips_nlu_parsers import get_builtin_entity_examples + + entities_its = dict() + for entity_name, entity in iteritems(intent_entities): + utterance_values = random_state.permutation(sorted(entity[UTTERANCES])) + if add_builtin_entities_examples and is_builtin_entity(entity_name): + entity_examples = get_builtin_entity_examples( + entity_name, language) + # Builtin entity examples must be kept first in the iterator to + # ensure that they are used when augmenting data + iterator_values = entity_examples + list(utterance_values) + else: + iterator_values = utterance_values + entities_its[entity_name] = cycle(iterator_values) + return entities_its + + +def get_intent_entities(dataset, intent_name): + intent_entities = set() + for utterance in dataset[INTENTS][intent_name][UTTERANCES]: + for chunk in utterance[DATA]: + if ENTITY in chunk: + intent_entities.add(chunk[ENTITY]) + return sorted(intent_entities) + + +def num_queries_to_generate(dataset, intent_name, min_utterances): + nb_utterances = len(dataset[INTENTS][intent_name][UTTERANCES]) + return max(nb_utterances, min_utterances) + + +def augment_utterances(dataset, intent_name, language, min_utterances, + capitalization_ratio, add_builtin_entities_examples, + resources, random_state): + contexts_it = get_contexts_iterator(dataset, intent_name, random_state) + intent_entities = {e: dataset[ENTITIES][e] + for e in get_intent_entities(dataset, intent_name)} + entities_its = get_entities_iterators(intent_entities, language, + add_builtin_entities_examples, + random_state) + generated_utterances = [] + nb_to_generate = num_queries_to_generate(dataset, intent_name, + min_utterances) + while nb_to_generate > 0: + generated_utterance = generate_utterance(contexts_it, entities_its) + generated_utterances.append(generated_utterance) + nb_to_generate -= 1 + + generated_utterances = capitalize_utterances( + generated_utterances, dataset[ENTITIES], language, + ratio=capitalization_ratio, resources=resources, + random_state=random_state) + + return generated_utterances diff --git a/snips_inference_agl/dataset/__init__.py b/snips_inference_agl/dataset/__init__.py new file mode 100644 index 0000000..4cbed08 --- /dev/null +++ b/snips_inference_agl/dataset/__init__.py @@ -0,0 +1,7 @@ +from snips_inference_agl.dataset.dataset import Dataset +from snips_inference_agl.dataset.entity import Entity +from snips_inference_agl.dataset.intent import Intent +from snips_inference_agl.dataset.utils import ( + extract_intent_entities, extract_utterance_entities, + get_dataset_gazetteer_entities, get_text_from_chunks) +from snips_inference_agl.dataset.validation import validate_and_format_dataset diff --git a/snips_inference_agl/dataset/dataset.py b/snips_inference_agl/dataset/dataset.py new file mode 100644 index 0000000..2ad7867 --- /dev/null +++ b/snips_inference_agl/dataset/dataset.py @@ -0,0 +1,102 @@ +# coding=utf-8 +from __future__ import print_function, unicode_literals + +import io +from itertools import cycle + +from snips_inference_agl.common.utils import unicode_string +from snips_inference_agl.dataset.entity import Entity +from snips_inference_agl.dataset.intent import Intent +from snips_inference_agl.exceptions import DatasetFormatError + + +class Dataset(object): + """Dataset used in the main NLU training API + + Consists of intents and entities data. This object can be built either from + text files (:meth:`.Dataset.from_files`) or from YAML files + (:meth:`.Dataset.from_yaml_files`). + + Attributes: + language (str): language of the intents + intents (list of :class:`.Intent`): intents data + entities (list of :class:`.Entity`): entities data + """ + + def __init__(self, language, intents, entities): + self.language = language + self.intents = intents + self.entities = entities + self._add_missing_entities() + self._ensure_entity_values() + + @classmethod + def _load_dataset_parts(cls, stream, stream_description): + from snips_inference_agl.dataset.yaml_wrapper import yaml + + intents = [] + entities = [] + for doc in yaml.safe_load_all(stream): + doc_type = doc.get("type") + if doc_type == "entity": + entities.append(Entity.from_yaml(doc)) + elif doc_type == "intent": + intents.append(Intent.from_yaml(doc)) + else: + raise DatasetFormatError( + "Invalid 'type' value in YAML file '%s': '%s'" + % (stream_description, doc_type)) + return intents, entities + + def _add_missing_entities(self): + entity_names = set(e.name for e in self.entities) + + # Add entities appearing only in the intents utterances + for intent in self.intents: + for entity_name in intent.entities_names: + if entity_name not in entity_names: + entity_names.add(entity_name) + self.entities.append(Entity(name=entity_name)) + + def _ensure_entity_values(self): + entities_values = {entity.name: self._get_entity_values(entity) + for entity in self.entities} + for intent in self.intents: + for utterance in intent.utterances: + for chunk in utterance.slot_chunks: + if chunk.text is not None: + continue + try: + chunk.text = next(entities_values[chunk.entity]) + except StopIteration: + raise DatasetFormatError( + "At least one entity value must be provided for " + "entity '%s'" % chunk.entity) + return self + + def _get_entity_values(self, entity): + from snips_nlu_parsers import get_builtin_entity_examples + + if entity.is_builtin: + return cycle(get_builtin_entity_examples( + entity.name, self.language)) + values = [v for utterance in entity.utterances + for v in utterance.variations] + values_set = set(values) + for intent in self.intents: + for utterance in intent.utterances: + for chunk in utterance.slot_chunks: + if not chunk.text or chunk.entity != entity.name: + continue + if chunk.text not in values_set: + values_set.add(chunk.text) + values.append(chunk.text) + return cycle(values) + + @property + def json(self): + """Dataset data in json format""" + intents = {intent_data.intent_name: intent_data.json + for intent_data in self.intents} + entities = {entity.name: entity.json for entity in self.entities} + return dict(language=self.language, intents=intents, entities=entities) diff --git a/snips_inference_agl/dataset/entity.py b/snips_inference_agl/dataset/entity.py new file mode 100644 index 0000000..65b9994 --- /dev/null +++ b/snips_inference_agl/dataset/entity.py @@ -0,0 +1,175 @@ +# coding=utf-8 +from __future__ import unicode_literals + +from builtins import str +from io import IOBase + +from snips_inference_agl.constants import ( + AUTOMATICALLY_EXTENSIBLE, DATA, MATCHING_STRICTNESS, SYNONYMS, + USE_SYNONYMS, VALUE) +from snips_inference_agl.exceptions import EntityFormatError + + +class Entity(object): + """Entity data of a :class:`.Dataset` + + This class can represents both a custom or a builtin entity. When the + entity is a builtin one, only the `name` attribute is relevant. + + Attributes: + name (str): name of the entity + utterances (list of :class:`.EntityUtterance`): entity utterances + (only for custom entities) + automatically_extensible (bool): whether or not the entity can be + extended to values not present in the data (only for custom + entities) + use_synonyms (bool): whether or not to map entity values using + synonyms (only for custom entities) + matching_strictness (float): controls the matching strictness of the + entity (only for custom entities). Must be between 0.0 and 1.0. + """ + + def __init__(self, name, utterances=None, automatically_extensible=True, + use_synonyms=True, matching_strictness=1.0): + if utterances is None: + utterances = [] + self.name = name + self.utterances = utterances + self.automatically_extensible = automatically_extensible + self.use_synonyms = use_synonyms + self.matching_strictness = matching_strictness + + @property + def is_builtin(self): + from snips_nlu_parsers import get_all_builtin_entities + + return self.name in get_all_builtin_entities() + + @classmethod + def from_yaml(cls, yaml_dict): + """Build an :class:`.Entity` from its YAML definition object + + Args: + yaml_dict (dict or :class:`.IOBase`): object containing the YAML + definition of the entity. It can be either a stream, or the + corresponding python dict. + + Examples: + An entity can be defined with a YAML document following the schema + illustrated in the example below: + + >>> import io + >>> from snips_inference_agl.common.utils import json_string + >>> entity_yaml = io.StringIO(''' + ... # City Entity + ... --- + ... type: entity + ... name: city + ... automatically_extensible: false # default value is true + ... use_synonyms: false # default value is true + ... matching_strictness: 0.8 # default value is 1.0 + ... values: + ... - london + ... - [new york, big apple] + ... - [paris, city of lights]''') + >>> entity = Entity.from_yaml(entity_yaml) + >>> print(json_string(entity.json, indent=4, sort_keys=True)) + { + "automatically_extensible": false, + "data": [ + { + "synonyms": [], + "value": "london" + }, + { + "synonyms": [ + "big apple" + ], + "value": "new york" + }, + { + "synonyms": [ + "city of lights" + ], + "value": "paris" + } + ], + "matching_strictness": 0.8, + "use_synonyms": false + } + + Raises: + EntityFormatError: When the YAML dict does not correspond to the + :ref:`expected entity format <yaml_entity_format>` + """ + if isinstance(yaml_dict, IOBase): + from snips_inference_agl.dataset.yaml_wrapper import yaml + + yaml_dict = yaml.safe_load(yaml_dict) + + object_type = yaml_dict.get("type") + if object_type and object_type != "entity": + raise EntityFormatError("Wrong type: '%s'" % object_type) + entity_name = yaml_dict.get("name") + if not entity_name: + raise EntityFormatError("Missing 'name' attribute") + auto_extensible = yaml_dict.get(AUTOMATICALLY_EXTENSIBLE, True) + use_synonyms = yaml_dict.get(USE_SYNONYMS, True) + matching_strictness = yaml_dict.get("matching_strictness", 1.0) + utterances = [] + for entity_value in yaml_dict.get("values", []): + if isinstance(entity_value, list): + utterance = EntityUtterance(entity_value[0], entity_value[1:]) + elif isinstance(entity_value, str): + utterance = EntityUtterance(entity_value) + else: + raise EntityFormatError( + "YAML entity values must be either strings or lists, but " + "found: %s" % type(entity_value)) + utterances.append(utterance) + + return cls(name=entity_name, + utterances=utterances, + automatically_extensible=auto_extensible, + use_synonyms=use_synonyms, + matching_strictness=matching_strictness) + + @property + def json(self): + """Returns the entity in json format""" + if self.is_builtin: + return dict() + return { + AUTOMATICALLY_EXTENSIBLE: self.automatically_extensible, + USE_SYNONYMS: self.use_synonyms, + DATA: [u.json for u in self.utterances], + MATCHING_STRICTNESS: self.matching_strictness + } + + +class EntityUtterance(object): + """Represents a value of a :class:`.CustomEntity` with potential synonyms + + Attributes: + value (str): entity value + synonyms (list of str): The values to remap to the utterance value + """ + + def __init__(self, value, synonyms=None): + self.value = value + if synonyms is None: + synonyms = [] + self.synonyms = synonyms + + @property + def variations(self): + return [self.value] + self.synonyms + + @property + def json(self): + return {VALUE: self.value, SYNONYMS: self.synonyms} + + +def utf_8_encoder(f): + for line in f: + yield line.encode("utf-8") diff --git a/snips_inference_agl/dataset/intent.py b/snips_inference_agl/dataset/intent.py new file mode 100644 index 0000000..0a915ce --- /dev/null +++ b/snips_inference_agl/dataset/intent.py @@ -0,0 +1,339 @@ +from __future__ import absolute_import, print_function, unicode_literals + +from abc import ABCMeta, abstractmethod +from builtins import object +from io import IOBase + +from future.utils import with_metaclass + +from snips_inference_agl.constants import DATA, ENTITY, SLOT_NAME, TEXT, UTTERANCES +from snips_inference_agl.exceptions import IntentFormatError + + +class Intent(object): + """Intent data of a :class:`.Dataset` + + Attributes: + intent_name (str): name of the intent + utterances (list of :class:`.IntentUtterance`): annotated intent + utterances + slot_mapping (dict): mapping between slot names and entities + """ + + def __init__(self, intent_name, utterances, slot_mapping=None): + if slot_mapping is None: + slot_mapping = dict() + self.intent_name = intent_name + self.utterances = utterances + self.slot_mapping = slot_mapping + self._complete_slot_name_mapping() + self._ensure_entity_names() + + @classmethod + def from_yaml(cls, yaml_dict): + """Build an :class:`.Intent` from its YAML definition object + + Args: + yaml_dict (dict or :class:`.IOBase`): object containing the YAML + definition of the intent. It can be either a stream, or the + corresponding python dict. + + Examples: + An intent can be defined with a YAML document following the schema + illustrated in the example below: + + >>> import io + >>> from snips_inference_agl.common.utils import json_string + >>> intent_yaml = io.StringIO(''' + ... # searchFlight Intent + ... --- + ... type: intent + ... name: searchFlight + ... slots: + ... - name: origin + ... entity: city + ... - name: destination + ... entity: city + ... - name: date + ... entity: snips/datetime + ... utterances: + ... - find me a flight from [origin](Oslo) to [destination](Lima) + ... - I need a flight leaving to [destination](Berlin)''') + >>> intent = Intent.from_yaml(intent_yaml) + >>> print(json_string(intent.json, indent=4, sort_keys=True)) + { + "utterances": [ + { + "data": [ + { + "text": "find me a flight from " + }, + { + "entity": "city", + "slot_name": "origin", + "text": "Oslo" + }, + { + "text": " to " + }, + { + "entity": "city", + "slot_name": "destination", + "text": "Lima" + } + ] + }, + { + "data": [ + { + "text": "I need a flight leaving to " + }, + { + "entity": "city", + "slot_name": "destination", + "text": "Berlin" + } + ] + } + ] + } + + Raises: + IntentFormatError: When the YAML dict does not correspond to the + :ref:`expected intent format <yaml_intent_format>` + """ + + if isinstance(yaml_dict, IOBase): + from snips_inference_agl.dataset.yaml_wrapper import yaml + + yaml_dict = yaml.safe_load(yaml_dict) + + object_type = yaml_dict.get("type") + if object_type and object_type != "intent": + raise IntentFormatError("Wrong type: '%s'" % object_type) + intent_name = yaml_dict.get("name") + if not intent_name: + raise IntentFormatError("Missing 'name' attribute") + slot_mapping = dict() + for slot in yaml_dict.get("slots", []): + slot_mapping[slot["name"]] = slot["entity"] + utterances = [IntentUtterance.parse(u.strip()) + for u in yaml_dict["utterances"] if u.strip()] + if not utterances: + raise IntentFormatError( + "Intent must contain at least one utterance") + return cls(intent_name, utterances, slot_mapping) + + def _complete_slot_name_mapping(self): + for utterance in self.utterances: + for chunk in utterance.slot_chunks: + if chunk.entity and chunk.slot_name not in self.slot_mapping: + self.slot_mapping[chunk.slot_name] = chunk.entity + return self + + def _ensure_entity_names(self): + for utterance in self.utterances: + for chunk in utterance.slot_chunks: + if chunk.entity: + continue + chunk.entity = self.slot_mapping.get( + chunk.slot_name, chunk.slot_name) + return self + + @property + def json(self): + """Intent data in json format""" + return { + UTTERANCES: [ + {DATA: [chunk.json for chunk in utterance.chunks]} + for utterance in self.utterances + ] + } + + @property + def entities_names(self): + return set(chunk.entity for u in self.utterances + for chunk in u.chunks if isinstance(chunk, SlotChunk)) + + +class IntentUtterance(object): + def __init__(self, chunks): + self.chunks = chunks + + @property + def text(self): + return "".join((chunk.text for chunk in self.chunks)) + + @property + def slot_chunks(self): + return (chunk for chunk in self.chunks if isinstance(chunk, SlotChunk)) + + @classmethod + def parse(cls, string): + """Parses an utterance + + Args: + string (str): an utterance in the class:`.Utterance` format + + Examples: + + >>> from snips_inference_agl.dataset.intent import IntentUtterance + >>> u = IntentUtterance.\ + parse("president of [country:default](France)") + >>> u.text + 'president of France' + >>> len(u.chunks) + 2 + >>> u.chunks[0].text + 'president of ' + >>> u.chunks[1].slot_name + 'country' + >>> u.chunks[1].entity + 'default' + """ + sm = SM(string) + capture_text(sm) + return cls(sm.chunks) + + +class Chunk(with_metaclass(ABCMeta, object)): + def __init__(self, text): + self.text = text + + @abstractmethod + def json(self): + pass + + +class SlotChunk(Chunk): + def __init__(self, slot_name, entity, text): + super(SlotChunk, self).__init__(text) + self.slot_name = slot_name + self.entity = entity + + @property + def json(self): + return { + TEXT: self.text, + SLOT_NAME: self.slot_name, + ENTITY: self.entity, + } + + +class TextChunk(Chunk): + @property + def json(self): + return { + TEXT: self.text + } + + +class SM(object): + """State Machine for parsing""" + + def __init__(self, input): + self.input = input + self.chunks = [] + self.current = 0 + + @property + def end_of_input(self): + return self.current >= len(self.input) + + def add_slot(self, name, entity=None): + """Adds a named slot + + Args: + name (str): slot name + entity (str): entity name + """ + chunk = SlotChunk(slot_name=name, entity=entity, text=None) + self.chunks.append(chunk) + + def add_text(self, text): + """Adds a simple text chunk using the current position""" + chunk = TextChunk(text=text) + self.chunks.append(chunk) + + def add_tagged(self, text): + """Adds text to the last slot""" + if not self.chunks: + raise AssertionError("Cannot add tagged text because chunks list " + "is empty") + self.chunks[-1].text = text + + def find(self, s): + return self.input.find(s, self.current) + + def move(self, pos): + """Moves the cursor of the state to position after given + + Args: + pos (int): position to place the cursor just after + """ + self.current = pos + 1 + + def peek(self): + if self.end_of_input: + return None + return self[0] + + def read(self): + c = self[0] + self.current += 1 + return c + + def __getitem__(self, key): + current = self.current + if isinstance(key, int): + return self.input[current + key] + elif isinstance(key, slice): + start = current + key.start if key.start else current + return self.input[slice(start, key.stop, key.step)] + else: + raise TypeError("Bad key type: %s" % type(key)) + + +def capture_text(state): + next_pos = state.find('[') + sub = state[:] if next_pos < 0 else state[:next_pos] + if sub: + state.add_text(sub) + if next_pos >= 0: + state.move(next_pos) + capture_slot(state) + + +def capture_slot(state): + next_colon_pos = state.find(':') + next_square_bracket_pos = state.find(']') + if next_square_bracket_pos < 0: + raise IntentFormatError( + "Missing ending ']' in annotated utterance \"%s\"" % state.input) + if next_colon_pos < 0 or next_square_bracket_pos < next_colon_pos: + slot_name = state[:next_square_bracket_pos] + state.move(next_square_bracket_pos) + state.add_slot(slot_name) + else: + slot_name = state[:next_colon_pos] + state.move(next_colon_pos) + entity = state[:next_square_bracket_pos] + state.move(next_square_bracket_pos) + state.add_slot(slot_name, entity) + if state.peek() == '(': + state.read() + capture_tagged(state) + else: + capture_text(state) + + +def capture_tagged(state): + next_pos = state.find(')') + if next_pos < 1: + raise IntentFormatError( + "Missing ending ')' in annotated utterance \"%s\"" % state.input) + else: + tagged_text = state[:next_pos] + state.add_tagged(tagged_text) + state.move(next_pos) + capture_text(state) diff --git a/snips_inference_agl/dataset/utils.py b/snips_inference_agl/dataset/utils.py new file mode 100644 index 0000000..f147f0f --- /dev/null +++ b/snips_inference_agl/dataset/utils.py @@ -0,0 +1,67 @@ +from __future__ import unicode_literals + +from future.utils import iteritems, itervalues + +from snips_inference_agl.constants import ( + DATA, ENTITIES, ENTITY, INTENTS, TEXT, UTTERANCES) +from snips_inference_agl.entity_parser.builtin_entity_parser import is_gazetteer_entity + + +def extract_utterance_entities(dataset): + entities_values = {ent_name: set() for ent_name in dataset[ENTITIES]} + + for intent in itervalues(dataset[INTENTS]): + for utterance in intent[UTTERANCES]: + for chunk in utterance[DATA]: + if ENTITY in chunk: + entities_values[chunk[ENTITY]].add(chunk[TEXT].strip()) + return {k: list(v) for k, v in iteritems(entities_values)} + + +def extract_intent_entities(dataset, entity_filter=None): + intent_entities = {intent: set() for intent in dataset[INTENTS]} + for intent_name, intent_data in iteritems(dataset[INTENTS]): + for utterance in intent_data[UTTERANCES]: + for chunk in utterance[DATA]: + if ENTITY in chunk: + if entity_filter and not entity_filter(chunk[ENTITY]): + continue + intent_entities[intent_name].add(chunk[ENTITY]) + return intent_entities + + +def extract_entity_values(dataset, apply_normalization): + from snips_nlu_utils import normalize + + entities_per_intent = {intent: set() for intent in dataset[INTENTS]} + intent_entities = extract_intent_entities(dataset) + for intent, entities in iteritems(intent_entities): + for entity in entities: + entity_values = set(dataset[ENTITIES][entity][UTTERANCES]) + if apply_normalization: + entity_values = {normalize(v) for v in entity_values} + entities_per_intent[intent].update(entity_values) + return entities_per_intent + + +def get_text_from_chunks(chunks): + return "".join(chunk[TEXT] for chunk in chunks) + + +def get_dataset_gazetteer_entities(dataset, intent=None): + if intent is not None: + return extract_intent_entities(dataset, is_gazetteer_entity)[intent] + return {e for e in dataset[ENTITIES] if is_gazetteer_entity(e)} + + +def get_stop_words_whitelist(dataset, stop_words): + """Extracts stop words whitelists per intent consisting of entity values + that appear in the stop_words list""" + entity_values_per_intent = extract_entity_values( + dataset, apply_normalization=True) + stop_words_whitelist = dict() + for intent, entity_values in iteritems(entity_values_per_intent): + whitelist = stop_words.intersection(entity_values) + if whitelist: + stop_words_whitelist[intent] = whitelist + return stop_words_whitelist diff --git a/snips_inference_agl/dataset/validation.py b/snips_inference_agl/dataset/validation.py new file mode 100644 index 0000000..d6fc4a1 --- /dev/null +++ b/snips_inference_agl/dataset/validation.py @@ -0,0 +1,254 @@ +from __future__ import division, unicode_literals + +import json +from builtins import str +from collections import Counter +from copy import deepcopy + +from future.utils import iteritems, itervalues + +from snips_inference_agl.common.dataset_utils import (validate_key, validate_keys, + validate_type) +from snips_inference_agl.constants import ( + AUTOMATICALLY_EXTENSIBLE, CAPITALIZE, DATA, ENTITIES, ENTITY, INTENTS, + LANGUAGE, MATCHING_STRICTNESS, SLOT_NAME, SYNONYMS, TEXT, USE_SYNONYMS, + UTTERANCES, VALIDATED, VALUE, LICENSE_INFO) +from snips_inference_agl.dataset import extract_utterance_entities, Dataset +from snips_inference_agl.entity_parser.builtin_entity_parser import ( + BuiltinEntityParser, is_builtin_entity) +from snips_inference_agl.exceptions import DatasetFormatError +from snips_inference_agl.preprocessing import tokenize_light +from snips_inference_agl.string_variations import get_string_variations + +NUMBER_VARIATIONS_THRESHOLD = 1e3 +VARIATIONS_GENERATION_THRESHOLD = 1e4 + + +def validate_and_format_dataset(dataset): + """Checks that the dataset is valid and format it + + Raise: + DatasetFormatError: When the dataset format is wrong + """ + from snips_nlu_parsers import get_all_languages + + if isinstance(dataset, Dataset): + dataset = dataset.json + + # Make this function idempotent + if dataset.get(VALIDATED, False): + return dataset + dataset = deepcopy(dataset) + dataset = json.loads(json.dumps(dataset)) + validate_type(dataset, dict, object_label="dataset") + mandatory_keys = [INTENTS, ENTITIES, LANGUAGE] + for key in mandatory_keys: + validate_key(dataset, key, object_label="dataset") + validate_type(dataset[ENTITIES], dict, object_label="entities") + validate_type(dataset[INTENTS], dict, object_label="intents") + language = dataset[LANGUAGE] + validate_type(language, str, object_label="language") + if language not in get_all_languages(): + raise DatasetFormatError("Unknown language: '%s'" % language) + + dataset[INTENTS] = { + intent_name: intent_data + for intent_name, intent_data in sorted(iteritems(dataset[INTENTS]))} + for intent in itervalues(dataset[INTENTS]): + _validate_and_format_intent(intent, dataset[ENTITIES]) + + utterance_entities_values = extract_utterance_entities(dataset) + builtin_entity_parser = BuiltinEntityParser.build(dataset=dataset) + + dataset[ENTITIES] = { + intent_name: entity_data + for intent_name, entity_data in sorted(iteritems(dataset[ENTITIES]))} + + for entity_name, entity in iteritems(dataset[ENTITIES]): + uterrance_entities = utterance_entities_values[entity_name] + if is_builtin_entity(entity_name): + dataset[ENTITIES][entity_name] = \ + _validate_and_format_builtin_entity(entity, uterrance_entities) + else: + dataset[ENTITIES][entity_name] = \ + _validate_and_format_custom_entity( + entity, uterrance_entities, language, + builtin_entity_parser) + dataset[VALIDATED] = True + return dataset + + +def _validate_and_format_intent(intent, entities): + validate_type(intent, dict, "intent") + validate_key(intent, UTTERANCES, object_label="intent dict") + validate_type(intent[UTTERANCES], list, object_label="utterances") + for utterance in intent[UTTERANCES]: + validate_type(utterance, dict, object_label="utterance") + validate_key(utterance, DATA, object_label="utterance") + validate_type(utterance[DATA], list, object_label="utterance data") + for chunk in utterance[DATA]: + validate_type(chunk, dict, object_label="utterance chunk") + validate_key(chunk, TEXT, object_label="chunk") + if ENTITY in chunk or SLOT_NAME in chunk: + mandatory_keys = [ENTITY, SLOT_NAME] + validate_keys(chunk, mandatory_keys, object_label="chunk") + if is_builtin_entity(chunk[ENTITY]): + continue + else: + validate_key(entities, chunk[ENTITY], + object_label=ENTITIES) + return intent + + +def _has_any_capitalization(entity_utterances, language): + for utterance in entity_utterances: + tokens = tokenize_light(utterance, language) + if any(t.isupper() or t.istitle() for t in tokens): + return True + return False + + +def _add_entity_variations(utterances, entity_variations, entity_value): + utterances[entity_value] = entity_value + for variation in entity_variations[entity_value]: + if variation: + utterances[variation] = entity_value + return utterances + + +def _extract_entity_values(entity): + values = set() + for ent in entity[DATA]: + values.add(ent[VALUE]) + if entity[USE_SYNONYMS]: + values.update(set(ent[SYNONYMS])) + return values + + +def _validate_and_format_custom_entity(entity, utterance_entities, language, + builtin_entity_parser): + validate_type(entity, dict, object_label="entity") + + # TODO: this is here temporarily, only to allow backward compatibility + if MATCHING_STRICTNESS not in entity: + strictness = entity.get("parser_threshold", 1.0) + + entity[MATCHING_STRICTNESS] = strictness + + mandatory_keys = [USE_SYNONYMS, AUTOMATICALLY_EXTENSIBLE, DATA, + MATCHING_STRICTNESS] + validate_keys(entity, mandatory_keys, object_label="custom entity") + validate_type(entity[USE_SYNONYMS], bool, object_label="use_synonyms") + validate_type(entity[AUTOMATICALLY_EXTENSIBLE], bool, + object_label="automatically_extensible") + validate_type(entity[DATA], list, object_label="entity data") + validate_type(entity[MATCHING_STRICTNESS], (float, int), + object_label="matching_strictness") + + formatted_entity = dict() + formatted_entity[AUTOMATICALLY_EXTENSIBLE] = entity[ + AUTOMATICALLY_EXTENSIBLE] + formatted_entity[MATCHING_STRICTNESS] = entity[MATCHING_STRICTNESS] + if LICENSE_INFO in entity: + formatted_entity[LICENSE_INFO] = entity[LICENSE_INFO] + use_synonyms = entity[USE_SYNONYMS] + + # Validate format and filter out unused data + valid_entity_data = [] + for entry in entity[DATA]: + validate_type(entry, dict, object_label="entity entry") + validate_keys(entry, [VALUE, SYNONYMS], object_label="entity entry") + entry[VALUE] = entry[VALUE].strip() + if not entry[VALUE]: + continue + validate_type(entry[SYNONYMS], list, object_label="entity synonyms") + entry[SYNONYMS] = [s.strip() for s in entry[SYNONYMS] if s.strip()] + valid_entity_data.append(entry) + entity[DATA] = valid_entity_data + + # Compute capitalization before normalizing + # Normalization lowercase and hence lead to bad capitalization calculation + formatted_entity[CAPITALIZE] = _has_any_capitalization(utterance_entities, + language) + + validated_utterances = dict() + # Map original values an synonyms + for data in entity[DATA]: + ent_value = data[VALUE] + validated_utterances[ent_value] = ent_value + if use_synonyms: + for s in data[SYNONYMS]: + if s not in validated_utterances: + validated_utterances[s] = ent_value + + # Number variations in entities values are expensive since each entity + # value is parsed with the builtin entity parser before creating the + # variations. We avoid generating these variations if there's enough entity + # values + + # Add variations if not colliding + all_original_values = _extract_entity_values(entity) + if len(entity[DATA]) < VARIATIONS_GENERATION_THRESHOLD: + variations_args = { + "case": True, + "and_": True, + "punctuation": True + } + else: + variations_args = { + "case": False, + "and_": False, + "punctuation": False + } + + variations_args["numbers"] = len( + entity[DATA]) < NUMBER_VARIATIONS_THRESHOLD + + variations = dict() + for data in entity[DATA]: + ent_value = data[VALUE] + values_to_variate = {ent_value} + if use_synonyms: + values_to_variate.update(set(data[SYNONYMS])) + variations[ent_value] = set( + v for value in values_to_variate + for v in get_string_variations( + value, language, builtin_entity_parser, **variations_args) + ) + variation_counter = Counter( + [v for variations_ in itervalues(variations) for v in variations_]) + non_colliding_variations = { + value: [ + v for v in variations if + v not in all_original_values and variation_counter[v] == 1 + ] + for value, variations in iteritems(variations) + } + + for entry in entity[DATA]: + entry_value = entry[VALUE] + validated_utterances = _add_entity_variations( + validated_utterances, non_colliding_variations, entry_value) + + # Merge utterances entities + utterance_entities_variations = { + ent: get_string_variations( + ent, language, builtin_entity_parser, **variations_args) + for ent in utterance_entities + } + + for original_ent, variations in iteritems(utterance_entities_variations): + if not original_ent or original_ent in validated_utterances: + continue + validated_utterances[original_ent] = original_ent + for variation in variations: + if variation and variation not in validated_utterances \ + and variation not in utterance_entities: + validated_utterances[variation] = original_ent + formatted_entity[UTTERANCES] = validated_utterances + return formatted_entity + + +def _validate_and_format_builtin_entity(entity, utterance_entities): + validate_type(entity, dict, object_label="builtin entity") + return {UTTERANCES: set(utterance_entities)} diff --git a/snips_inference_agl/dataset/yaml_wrapper.py b/snips_inference_agl/dataset/yaml_wrapper.py new file mode 100644 index 0000000..ba8390d --- /dev/null +++ b/snips_inference_agl/dataset/yaml_wrapper.py @@ -0,0 +1,11 @@ +import yaml + + +def _construct_yaml_str(self, node): + # Override the default string handling function + # to always return unicode objects + return self.construct_scalar(node) + + +yaml.Loader.add_constructor("tag:yaml.org,2002:str", _construct_yaml_str) +yaml.SafeLoader.add_constructor("tag:yaml.org,2002:str", _construct_yaml_str) diff --git a/snips_inference_agl/default_configs/__init__.py b/snips_inference_agl/default_configs/__init__.py new file mode 100644 index 0000000..fc66d33 --- /dev/null +++ b/snips_inference_agl/default_configs/__init__.py @@ -0,0 +1,26 @@ +from __future__ import unicode_literals + +from snips_inference_agl.constants import ( + LANGUAGE_DE, LANGUAGE_EN, LANGUAGE_ES, LANGUAGE_FR, LANGUAGE_IT, + LANGUAGE_JA, LANGUAGE_KO, LANGUAGE_PT_BR, LANGUAGE_PT_PT) +from .config_de import CONFIG as CONFIG_DE +from .config_en import CONFIG as CONFIG_EN +from .config_es import CONFIG as CONFIG_ES +from .config_fr import CONFIG as CONFIG_FR +from .config_it import CONFIG as CONFIG_IT +from .config_ja import CONFIG as CONFIG_JA +from .config_ko import CONFIG as CONFIG_KO +from .config_pt_br import CONFIG as CONFIG_PT_BR +from .config_pt_pt import CONFIG as CONFIG_PT_PT + +DEFAULT_CONFIGS = { + LANGUAGE_DE: CONFIG_DE, + LANGUAGE_EN: CONFIG_EN, + LANGUAGE_ES: CONFIG_ES, + LANGUAGE_FR: CONFIG_FR, + LANGUAGE_IT: CONFIG_IT, + LANGUAGE_JA: CONFIG_JA, + LANGUAGE_KO: CONFIG_KO, + LANGUAGE_PT_BR: CONFIG_PT_BR, + LANGUAGE_PT_PT: CONFIG_PT_PT, +} diff --git a/snips_inference_agl/default_configs/config_de.py b/snips_inference_agl/default_configs/config_de.py new file mode 100644 index 0000000..200fc30 --- /dev/null +++ b/snips_inference_agl/default_configs/config_de.py @@ -0,0 +1,159 @@ +from __future__ import unicode_literals + +CONFIG = { + "unit_name": "nlu_engine", + "intent_parsers_configs": [ + { + "unit_name": "lookup_intent_parser", + "ignore_stop_words": True + }, + { + "unit_name": "probabilistic_intent_parser", + "slot_filler_config": { + "unit_name": "crf_slot_filler", + "feature_factory_configs": [ + { + "args": { + "common_words_gazetteer_name": + "top_200000_words_stemmed", + "use_stemming": True, + "n": 1 + }, + "factory_name": "ngram", + "offsets": [-2, -1, 0, 1, 2] + }, + { + "args": { + "common_words_gazetteer_name": + "top_200000_words_stemmed", + "use_stemming": True, + "n": 2 + }, + "factory_name": "ngram", + "offsets": [-2, 1] + }, + { + "args": { + "prefix_size": 2 + }, + "factory_name": "prefix", + "offsets": [0] + }, + { + "args": {"prefix_size": 5}, + "factory_name": "prefix", + "offsets": [0] + }, + { + "args": {"suffix_size": 2}, + "factory_name": "suffix", + "offsets": [0] + }, + { + "args": {"suffix_size": 5}, + "factory_name": "suffix", + "offsets": [0] + }, + { + "args": {}, + "factory_name": "is_digit", + "offsets": [-1, 0, 1] + }, + { + "args": {}, + "factory_name": "is_first", + "offsets": [-2, -1, 0] + }, + { + "args": {}, + "factory_name": "is_last", + "offsets": [0, 1, 2] + }, + { + "args": {"n": 1}, + "factory_name": "shape_ngram", + "offsets": [0] + }, + { + "args": {"n": 2}, + "factory_name": "shape_ngram", + "offsets": [-1, 0] + }, + { + "args": {"n": 3}, + "factory_name": "shape_ngram", + "offsets": [-1] + }, + { + "args": { + "use_stemming": True, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": False + } + }, + "factory_name": "entity_match", + "offsets": [-2, -1, 0] + }, + { + "args": { + "use_stemming": True, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": True + } + }, + "factory_name": "entity_match", + "offsets": [-2, -1, 0], + "drop_out": 0.5 + }, + { + "args": {"tagging_scheme_code": 1}, + "factory_name": "builtin_entity_match", + "offsets": [-2, -1, 0] + } + ], + "crf_args": { + "c1": 0.1, + "c2": 0.1, + "algorithm": "lbfgs" + }, + "tagging_scheme": 1, + "data_augmentation_config": { + "min_utterances": 200, + "capitalization_ratio": 0.2, + "add_builtin_entities_examples": True + } + }, + "intent_classifier_config": { + "unit_name": "log_reg_intent_classifier", + "data_augmentation_config": { + "min_utterances": 20, + "noise_factor": 5, + "add_builtin_entities_examples": False, + "max_unknown_words": None, + "unknown_word_prob": 0.0, + "unknown_words_replacement_string": None + }, + "featurizer_config": { + "unit_name": "featurizer", + "pvalue_threshold": 0.4, + "added_cooccurrence_feature_ratio": 0.0, + "tfidf_vectorizer_config": { + "unit_name": "tfidf_vectorizer", + "use_stemming": True, + "word_clusters_name": None + }, + "cooccurrence_vectorizer_config": { + "unit_name": "cooccurrence_vectorizer", + "window_size": None, + "filter_stop_words": True, + "unknown_words_replacement_string": None, + "keep_order": True + } + }, + "noise_reweight_factor": 1, + } + } + ] +} diff --git a/snips_inference_agl/default_configs/config_en.py b/snips_inference_agl/default_configs/config_en.py new file mode 100644 index 0000000..12f7ae1 --- /dev/null +++ b/snips_inference_agl/default_configs/config_en.py @@ -0,0 +1,145 @@ +from __future__ import unicode_literals + +CONFIG = { + "unit_name": "nlu_engine", + "intent_parsers_configs": [ + { + "unit_name": "lookup_intent_parser", + "ignore_stop_words": True + }, + { + "unit_name": "probabilistic_intent_parser", + "slot_filler_config": { + "unit_name": "crf_slot_filler", + "feature_factory_configs": [ + { + "args": { + "common_words_gazetteer_name": + "top_10000_words_stemmed", + "use_stemming": True, + "n": 1 + }, + "factory_name": "ngram", + "offsets": [-2, -1, 0, 1, 2] + }, + { + "args": { + "common_words_gazetteer_name": + "top_10000_words_stemmed", + "use_stemming": True, + "n": 2 + }, + "factory_name": "ngram", + "offsets": [-2, 1] + }, + { + "args": {}, + "factory_name": "is_digit", + "offsets": [-1, 0, 1] + }, + { + "args": {}, + "factory_name": "is_first", + "offsets": [-2, -1, 0] + }, + { + "args": {}, + "factory_name": "is_last", + "offsets": [0, 1, 2] + }, + { + "args": {"n": 1}, + "factory_name": "shape_ngram", + "offsets": [0] + }, + { + "args": {"n": 2}, + "factory_name": "shape_ngram", + "offsets": [-1, 0] + }, + { + "args": {"n": 3}, + "factory_name": "shape_ngram", + "offsets": [-1] + }, + { + "args": { + "use_stemming": True, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": False + } + }, + "factory_name": "entity_match", + "offsets": [-2, -1, 0] + }, + { + "args": { + "use_stemming": True, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": True + } + }, + "factory_name": "entity_match", + "offsets": [-2, -1, 0], + "drop_out": 0.5 + }, + { + "args": {"tagging_scheme_code": 1}, + "factory_name": "builtin_entity_match", + "offsets": [-2, -1, 0] + }, + { + "args": { + "cluster_name": "brown_clusters", + "use_stemming": False + }, + "factory_name": "word_cluster", + "offsets": [-2, -1, 0, 1] + } + ], + "crf_args": { + "c1": 0.1, + "c2": 0.1, + "algorithm": "lbfgs" + }, + "tagging_scheme": 1, + "data_augmentation_config": { + "min_utterances": 200, + "capitalization_ratio": 0.2, + "add_builtin_entities_examples": True + } + }, + "intent_classifier_config": { + "unit_name": "log_reg_intent_classifier", + "data_augmentation_config": { + "min_utterances": 20, + "noise_factor": 5, + "add_builtin_entities_examples": False, + "max_unknown_words": None, + "unknown_word_prob": 0.0, + "unknown_words_replacement_string": None + }, + "featurizer_config": { + "unit_name": "featurizer", + "pvalue_threshold": 0.4, + "added_cooccurrence_feature_ratio": 0.0, + "tfidf_vectorizer_config": { + "unit_name": "tfidf_vectorizer", + "use_stemming": False, + "word_clusters_name": None + }, + "cooccurrence_vectorizer_config": { + "unit_name": "cooccurrence_vectorizer", + "window_size": None, + "filter_stop_words": True, + "unknown_words_replacement_string": None, + "keep_order": True + } + }, + "noise_reweight_factor": 1, + } + } + ] +} diff --git a/snips_inference_agl/default_configs/config_es.py b/snips_inference_agl/default_configs/config_es.py new file mode 100644 index 0000000..28969ce --- /dev/null +++ b/snips_inference_agl/default_configs/config_es.py @@ -0,0 +1,138 @@ +from __future__ import unicode_literals + +CONFIG = { + "unit_name": "nlu_engine", + "intent_parsers_configs": [ + { + "unit_name": "lookup_intent_parser", + "ignore_stop_words": True + }, + { + "unit_name": "probabilistic_intent_parser", + "slot_filler_config": { + "unit_name": "crf_slot_filler", + "feature_factory_configs": [ + { + "args": { + "common_words_gazetteer_name": + "top_10000_words_stemmed", + "use_stemming": True, + "n": 1 + }, + "factory_name": "ngram", + "offsets": [-2, -1, 0, 1, 2] + }, + { + "args": { + "common_words_gazetteer_name": + "top_10000_words_stemmed", + "use_stemming": True, + "n": 2 + }, + "factory_name": "ngram", + "offsets": [-2, 1] + }, + { + "args": {}, + "factory_name": "is_digit", + "offsets": [-1, 0, 1] + }, + { + "args": {}, + "factory_name": "is_first", + "offsets": [-2, -1, 0] + }, + { + "args": {}, + "factory_name": "is_last", + "offsets": [0, 1, 2] + }, + { + "args": {"n": 1}, + "factory_name": "shape_ngram", + "offsets": [0] + }, + { + "args": {"n": 2}, + "factory_name": "shape_ngram", + "offsets": [-1, 0] + }, + { + "args": {"n": 3}, + "factory_name": "shape_ngram", + "offsets": [-1] + }, + { + "args": { + "use_stemming": True, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": False + } + }, + "factory_name": "entity_match", + "offsets": [-2, -1, 0] + }, + { + "args": { + "use_stemming": True, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": True + } + }, + "factory_name": "entity_match", + "offsets": [-2, -1, 0], + "drop_out": 0.5 + }, + { + "args": {"tagging_scheme_code": 1}, + "factory_name": "builtin_entity_match", + "offsets": [-2, -1, 0] + } + ], + "crf_args": { + "c1": 0.1, + "c2": 0.1, + "algorithm": "lbfgs" + }, + "tagging_scheme": 1, + "data_augmentation_config": { + "min_utterances": 200, + "capitalization_ratio": 0.2, + "add_builtin_entities_examples": True + }, + + }, + "intent_classifier_config": { + "unit_name": "log_reg_intent_classifier", + "data_augmentation_config": { + "min_utterances": 20, + "noise_factor": 5, + "add_builtin_entities_examples": False, + "max_unknown_words": None, + "unknown_word_prob": 0.0, + "unknown_words_replacement_string": None + }, + "featurizer_config": { + "unit_name": "featurizer", + "pvalue_threshold": 0.4, + "added_cooccurrence_feature_ratio": 0.0, + "tfidf_vectorizer_config": { + "unit_name": "tfidf_vectorizer", + "use_stemming": True, + "word_clusters_name": None + }, + "cooccurrence_vectorizer_config": { + "unit_name": "cooccurrence_vectorizer", + "window_size": None, + "filter_stop_words": True, + "unknown_words_replacement_string": None, + "keep_order": True + } + }, + "noise_reweight_factor": 1, + } + } + ] +} diff --git a/snips_inference_agl/default_configs/config_fr.py b/snips_inference_agl/default_configs/config_fr.py new file mode 100644 index 0000000..a2da590 --- /dev/null +++ b/snips_inference_agl/default_configs/config_fr.py @@ -0,0 +1,137 @@ +from __future__ import unicode_literals + +CONFIG = { + "unit_name": "nlu_engine", + "intent_parsers_configs": [ + { + "unit_name": "lookup_intent_parser", + "ignore_stop_words": True + }, + { + "unit_name": "probabilistic_intent_parser", + "slot_filler_config": { + "unit_name": "crf_slot_filler", + "feature_factory_configs": [ + { + "args": { + "common_words_gazetteer_name": + "top_10000_words_stemmed", + "use_stemming": True, + "n": 1 + }, + "factory_name": "ngram", + "offsets": [-2, -1, 0, 1, 2] + }, + { + "args": { + "common_words_gazetteer_name": + "top_10000_words_stemmed", + "use_stemming": True, + "n": 2 + }, + "factory_name": "ngram", + "offsets": [-2, 1] + }, + { + "args": {}, + "factory_name": "is_digit", + "offsets": [-1, 0, 1] + }, + { + "args": {}, + "factory_name": "is_first", + "offsets": [-2, -1, 0] + }, + { + "args": {}, + "factory_name": "is_last", + "offsets": [0, 1, 2] + }, + { + "args": {"n": 1}, + "factory_name": "shape_ngram", + "offsets": [0] + }, + { + "args": {"n": 2}, + "factory_name": "shape_ngram", + "offsets": [-1, 0] + }, + { + "args": {"n": 3}, + "factory_name": "shape_ngram", + "offsets": [-1] + }, + { + "args": { + "use_stemming": True, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": False + } + }, + "factory_name": "entity_match", + "offsets": [-2, -1, 0] + }, + { + "args": { + "use_stemming": True, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": True + } + }, + "factory_name": "entity_match", + "offsets": [-2, -1, 0], + "drop_out": 0.5 + }, + { + "args": {"tagging_scheme_code": 1}, + "factory_name": "builtin_entity_match", + "offsets": [-2, -1, 0] + } + ], + "crf_args": { + "c1": 0.1, + "c2": 0.1, + "algorithm": "lbfgs" + }, + "tagging_scheme": 1, + "data_augmentation_config": { + "min_utterances": 200, + "capitalization_ratio": 0.2, + "add_builtin_entities_examples": True + } + }, + "intent_classifier_config": { + "unit_name": "log_reg_intent_classifier", + "data_augmentation_config": { + "min_utterances": 20, + "noise_factor": 5, + "add_builtin_entities_examples": False, + "max_unknown_words": None, + "unknown_word_prob": 0.0, + "unknown_words_replacement_string": None + }, + "featurizer_config": { + "unit_name": "featurizer", + "pvalue_threshold": 0.4, + "added_cooccurrence_feature_ratio": 0.0, + "tfidf_vectorizer_config": { + "unit_name": "tfidf_vectorizer", + "use_stemming": True, + "word_clusters_name": None + }, + "cooccurrence_vectorizer_config": { + "unit_name": "cooccurrence_vectorizer", + "window_size": None, + "filter_stop_words": True, + "unknown_words_replacement_string": None, + "keep_order": True + } + }, + "noise_reweight_factor": 1, + } + } + ] +} diff --git a/snips_inference_agl/default_configs/config_it.py b/snips_inference_agl/default_configs/config_it.py new file mode 100644 index 0000000..a2da590 --- /dev/null +++ b/snips_inference_agl/default_configs/config_it.py @@ -0,0 +1,137 @@ +from __future__ import unicode_literals + +CONFIG = { + "unit_name": "nlu_engine", + "intent_parsers_configs": [ + { + "unit_name": "lookup_intent_parser", + "ignore_stop_words": True + }, + { + "unit_name": "probabilistic_intent_parser", + "slot_filler_config": { + "unit_name": "crf_slot_filler", + "feature_factory_configs": [ + { + "args": { + "common_words_gazetteer_name": + "top_10000_words_stemmed", + "use_stemming": True, + "n": 1 + }, + "factory_name": "ngram", + "offsets": [-2, -1, 0, 1, 2] + }, + { + "args": { + "common_words_gazetteer_name": + "top_10000_words_stemmed", + "use_stemming": True, + "n": 2 + }, + "factory_name": "ngram", + "offsets": [-2, 1] + }, + { + "args": {}, + "factory_name": "is_digit", + "offsets": [-1, 0, 1] + }, + { + "args": {}, + "factory_name": "is_first", + "offsets": [-2, -1, 0] + }, + { + "args": {}, + "factory_name": "is_last", + "offsets": [0, 1, 2] + }, + { + "args": {"n": 1}, + "factory_name": "shape_ngram", + "offsets": [0] + }, + { + "args": {"n": 2}, + "factory_name": "shape_ngram", + "offsets": [-1, 0] + }, + { + "args": {"n": 3}, + "factory_name": "shape_ngram", + "offsets": [-1] + }, + { + "args": { + "use_stemming": True, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": False + } + }, + "factory_name": "entity_match", + "offsets": [-2, -1, 0] + }, + { + "args": { + "use_stemming": True, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": True + } + }, + "factory_name": "entity_match", + "offsets": [-2, -1, 0], + "drop_out": 0.5 + }, + { + "args": {"tagging_scheme_code": 1}, + "factory_name": "builtin_entity_match", + "offsets": [-2, -1, 0] + } + ], + "crf_args": { + "c1": 0.1, + "c2": 0.1, + "algorithm": "lbfgs" + }, + "tagging_scheme": 1, + "data_augmentation_config": { + "min_utterances": 200, + "capitalization_ratio": 0.2, + "add_builtin_entities_examples": True + } + }, + "intent_classifier_config": { + "unit_name": "log_reg_intent_classifier", + "data_augmentation_config": { + "min_utterances": 20, + "noise_factor": 5, + "add_builtin_entities_examples": False, + "max_unknown_words": None, + "unknown_word_prob": 0.0, + "unknown_words_replacement_string": None + }, + "featurizer_config": { + "unit_name": "featurizer", + "pvalue_threshold": 0.4, + "added_cooccurrence_feature_ratio": 0.0, + "tfidf_vectorizer_config": { + "unit_name": "tfidf_vectorizer", + "use_stemming": True, + "word_clusters_name": None + }, + "cooccurrence_vectorizer_config": { + "unit_name": "cooccurrence_vectorizer", + "window_size": None, + "filter_stop_words": True, + "unknown_words_replacement_string": None, + "keep_order": True + } + }, + "noise_reweight_factor": 1, + } + } + ] +} diff --git a/snips_inference_agl/default_configs/config_ja.py b/snips_inference_agl/default_configs/config_ja.py new file mode 100644 index 0000000..b28791f --- /dev/null +++ b/snips_inference_agl/default_configs/config_ja.py @@ -0,0 +1,164 @@ +from __future__ import unicode_literals + +CONFIG = { + "unit_name": "nlu_engine", + "intent_parsers_configs": [ + { + "unit_name": "lookup_intent_parser", + "ignore_stop_words": False + }, + { + "unit_name": "probabilistic_intent_parser", + "slot_filler_config": { + "unit_name": "crf_slot_filler", + "feature_factory_configs": [ + { + "args": { + "common_words_gazetteer_name": None, + "use_stemming": False, + "n": 1 + }, + "factory_name": "ngram", + "offsets": [-2, -1, 0, 1, 2] + }, + { + "args": { + "common_words_gazetteer_name": None, + "use_stemming": False, + "n": 2 + }, + "factory_name": "ngram", + "offsets": [-2, 0, 1, 2] + }, + { + "args": {"prefix_size": 1}, + "factory_name": "prefix", + "offsets": [0, 1] + }, + { + "args": {"prefix_size": 2}, + "factory_name": "prefix", + "offsets": [0, 1] + }, + { + "args": {"suffix_size": 1}, + "factory_name": "suffix", + "offsets": [0, 1] + }, + { + "args": {"suffix_size": 2}, + "factory_name": "suffix", + "offsets": [0, 1] + }, + { + "args": {}, + "factory_name": "is_digit", + "offsets": [-1, 0, 1] + }, + { + "args": {}, + "factory_name": "is_first", + "offsets": [-2, -1, 0] + }, + { + "args": {}, + "factory_name": "is_last", + "offsets": [0, 1, 2] + }, + { + "args": {"n": 1}, + "factory_name": "shape_ngram", + "offsets": [0] + }, + { + "args": {"n": 2}, + "factory_name": "shape_ngram", + "offsets": [-1, 0] + }, + { + "args": {"n": 3}, + "factory_name": "shape_ngram", + "offsets": [-1] + }, + { + "args": { + "use_stemming": False, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": False + } + }, + "factory_name": "entity_match", + "offsets": [-1, 0, 1, 2], + }, + { + "args": { + "use_stemming": False, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": False + } + }, + "factory_name": "entity_match", + "offsets": [-1, 0, 1, 2], + "drop_out": 0.5 + }, + { + "args": {"tagging_scheme_code": 1}, + "factory_name": "builtin_entity_match", + "offsets": [-1, 0, 1, 2] + }, + { + "args": { + "cluster_name": "w2v_clusters", + "use_stemming": False + }, + "factory_name": "word_cluster", + "offsets": [-2, -1, 0, 1, 2] + } + ], + "crf_args": { + "c1": 0.1, + "c2": 0.1, + "algorithm": "lbfgs" + }, + "tagging_scheme": 1, + "data_augmentation_config": { + "min_utterances": 200, + "capitalization_ratio": 0.2, + "add_builtin_entities_examples": True + }, + + }, + "intent_classifier_config": { + "unit_name": "log_reg_intent_classifier", + "data_augmentation_config": { + "min_utterances": 20, + "noise_factor": 5, + "add_builtin_entities_examples": False, + "max_unknown_words": None, + "unknown_word_prob": 0.0, + "unknown_words_replacement_string": None + }, + "featurizer_config": { + "unit_name": "featurizer", + "pvalue_threshold": 0.9, + "added_cooccurrence_feature_ratio": 0.0, + "tfidf_vectorizer_config": { + "unit_name": "tfidf_vectorizer", + "use_stemming": False, + "word_clusters_name": "w2v_clusters" + }, + "cooccurrence_vectorizer_config": { + "unit_name": "cooccurrence_vectorizer", + "window_size": None, + "filter_stop_words": True, + "unknown_words_replacement_string": None, + "keep_order": True + } + }, + "noise_reweight_factor": 1, + } + } + ] +} diff --git a/snips_inference_agl/default_configs/config_ko.py b/snips_inference_agl/default_configs/config_ko.py new file mode 100644 index 0000000..1630796 --- /dev/null +++ b/snips_inference_agl/default_configs/config_ko.py @@ -0,0 +1,155 @@ +from __future__ import unicode_literals + +CONFIG = { + "unit_name": "nlu_engine", + "intent_parsers_configs": [ + { + "unit_name": "lookup_intent_parser", + "ignore_stop_words": False + }, + { + "unit_name": "probabilistic_intent_parser", + "slot_filler_config": { + "unit_name": "crf_slot_filler", + "feature_factory_configs": [ + { + "args": { + "common_words_gazetteer_name": None, + "use_stemming": False, + "n": 1 + }, + "factory_name": "ngram", + "offsets": [-2, -1, 0, 1, 2] + }, + { + "args": { + "common_words_gazetteer_name": None, + "use_stemming": False, + "n": 2 + }, + "factory_name": "ngram", + "offsets": [-2, 1] + }, + { + "args": {"prefix_size": 1}, + "factory_name": "prefix", + "offsets": [0] + }, + { + "args": {"prefix_size": 2}, + "factory_name": "prefix", + "offsets": [0] + }, + { + "args": {"suffix_size": 1}, + "factory_name": "suffix", + "offsets": [0] + }, + { + "args": {"suffix_size": 2}, + "factory_name": "suffix", + "offsets": [0] + }, + { + "args": {}, + "factory_name": "is_digit", + "offsets": [-1, 0, 1] + }, + { + "args": {}, + "factory_name": "is_first", + "offsets": [-2, -1, 0] + }, + { + "args": {}, + "factory_name": "is_last", + "offsets": [0, 1, 2] + }, + { + "args": {"n": 1}, + "factory_name": "shape_ngram", + "offsets": [0] + }, + { + "args": {"n": 2}, + "factory_name": "shape_ngram", + "offsets": [-1, 0] + }, + { + "args": {"n": 3}, + "factory_name": "shape_ngram", + "offsets": [-1] + }, + { + "args": { + "use_stemming": False, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": False + } + }, + "factory_name": "entity_match", + "offsets": [-2, -1, 0] + }, + { + "args": { + "use_stemming": False, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": True + } + }, + "factory_name": "entity_match", + "offsets": [-2, -1, 0], + "drop_out": 0.5 + }, + { + "args": {"tagging_scheme_code": 1}, + "factory_name": "builtin_entity_match", + "offsets": [-2, -1, 0] + } + ], + "crf_args": { + "c1": 0.1, + "c2": 0.1, + "algorithm": "lbfgs" + }, + "tagging_scheme": 1, + "data_augmentation_config": { + "min_utterances": 200, + "capitalization_ratio": 0.2, + "add_builtin_entities_examples": True + } + }, + "intent_classifier_config": { + "unit_name": "log_reg_intent_classifier", + "data_augmentation_config": { + "min_utterances": 20, + "noise_factor": 5, + "add_builtin_entities_examples": False, + "max_unknown_words": None, + "unknown_word_prob": 0.0, + "unknown_words_replacement_string": None + }, + "featurizer_config": { + "unit_name": "featurizer", + "pvalue_threshold": 0.4, + "added_cooccurrence_feature_ratio": 0.0, + "tfidf_vectorizer_config": { + "unit_name": "tfidf_vectorizer", + "use_stemming": False, + "word_clusters_name": None + }, + "cooccurrence_vectorizer_config": { + "unit_name": "cooccurrence_vectorizer", + "window_size": None, + "filter_stop_words": True, + "unknown_words_replacement_string": None, + "keep_order": True + } + }, + "noise_reweight_factor": 1, + } + } + ] +} diff --git a/snips_inference_agl/default_configs/config_pt_br.py b/snips_inference_agl/default_configs/config_pt_br.py new file mode 100644 index 0000000..450f0db --- /dev/null +++ b/snips_inference_agl/default_configs/config_pt_br.py @@ -0,0 +1,137 @@ +from __future__ import unicode_literals + +CONFIG = { + "unit_name": "nlu_engine", + "intent_parsers_configs": [ + { + "unit_name": "lookup_intent_parser", + "ignore_stop_words": True + }, + { + "unit_name": "probabilistic_intent_parser", + "slot_filler_config": { + "unit_name": "crf_slot_filler", + "feature_factory_configs": [ + { + "args": { + "common_words_gazetteer_name": + "top_5000_words_stemmed", + "use_stemming": True, + "n": 1 + }, + "factory_name": "ngram", + "offsets": [-2, -1, 0, 1, 2] + }, + { + "args": { + "common_words_gazetteer_name": + "top_5000_words_stemmed", + "use_stemming": True, + "n": 2 + }, + "factory_name": "ngram", + "offsets": [-2, 1] + }, + { + "args": {}, + "factory_name": "is_digit", + "offsets": [-1, 0, 1] + }, + { + "args": {}, + "factory_name": "is_first", + "offsets": [-2, -1, 0] + }, + { + "args": {}, + "factory_name": "is_last", + "offsets": [0, 1, 2] + }, + { + "args": {"n": 1}, + "factory_name": "shape_ngram", + "offsets": [0] + }, + { + "args": {"n": 2}, + "factory_name": "shape_ngram", + "offsets": [-1, 0] + }, + { + "args": {"n": 3}, + "factory_name": "shape_ngram", + "offsets": [-1] + }, + { + "args": { + "use_stemming": True, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": False + } + }, + "factory_name": "entity_match", + "offsets": [-2, -1, 0] + }, + { + "args": { + "use_stemming": True, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": True + } + }, + "factory_name": "entity_match", + "offsets": [-2, -1, 0], + "drop_out": 0.5 + }, + { + "args": {"tagging_scheme_code": 1}, + "factory_name": "builtin_entity_match", + "offsets": [-2, -1, 0] + } + ], + "crf_args": { + "c1": 0.1, + "c2": 0.1, + "algorithm": "lbfgs" + }, + "tagging_scheme": 1, + "data_augmentation_config": { + "min_utterances": 200, + "capitalization_ratio": 0.2, + "add_builtin_entities_examples": True + }, + }, + "intent_classifier_config": { + "unit_name": "log_reg_intent_classifier", + "data_augmentation_config": { + "min_utterances": 20, + "noise_factor": 5, + "add_builtin_entities_examples": False, + "max_unknown_words": None, + "unknown_word_prob": 0.0, + "unknown_words_replacement_string": None + }, + "featurizer_config": { + "unit_name": "featurizer", + "pvalue_threshold": 0.4, + "added_cooccurrence_feature_ratio": 0.0, + "tfidf_vectorizer_config": { + "unit_name": "tfidf_vectorizer", + "use_stemming": True, + "word_clusters_name": None + }, + "cooccurrence_vectorizer_config": { + "unit_name": "cooccurrence_vectorizer", + "window_size": None, + "filter_stop_words": True, + "unknown_words_replacement_string": None, + "keep_order": True + } + }, + }, + "noise_reweight_factor": 1, + } + ] +} diff --git a/snips_inference_agl/default_configs/config_pt_pt.py b/snips_inference_agl/default_configs/config_pt_pt.py new file mode 100644 index 0000000..450f0db --- /dev/null +++ b/snips_inference_agl/default_configs/config_pt_pt.py @@ -0,0 +1,137 @@ +from __future__ import unicode_literals + +CONFIG = { + "unit_name": "nlu_engine", + "intent_parsers_configs": [ + { + "unit_name": "lookup_intent_parser", + "ignore_stop_words": True + }, + { + "unit_name": "probabilistic_intent_parser", + "slot_filler_config": { + "unit_name": "crf_slot_filler", + "feature_factory_configs": [ + { + "args": { + "common_words_gazetteer_name": + "top_5000_words_stemmed", + "use_stemming": True, + "n": 1 + }, + "factory_name": "ngram", + "offsets": [-2, -1, 0, 1, 2] + }, + { + "args": { + "common_words_gazetteer_name": + "top_5000_words_stemmed", + "use_stemming": True, + "n": 2 + }, + "factory_name": "ngram", + "offsets": [-2, 1] + }, + { + "args": {}, + "factory_name": "is_digit", + "offsets": [-1, 0, 1] + }, + { + "args": {}, + "factory_name": "is_first", + "offsets": [-2, -1, 0] + }, + { + "args": {}, + "factory_name": "is_last", + "offsets": [0, 1, 2] + }, + { + "args": {"n": 1}, + "factory_name": "shape_ngram", + "offsets": [0] + }, + { + "args": {"n": 2}, + "factory_name": "shape_ngram", + "offsets": [-1, 0] + }, + { + "args": {"n": 3}, + "factory_name": "shape_ngram", + "offsets": [-1] + }, + { + "args": { + "use_stemming": True, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": False + } + }, + "factory_name": "entity_match", + "offsets": [-2, -1, 0] + }, + { + "args": { + "use_stemming": True, + "tagging_scheme_code": 2, + "entity_filter": { + "automatically_extensible": True + } + }, + "factory_name": "entity_match", + "offsets": [-2, -1, 0], + "drop_out": 0.5 + }, + { + "args": {"tagging_scheme_code": 1}, + "factory_name": "builtin_entity_match", + "offsets": [-2, -1, 0] + } + ], + "crf_args": { + "c1": 0.1, + "c2": 0.1, + "algorithm": "lbfgs" + }, + "tagging_scheme": 1, + "data_augmentation_config": { + "min_utterances": 200, + "capitalization_ratio": 0.2, + "add_builtin_entities_examples": True + }, + }, + "intent_classifier_config": { + "unit_name": "log_reg_intent_classifier", + "data_augmentation_config": { + "min_utterances": 20, + "noise_factor": 5, + "add_builtin_entities_examples": False, + "max_unknown_words": None, + "unknown_word_prob": 0.0, + "unknown_words_replacement_string": None + }, + "featurizer_config": { + "unit_name": "featurizer", + "pvalue_threshold": 0.4, + "added_cooccurrence_feature_ratio": 0.0, + "tfidf_vectorizer_config": { + "unit_name": "tfidf_vectorizer", + "use_stemming": True, + "word_clusters_name": None + }, + "cooccurrence_vectorizer_config": { + "unit_name": "cooccurrence_vectorizer", + "window_size": None, + "filter_stop_words": True, + "unknown_words_replacement_string": None, + "keep_order": True + } + }, + }, + "noise_reweight_factor": 1, + } + ] +} diff --git a/snips_inference_agl/entity_parser/__init__.py b/snips_inference_agl/entity_parser/__init__.py new file mode 100644 index 0000000..c54f0b2 --- /dev/null +++ b/snips_inference_agl/entity_parser/__init__.py @@ -0,0 +1,6 @@ +# coding=utf-8 +from __future__ import unicode_literals + +from snips_inference_agl.entity_parser.builtin_entity_parser import BuiltinEntityParser +from snips_inference_agl.entity_parser.custom_entity_parser import ( + CustomEntityParser, CustomEntityParserUsage) diff --git a/snips_inference_agl/entity_parser/builtin_entity_parser.py b/snips_inference_agl/entity_parser/builtin_entity_parser.py new file mode 100644 index 0000000..02fa610 --- /dev/null +++ b/snips_inference_agl/entity_parser/builtin_entity_parser.py @@ -0,0 +1,162 @@ +from __future__ import unicode_literals + +import json +import shutil + +from future.builtins import str + +from snips_inference_agl.common.io_utils import temp_dir +from snips_inference_agl.common.utils import json_string +from snips_inference_agl.constants import DATA_PATH, ENTITIES, LANGUAGE +from snips_inference_agl.entity_parser.entity_parser import EntityParser +from snips_inference_agl.result import parsed_entity + +_BUILTIN_ENTITY_PARSERS = dict() + +try: + FileNotFoundError +except NameError: + FileNotFoundError = IOError + + +class BuiltinEntityParser(EntityParser): + def __init__(self, parser): + super(BuiltinEntityParser, self).__init__() + self._parser = parser + + def _parse(self, text, scope=None): + entities = self._parser.parse(text.lower(), scope=scope) + result = [] + for entity in entities: + ent = parsed_entity( + entity_kind=entity["entity_kind"], + entity_value=entity["value"], + entity_resolved_value=entity["entity"], + entity_range=entity["range"] + ) + result.append(ent) + return result + + def persist(self, path): + self._parser.persist(path) + + @classmethod + def from_path(cls, path): + from snips_nlu_parsers import ( + BuiltinEntityParser as _BuiltinEntityParser) + + parser = _BuiltinEntityParser.from_path(path) + return cls(parser) + + @classmethod + def build(cls, dataset=None, language=None, gazetteer_entity_scope=None): + from snips_nlu_parsers import get_supported_gazetteer_entities + + global _BUILTIN_ENTITY_PARSERS + + if dataset is not None: + language = dataset[LANGUAGE] + gazetteer_entity_scope = [entity for entity in dataset[ENTITIES] + if is_gazetteer_entity(entity)] + + if language is None: + raise ValueError("Either a dataset or a language must be provided " + "in order to build a BuiltinEntityParser") + + if gazetteer_entity_scope is None: + gazetteer_entity_scope = [] + caching_key = _get_caching_key(language, gazetteer_entity_scope) + if caching_key not in _BUILTIN_ENTITY_PARSERS: + for entity in gazetteer_entity_scope: + if entity not in get_supported_gazetteer_entities(language): + raise ValueError( + "Gazetteer entity '%s' is not supported in " + "language '%s'" % (entity, language)) + _BUILTIN_ENTITY_PARSERS[caching_key] = _build_builtin_parser( + language, gazetteer_entity_scope) + return _BUILTIN_ENTITY_PARSERS[caching_key] + + +def _build_builtin_parser(language, gazetteer_entities): + from snips_nlu_parsers import BuiltinEntityParser as _BuiltinEntityParser + + with temp_dir() as serialization_dir: + gazetteer_entity_parser = None + if gazetteer_entities: + gazetteer_entity_parser = _build_gazetteer_parser( + serialization_dir, gazetteer_entities, language) + + metadata = { + "language": language.upper(), + "gazetteer_parser": gazetteer_entity_parser + } + metadata_path = serialization_dir / "metadata.json" + with metadata_path.open("w", encoding="utf-8") as f: + f.write(json_string(metadata)) + parser = _BuiltinEntityParser.from_path(serialization_dir) + return BuiltinEntityParser(parser) + + +def _build_gazetteer_parser(target_dir, gazetteer_entities, language): + from snips_nlu_parsers import get_builtin_entity_shortname + + gazetteer_parser_name = "gazetteer_entity_parser" + gazetteer_parser_path = target_dir / gazetteer_parser_name + gazetteer_parser_metadata = [] + for ent in sorted(gazetteer_entities): + # Fetch the compiled parser in the resources + source_parser_path = find_gazetteer_entity_data_path(language, ent) + short_name = get_builtin_entity_shortname(ent).lower() + target_parser_path = gazetteer_parser_path / short_name + parser_metadata = { + "entity_identifier": ent, + "entity_parser": short_name + } + gazetteer_parser_metadata.append(parser_metadata) + # Copy the single entity parser + shutil.copytree(str(source_parser_path), str(target_parser_path)) + # Dump the parser metadata + gazetteer_entity_parser_metadata = { + "parsers_metadata": gazetteer_parser_metadata + } + gazetteer_parser_metadata_path = gazetteer_parser_path / "metadata.json" + with gazetteer_parser_metadata_path.open("w", encoding="utf-8") as f: + f.write(json_string(gazetteer_entity_parser_metadata)) + return gazetteer_parser_name + + +def is_builtin_entity(entity_label): + from snips_nlu_parsers import get_all_builtin_entities + + return entity_label in get_all_builtin_entities() + + +def is_gazetteer_entity(entity_label): + from snips_nlu_parsers import get_all_gazetteer_entities + + return entity_label in get_all_gazetteer_entities() + + +def find_gazetteer_entity_data_path(language, entity_name): + for directory in DATA_PATH.iterdir(): + if not directory.is_dir(): + continue + metadata_path = directory / "metadata.json" + if not metadata_path.exists(): + continue + with metadata_path.open(encoding="utf8") as f: + metadata = json.load(f) + if metadata.get("entity_name") == entity_name \ + and metadata.get("language") == language: + return directory / metadata["data_directory"] + raise FileNotFoundError( + "No data found for the '{e}' builtin entity in language '{lang}'. " + "You must download the corresponding resources by running " + "'python -m snips_nlu download-entity {e} {lang}' before you can use " + "this builtin entity.".format(e=entity_name, lang=language)) + + +def _get_caching_key(language, entity_scope): + tuple_key = (language,) + tuple_key += tuple(entity for entity in sorted(entity_scope)) + return tuple_key diff --git a/snips_inference_agl/entity_parser/custom_entity_parser.py b/snips_inference_agl/entity_parser/custom_entity_parser.py new file mode 100644 index 0000000..949df1f --- /dev/null +++ b/snips_inference_agl/entity_parser/custom_entity_parser.py @@ -0,0 +1,209 @@ +# coding=utf-8 +from __future__ import unicode_literals + +import json +import operator +from copy import deepcopy +from pathlib import Path + +from future.utils import iteritems, viewvalues + +from snips_inference_agl.common.utils import json_string +from snips_inference_agl.constants import ( + END, ENTITIES, LANGUAGE, MATCHING_STRICTNESS, START, UTTERANCES, + LICENSE_INFO) +from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity +from snips_inference_agl.entity_parser.custom_entity_parser_usage import ( + CustomEntityParserUsage) +from snips_inference_agl.entity_parser.entity_parser import EntityParser +from snips_inference_agl.preprocessing import stem, tokenize, tokenize_light +from snips_inference_agl.result import parsed_entity + +STOPWORDS_FRACTION = 1e-3 + + +class CustomEntityParser(EntityParser): + def __init__(self, parser, language, parser_usage): + super(CustomEntityParser, self).__init__() + self._parser = parser + self.language = language + self.parser_usage = parser_usage + + def _parse(self, text, scope=None): + tokens = tokenize(text, self.language) + shifts = _compute_char_shifts(tokens) + cleaned_text = " ".join(token.value for token in tokens) + + entities = self._parser.parse(cleaned_text, scope) + result = [] + for entity in entities: + start = entity["range"]["start"] + start -= shifts[start] + end = entity["range"]["end"] + end -= shifts[end - 1] + entity_range = {START: start, END: end} + ent = parsed_entity( + entity_kind=entity["entity_identifier"], + entity_value=entity["value"], + entity_resolved_value=entity["resolved_value"], + entity_range=entity_range + ) + result.append(ent) + return result + + def persist(self, path): + path = Path(path) + path.mkdir() + parser_directory = "parser" + metadata = { + "language": self.language, + "parser_usage": self.parser_usage.value, + "parser_directory": parser_directory + } + with (path / "metadata.json").open(mode="w", encoding="utf8") as f: + f.write(json_string(metadata)) + self._parser.persist(path / parser_directory) + + @classmethod + def from_path(cls, path): + from snips_nlu_parsers import GazetteerEntityParser + + path = Path(path) + with (path / "metadata.json").open(encoding="utf8") as f: + metadata = json.load(f) + language = metadata["language"] + parser_usage = CustomEntityParserUsage(metadata["parser_usage"]) + parser_path = path / metadata["parser_directory"] + parser = GazetteerEntityParser.from_path(parser_path) + return cls(parser, language, parser_usage) + + @classmethod + def build(cls, dataset, parser_usage, resources): + from snips_nlu_parsers import GazetteerEntityParser + from snips_inference_agl.dataset import validate_and_format_dataset + + dataset = validate_and_format_dataset(dataset) + language = dataset[LANGUAGE] + custom_entities = { + entity_name: deepcopy(entity) + for entity_name, entity in iteritems(dataset[ENTITIES]) + if not is_builtin_entity(entity_name) + } + if parser_usage == CustomEntityParserUsage.WITH_AND_WITHOUT_STEMS: + for ent in viewvalues(custom_entities): + stemmed_utterances = _stem_entity_utterances( + ent[UTTERANCES], language, resources) + ent[UTTERANCES] = _merge_entity_utterances( + ent[UTTERANCES], stemmed_utterances) + elif parser_usage == CustomEntityParserUsage.WITH_STEMS: + for ent in viewvalues(custom_entities): + ent[UTTERANCES] = _stem_entity_utterances( + ent[UTTERANCES], language, resources) + elif parser_usage is None: + raise ValueError("A parser usage must be defined in order to fit " + "a CustomEntityParser") + configuration = _create_custom_entity_parser_configuration( + custom_entities, + language=dataset[LANGUAGE], + stopwords_fraction=STOPWORDS_FRACTION, + ) + parser = GazetteerEntityParser.build(configuration) + return cls(parser, language, parser_usage) + + +def _stem_entity_utterances(entity_utterances, language, resources): + values = dict() + # Sort by resolved value, so that values conflict in a deterministic way + for raw_value, resolved_value in sorted( + iteritems(entity_utterances), key=operator.itemgetter(1)): + stemmed_value = stem(raw_value, language, resources) + if stemmed_value not in values: + values[stemmed_value] = resolved_value + return values + + +def _merge_entity_utterances(raw_utterances, stemmed_utterances): + # Sort by resolved value, so that values conflict in a deterministic way + for raw_stemmed_value, resolved_value in sorted( + iteritems(stemmed_utterances), key=operator.itemgetter(1)): + if raw_stemmed_value not in raw_utterances: + raw_utterances[raw_stemmed_value] = resolved_value + return raw_utterances + + +def _create_custom_entity_parser_configuration( + entities, stopwords_fraction, language): + """Dynamically creates the gazetteer parser configuration. + + Args: + entities (dict): entity for the dataset + stopwords_fraction (float): fraction of the vocabulary of + the entity values that will be considered as stop words ( + the top n_vocabulary * stopwords_fraction most frequent words will + be considered stop words) + language (str): language of the entities + + Returns: the parser configuration as dictionary + """ + + if not 0 < stopwords_fraction < 1: + raise ValueError("stopwords_fraction must be in ]0.0, 1.0[") + + parser_configurations = [] + for entity_name, entity in sorted(iteritems(entities)): + vocabulary = set( + t for raw_value in entity[UTTERANCES] + for t in tokenize_light(raw_value, language) + ) + num_stopwords = int(stopwords_fraction * len(vocabulary)) + config = { + "entity_identifier": entity_name, + "entity_parser": { + "threshold": entity[MATCHING_STRICTNESS], + "n_gazetteer_stop_words": num_stopwords, + "gazetteer": [ + { + "raw_value": k, + "resolved_value": v + } for k, v in sorted(iteritems(entity[UTTERANCES])) + ] + } + } + if LICENSE_INFO in entity: + config["entity_parser"][LICENSE_INFO] = entity[LICENSE_INFO] + parser_configurations.append(config) + + configuration = { + "entity_parsers": parser_configurations + } + + return configuration + + +def _compute_char_shifts(tokens): + """Compute the shifts in characters that occur when comparing the + tokens string with the string consisting of all tokens separated with a + space + + For instance, if "hello?world" is tokenized in ["hello", "?", "world"], + then the character shifts between "hello?world" and "hello ? world" are + [0, 0, 0, 0, 0, 1, 1, 2, 2, 2, 2, 2, 2] + """ + characters_shifts = [] + if not tokens: + return characters_shifts + + current_shift = 0 + for token_index, token in enumerate(tokens): + if token_index == 0: + previous_token_end = 0 + previous_space_len = 0 + else: + previous_token_end = tokens[token_index - 1].end + previous_space_len = 1 + offset = (token.start - previous_token_end) - previous_space_len + current_shift -= offset + token_len = token.end - token.start + index_shift = token_len + previous_space_len + characters_shifts += [current_shift for _ in range(index_shift)] + return characters_shifts diff --git a/snips_inference_agl/entity_parser/custom_entity_parser_usage.py b/snips_inference_agl/entity_parser/custom_entity_parser_usage.py new file mode 100644 index 0000000..72d420a --- /dev/null +++ b/snips_inference_agl/entity_parser/custom_entity_parser_usage.py @@ -0,0 +1,23 @@ +from __future__ import unicode_literals + +from enum import Enum, unique + + +@unique +class CustomEntityParserUsage(Enum): + WITH_STEMS = 0 + """The parser is used with stemming""" + WITHOUT_STEMS = 1 + """The parser is used without stemming""" + WITH_AND_WITHOUT_STEMS = 2 + """The parser is used both with and without stemming""" + + @classmethod + def merge_usages(cls, lhs_usage, rhs_usage): + if lhs_usage is None: + return rhs_usage + if rhs_usage is None: + return lhs_usage + if lhs_usage == rhs_usage: + return lhs_usage + return cls.WITH_AND_WITHOUT_STEMS diff --git a/snips_inference_agl/entity_parser/entity_parser.py b/snips_inference_agl/entity_parser/entity_parser.py new file mode 100644 index 0000000..46de55e --- /dev/null +++ b/snips_inference_agl/entity_parser/entity_parser.py @@ -0,0 +1,85 @@ +# coding=utf-8 +from __future__ import unicode_literals + +from abc import ABCMeta, abstractmethod + +from future.builtins import object +from future.utils import with_metaclass + +from snips_inference_agl.common.dict_utils import LimitedSizeDict + +# pylint: disable=ungrouped-imports + +try: + from abc import abstractclassmethod +except ImportError: + from snips_inference_agl.common.abc_utils import abstractclassmethod + + +# pylint: enable=ungrouped-imports + + +class EntityParser(with_metaclass(ABCMeta, object)): + """Abstraction of a entity parser implementing some basic caching + """ + + def __init__(self): + self._cache = LimitedSizeDict(size_limit=1000) + + def parse(self, text, scope=None, use_cache=True): + """Search the given text for entities defined in the scope. If no + scope is provided, search for all kinds of entities. + + Args: + text (str): input text + scope (list or set of str, optional): if provided the parser + will only look for entities which entity kind is given in + the scope. By default the scope is None and the parser + will search for all kinds of supported entities + use_cache (bool): if False the internal cache will not be use, + this can be useful if the output of the parser depends on + the current timestamp. Defaults to True. + + Returns: + list of dict: list of the parsed entities formatted as a dict + containing the string value, the resolved value, the + entity kind and the entity range + """ + if not use_cache: + return self._parse(text, scope) + scope_key = tuple(sorted(scope)) if scope is not None else scope + cache_key = (text, scope_key) + if cache_key not in self._cache: + parser_result = self._parse(text, scope) + self._cache[cache_key] = parser_result + return self._cache[cache_key] + + @abstractmethod + def _parse(self, text, scope=None): + """Internal parse method to implement in each subclass of + :class:`.EntityParser` + + Args: + text (str): input text + scope (list or set of str, optional): if provided the parser + will only look for entities which entity kind is given in + the scope. By default the scope is None and the parser + will search for all kinds of supported entities + use_cache (bool): if False the internal cache will not be use, + this can be useful if the output of the parser depends on + the current timestamp. Defaults to True. + + Returns: + list of dict: list of the parsed entities. These entity must + have the same output format as the + :func:`snips_inference_agl.utils.result.parsed_entity` function + """ + pass + + @abstractmethod + def persist(self, path): + pass + + @abstractclassmethod + def from_path(cls, path): + pass diff --git a/snips_inference_agl/exceptions.py b/snips_inference_agl/exceptions.py new file mode 100644 index 0000000..b1037a1 --- /dev/null +++ b/snips_inference_agl/exceptions.py @@ -0,0 +1,87 @@ +from snips_inference_agl.__about__ import __model_version__ + + +class SnipsNLUError(Exception): + """Base class for exceptions raised in the snips-nlu library""" + + +class IncompatibleModelError(Exception): + """Raised when trying to load an incompatible NLU engine + + This happens when the engine data was persisted with a previous version of + the library which is not compatible with the one used to load the model. + """ + + def __init__(self, persisted_version): + super(IncompatibleModelError, self).__init__( + "Incompatible data model: persisted model=%s, python lib model=%s" + % (persisted_version, __model_version__)) + + +class InvalidInputError(Exception): + """Raised when an incorrect input is passed to one of the APIs""" + + +class NotTrained(SnipsNLUError): + """Raised when a processing unit is used while not fitted""" + + +class IntentNotFoundError(SnipsNLUError): + """Raised when an intent is used although it was not part of the + training data""" + + def __init__(self, intent): + super(IntentNotFoundError, self).__init__("Unknown intent '%s'" + % intent) + + +class DatasetFormatError(SnipsNLUError): + """Raised when attempting to create a Snips NLU dataset using a wrong + format""" + + +class EntityFormatError(DatasetFormatError): + """Raised when attempting to create a Snips NLU entity using a wrong + format""" + + +class IntentFormatError(DatasetFormatError): + """Raised when attempting to create a Snips NLU intent using a wrong + format""" + + +class AlreadyRegisteredError(SnipsNLUError): + """Raised when attempting to register a subclass which is already + registered""" + + def __init__(self, name, new_class, existing_class): + msg = "Cannot register %s for %s as it has already been used to " \ + "register %s" \ + % (name, new_class.__name__, existing_class.__name__) + super(AlreadyRegisteredError, self).__init__(msg) + + +class NotRegisteredError(SnipsNLUError): + """Raised when trying to use a subclass which was not registered""" + + def __init__(self, registrable_cls, name=None, registered_cls=None): + if name is not None: + msg = "'%s' has not been registered for type %s. " \ + % (name, registrable_cls) + else: + msg = "subclass %s has not been registered for type %s. " \ + % (registered_cls, registrable_cls) + msg += "Use @BaseClass.register('my_component') to register a subclass" + super(NotRegisteredError, self).__init__(msg) + +class PersistingError(SnipsNLUError): + """Raised when trying to persist a processing unit to a path which already + exists""" + + def __init__(self, path): + super(PersistingError, self).__init__("Path already exists: %s" + % str(path)) + +class LoadingError(SnipsNLUError): + """Raised when trying to load a processing unit while some files are + missing""" diff --git a/snips_inference_agl/intent_classifier/__init__.py b/snips_inference_agl/intent_classifier/__init__.py new file mode 100644 index 0000000..89ccf95 --- /dev/null +++ b/snips_inference_agl/intent_classifier/__init__.py @@ -0,0 +1,3 @@ +from .intent_classifier import IntentClassifier +from .log_reg_classifier import LogRegIntentClassifier +from .featurizer import Featurizer, CooccurrenceVectorizer, TfidfVectorizer diff --git a/snips_inference_agl/intent_classifier/featurizer.py b/snips_inference_agl/intent_classifier/featurizer.py new file mode 100644 index 0000000..116837f --- /dev/null +++ b/snips_inference_agl/intent_classifier/featurizer.py @@ -0,0 +1,452 @@ +from __future__ import division, unicode_literals + +import json +from builtins import str, zip +from copy import deepcopy +from pathlib import Path + +from future.utils import iteritems + +from snips_inference_agl.common.utils import ( + fitted_required, replace_entities_with_placeholders) +from snips_inference_agl.constants import ( + DATA, ENTITY, ENTITY_KIND, NGRAM, TEXT) +from snips_inference_agl.dataset import get_text_from_chunks +from snips_inference_agl.entity_parser.builtin_entity_parser import ( + is_builtin_entity) +from snips_inference_agl.exceptions import (LoadingError) +from snips_inference_agl.languages import get_default_sep +from snips_inference_agl.pipeline.configs import FeaturizerConfig +from snips_inference_agl.pipeline.configs.intent_classifier import ( + CooccurrenceVectorizerConfig, TfidfVectorizerConfig) +from snips_inference_agl.pipeline.processing_unit import ProcessingUnit +from snips_inference_agl.preprocessing import stem, tokenize_light +from snips_inference_agl.resources import get_stop_words, get_word_cluster +from snips_inference_agl.slot_filler.features_utils import get_all_ngrams + + +@ProcessingUnit.register("featurizer") +class Featurizer(ProcessingUnit): + """Feature extractor for text classification relying on ngrams tfidf and + optionally word cooccurrences features""" + + config_type = FeaturizerConfig + + def __init__(self, config=None, **shared): + super(Featurizer, self).__init__(config, **shared) + self.language = None + self.tfidf_vectorizer = None + self.cooccurrence_vectorizer = None + + @property + def fitted(self): + if not self.tfidf_vectorizer or not self.tfidf_vectorizer.vocabulary: + return False + return True + + def transform(self, utterances): + import scipy.sparse as sp + + x = self.tfidf_vectorizer.transform(utterances) + if self.cooccurrence_vectorizer: + x_cooccurrence = self.cooccurrence_vectorizer.transform(utterances) + x = sp.hstack((x, x_cooccurrence)) + return x + + @classmethod + def from_path(cls, path, **shared): + path = Path(path) + + model_path = path / "featurizer.json" + if not model_path.exists(): + raise LoadingError("Missing featurizer model file: %s" + % model_path.name) + with model_path.open("r", encoding="utf-8") as f: + featurizer_dict = json.load(f) + + featurizer_config = featurizer_dict["config"] + featurizer = cls(featurizer_config, **shared) + + featurizer.language = featurizer_dict["language_code"] + + tfidf_vectorizer = featurizer_dict["tfidf_vectorizer"] + if tfidf_vectorizer: + vectorizer_path = path / featurizer_dict["tfidf_vectorizer"] + tfidf_vectorizer = TfidfVectorizer.from_path( + vectorizer_path, **shared) + featurizer.tfidf_vectorizer = tfidf_vectorizer + + cooccurrence_vectorizer = featurizer_dict["cooccurrence_vectorizer"] + if cooccurrence_vectorizer: + vectorizer_path = path / featurizer_dict["cooccurrence_vectorizer"] + cooccurrence_vectorizer = CooccurrenceVectorizer.from_path( + vectorizer_path, **shared) + featurizer.cooccurrence_vectorizer = cooccurrence_vectorizer + + return featurizer + + +@ProcessingUnit.register("tfidf_vectorizer") +class TfidfVectorizer(ProcessingUnit): + """Wrapper of the scikit-learn TfidfVectorizer""" + + config_type = TfidfVectorizerConfig + + def __init__(self, config=None, **shared): + super(TfidfVectorizer, self).__init__(config, **shared) + self._tfidf_vectorizer = None + self._language = None + self.builtin_entity_scope = None + + @property + def fitted(self): + return self._tfidf_vectorizer is not None and hasattr( + self._tfidf_vectorizer, "vocabulary_") + + @fitted_required + def transform(self, x): + """Featurizes the given utterances after enriching them with builtin + entities matches, custom entities matches and the potential word + clusters matches + + Args: + x (list of dict): list of utterances + + Returns: + :class:`.scipy.sparse.csr_matrix`: A sparse matrix X of shape + (len(x), len(self.vocabulary)) where X[i, j] contains tfdif of + the ngram of index j of the vocabulary in the utterance i + + Raises: + NotTrained: when the vectorizer is not fitted: + """ + utterances = [self._enrich_utterance(*data) + for data in zip(*self._preprocess(x))] + return self._tfidf_vectorizer.transform(utterances) + + def _preprocess(self, utterances): + normalized_utterances = deepcopy(utterances) + for u in normalized_utterances: + nb_chunks = len(u[DATA]) + for i, chunk in enumerate(u[DATA]): + chunk[TEXT] = _normalize_stem( + chunk[TEXT], self.language, self.resources, + self.config.use_stemming) + if i < nb_chunks - 1: + chunk[TEXT] += " " + + # Extract builtin entities on unormalized utterances + builtin_ents = [ + self.builtin_entity_parser.parse( + get_text_from_chunks(u[DATA]), + self.builtin_entity_scope, use_cache=True) + for u in utterances + ] + # Extract builtin entities on normalized utterances + custom_ents = [ + self.custom_entity_parser.parse( + get_text_from_chunks(u[DATA]), use_cache=True) + for u in normalized_utterances + ] + if self.config.word_clusters_name: + # Extract world clusters on unormalized utterances + original_utterances_text = [get_text_from_chunks(u[DATA]) + for u in utterances] + w_clusters = [ + _get_word_cluster_features( + tokenize_light(u.lower(), self.language), + self.config.word_clusters_name, + self.resources) + for u in original_utterances_text + ] + else: + w_clusters = [None for _ in normalized_utterances] + + return normalized_utterances, builtin_ents, custom_ents, w_clusters + + def _enrich_utterance(self, utterance, builtin_entities, custom_entities, + word_clusters): + custom_entities_features = [ + _entity_name_to_feature(e[ENTITY_KIND], self.language) + for e in custom_entities] + + builtin_entities_features = [ + _builtin_entity_to_feature(ent[ENTITY_KIND], self.language) + for ent in builtin_entities + ] + + # We remove values of builtin slots from the utterance to avoid + # learning specific samples such as '42' or 'tomorrow' + filtered_tokens = [ + chunk[TEXT] for chunk in utterance[DATA] + if ENTITY not in chunk or not is_builtin_entity(chunk[ENTITY]) + ] + + features = get_default_sep(self.language).join(filtered_tokens) + + if builtin_entities_features: + features += " " + " ".join(sorted(builtin_entities_features)) + if custom_entities_features: + features += " " + " ".join(sorted(custom_entities_features)) + if word_clusters: + features += " " + " ".join(sorted(word_clusters)) + + return features + + @property + def language(self): + # Create this getter to prevent the language from being set elsewhere + # than in the fit + return self._language + + @property + def vocabulary(self): + if self._tfidf_vectorizer and hasattr( + self._tfidf_vectorizer, "vocabulary_"): + return self._tfidf_vectorizer.vocabulary_ + return None + + @property + def idf_diag(self): + if self._tfidf_vectorizer and hasattr( + self._tfidf_vectorizer, "vocabulary_"): + return self._tfidf_vectorizer.idf_ + return None + + @classmethod + # pylint: disable=W0212 + def from_path(cls, path, **shared): + import numpy as np + import scipy.sparse as sp + from sklearn.feature_extraction.text import ( + TfidfTransformer, TfidfVectorizer as SklearnTfidfVectorizer) + + path = Path(path) + + model_path = path / "vectorizer.json" + if not model_path.exists(): + raise LoadingError("Missing vectorizer model file: %s" + % model_path.name) + with model_path.open("r", encoding="utf-8") as f: + vectorizer_dict = json.load(f) + + vectorizer = cls(vectorizer_dict["config"], **shared) + vectorizer._language = vectorizer_dict["language_code"] + + builtin_entity_scope = vectorizer_dict["builtin_entity_scope"] + if builtin_entity_scope is not None: + builtin_entity_scope = set(builtin_entity_scope) + vectorizer.builtin_entity_scope = builtin_entity_scope + + vectorizer_ = vectorizer_dict["vectorizer"] + if vectorizer_: + vocab = vectorizer_["vocab"] + idf_diag_data = vectorizer_["idf_diag"] + idf_diag_data = np.array(idf_diag_data) + + idf_diag_shape = (len(idf_diag_data), len(idf_diag_data)) + row = list(range(idf_diag_shape[0])) + col = list(range(idf_diag_shape[0])) + idf_diag = sp.csr_matrix( + (idf_diag_data, (row, col)), shape=idf_diag_shape) + + tfidf_transformer = TfidfTransformer() + tfidf_transformer._idf_diag = idf_diag + + vectorizer_ = SklearnTfidfVectorizer( + tokenizer=lambda x: tokenize_light(x, vectorizer._language)) + vectorizer_.vocabulary_ = vocab + + vectorizer_._tfidf = tfidf_transformer + + vectorizer._tfidf_vectorizer = vectorizer_ + return vectorizer + + +@ProcessingUnit.register("cooccurrence_vectorizer") +class CooccurrenceVectorizer(ProcessingUnit): + """Featurizer that takes utterances and extracts ordered word cooccurrence + features matrix from them""" + + config_type = CooccurrenceVectorizerConfig + + def __init__(self, config=None, **shared): + super(CooccurrenceVectorizer, self).__init__(config, **shared) + self._word_pairs = None + self._language = None + self.builtin_entity_scope = None + + @property + def language(self): + # Create this getter to prevent the language from being set elsewhere + # than in the fit + return self._language + + @property + def word_pairs(self): + return self._word_pairs + + @property + def fitted(self): + """Whether or not the vectorizer is fitted""" + return self.word_pairs is not None + + @fitted_required + def transform(self, x): + """Computes the cooccurrence feature matrix. + + Args: + x (list of dict): list of utterances + + Returns: + :class:`.scipy.sparse.csr_matrix`: A sparse matrix X of shape + (len(x), len(self.word_pairs)) where X[i, j] = 1.0 if + x[i][0] contains the words cooccurrence (w1, w2) and if + self.word_pairs[(w1, w2)] = j + + Raises: + NotTrained: when the vectorizer is not fitted + """ + import numpy as np + import scipy.sparse as sp + + preprocessed = self._preprocess(x) + utterances = [ + self._enrich_utterance(utterance, builtin_ents, custom_ent) + for utterance, builtin_ents, custom_ent in zip(*preprocessed)] + + x_coo = sp.dok_matrix((len(x), len(self.word_pairs)), dtype=np.int32) + for i, u in enumerate(utterances): + for p in self._extract_word_pairs(u): + if p in self.word_pairs: + x_coo[i, self.word_pairs[p]] = 1 + + return x_coo.tocsr() + + def _preprocess(self, x): + # Extract all entities on unnormalized data + builtin_ents = [ + self.builtin_entity_parser.parse( + get_text_from_chunks(u[DATA]), + self.builtin_entity_scope, + use_cache=True + ) for u in x + ] + custom_ents = [ + self.custom_entity_parser.parse( + get_text_from_chunks(u[DATA]), use_cache=True) + for u in x + ] + return x, builtin_ents, custom_ents + + def _extract_word_pairs(self, utterance): + if self.config.filter_stop_words: + stop_words = get_stop_words(self.resources) + utterance = [t for t in utterance if t not in stop_words] + pairs = set() + for j, w1 in enumerate(utterance): + max_index = None + if self.config.window_size is not None: + max_index = j + self.config.window_size + 1 + for w2 in utterance[j + 1:max_index]: + key = (w1, w2) + if not self.config.keep_order: + key = tuple(sorted(key)) + pairs.add(key) + return pairs + + def _enrich_utterance(self, x, builtin_ents, custom_ents): + utterance = get_text_from_chunks(x[DATA]) + all_entities = builtin_ents + custom_ents + placeholder_fn = self._placeholder_fn + # Replace entities with placeholders + enriched_utterance = replace_entities_with_placeholders( + utterance, all_entities, placeholder_fn)[1] + # Tokenize + enriched_utterance = tokenize_light(enriched_utterance, self.language) + # Remove the unknownword strings if needed + if self.config.unknown_words_replacement_string: + enriched_utterance = [ + t for t in enriched_utterance + if t != self.config.unknown_words_replacement_string + ] + return enriched_utterance + + def _extract_word_pairs(self, utterance): + if self.config.filter_stop_words: + stop_words = get_stop_words(self.resources) + utterance = [t for t in utterance if t not in stop_words] + pairs = set() + for j, w1 in enumerate(utterance): + max_index = None + if self.config.window_size is not None: + max_index = j + self.config.window_size + 1 + for w2 in utterance[j + 1:max_index]: + key = (w1, w2) + if not self.config.keep_order: + key = tuple(sorted(key)) + pairs.add(key) + return pairs + + def _placeholder_fn(self, entity_name): + return "".join( + tokenize_light(str(entity_name), str(self.language))).upper() + + @classmethod + # pylint: disable=protected-access + def from_path(cls, path, **shared): + path = Path(path) + model_path = path / "vectorizer.json" + if not model_path.exists(): + raise LoadingError("Missing vectorizer model file: %s" + % model_path.name) + + with model_path.open(encoding="utf8") as f: + vectorizer_dict = json.load(f) + config = vectorizer_dict.pop("config") + + self = cls(config, **shared) + self._language = vectorizer_dict["language_code"] + self._word_pairs = None + + builtin_entity_scope = vectorizer_dict["builtin_entity_scope"] + if builtin_entity_scope is not None: + builtin_entity_scope = set(builtin_entity_scope) + self.builtin_entity_scope = builtin_entity_scope + + if vectorizer_dict["word_pairs"]: + self._word_pairs = { + tuple(p): int(i) + for i, p in iteritems(vectorizer_dict["word_pairs"]) + } + return self + +def _entity_name_to_feature(entity_name, language): + return "entityfeature%s" % "".join(tokenize_light( + entity_name.lower(), language)) + + +def _builtin_entity_to_feature(builtin_entity_label, language): + return "builtinentityfeature%s" % "".join(tokenize_light( + builtin_entity_label.lower(), language)) + + +def _normalize_stem(text, language, resources, use_stemming): + from snips_nlu_utils import normalize + + if use_stemming: + return stem(text, language, resources) + return normalize(text) + + +def _get_word_cluster_features(query_tokens, clusters_name, resources): + if not clusters_name: + return [] + ngrams = get_all_ngrams(query_tokens) + cluster_features = [] + for ngram in ngrams: + cluster = get_word_cluster(resources, clusters_name).get( + ngram[NGRAM].lower(), None) + if cluster is not None: + cluster_features.append(cluster) + return cluster_features diff --git a/snips_inference_agl/intent_classifier/intent_classifier.py b/snips_inference_agl/intent_classifier/intent_classifier.py new file mode 100644 index 0000000..f9a7952 --- /dev/null +++ b/snips_inference_agl/intent_classifier/intent_classifier.py @@ -0,0 +1,51 @@ +from abc import ABCMeta + +from future.utils import with_metaclass + +from snips_inference_agl.pipeline.processing_unit import ProcessingUnit +from snips_inference_agl.common.abc_utils import classproperty + + +class IntentClassifier(with_metaclass(ABCMeta, ProcessingUnit)): + """Abstraction which performs intent classification + + A custom intent classifier must inherit this class to be used in a + :class:`.ProbabilisticIntentParser` + """ + + @classproperty + def unit_name(cls): # pylint:disable=no-self-argument + return IntentClassifier.registered_name(cls) + + # @abstractmethod + def get_intent(self, text, intents_filter): + """Performs intent classification on the provided *text* + + Args: + text (str): Input + intents_filter (str or list of str): When defined, it will find + the most likely intent among the list, otherwise it will use + the whole list of intents defined in the dataset + + Returns: + dict or None: The most likely intent along with its probability or + *None* if no intent was found. See + :func:`.intent_classification_result` for the output format. + """ + pass + + # @abstractmethod + def get_intents(self, text): + """Performs intent classification on the provided *text* and returns + the list of intents ordered by decreasing probability + + The length of the returned list is exactly the number of intents in the + dataset + 1 for the None intent + + .. note:: + + The probabilities returned along with each intent are not + guaranteed to sum to 1.0. They should be considered as scores + between 0 and 1. + """ + pass diff --git a/snips_inference_agl/intent_classifier/log_reg_classifier.py b/snips_inference_agl/intent_classifier/log_reg_classifier.py new file mode 100644 index 0000000..09e537c --- /dev/null +++ b/snips_inference_agl/intent_classifier/log_reg_classifier.py @@ -0,0 +1,211 @@ +from __future__ import unicode_literals + +import json +import logging +from builtins import str, zip +from pathlib import Path + +from snips_inference_agl.common.log_utils import DifferedLoggingMessage +from snips_inference_agl.common.utils import (fitted_required) +from snips_inference_agl.constants import RES_PROBA +from snips_inference_agl.exceptions import LoadingError +from snips_inference_agl.intent_classifier.featurizer import Featurizer +from snips_inference_agl.intent_classifier.intent_classifier import IntentClassifier +from snips_inference_agl.intent_classifier.log_reg_classifier_utils import (text_to_utterance) +from snips_inference_agl.pipeline.configs import LogRegIntentClassifierConfig +from snips_inference_agl.result import intent_classification_result + +logger = logging.getLogger(__name__) + +# We set tol to 1e-3 to silence the following warning with Python 2 ( +# scikit-learn 0.20): +# +# FutureWarning: max_iter and tol parameters have been added in SGDClassifier +# in 0.19. If max_iter is set but tol is left unset, the default value for tol +# in 0.19 and 0.20 will be None (which is equivalent to -infinity, so it has no +# effect) but will change in 0.21 to 1e-3. Specify tol to silence this warning. + +LOG_REG_ARGS = { + "loss": "log", + "penalty": "l2", + "max_iter": 1000, + "tol": 1e-3, + "n_jobs": -1 +} + + +@IntentClassifier.register("log_reg_intent_classifier") +class LogRegIntentClassifier(IntentClassifier): + """Intent classifier which uses a Logistic Regression underneath""" + + config_type = LogRegIntentClassifierConfig + + def __init__(self, config=None, **shared): + """The LogReg intent classifier can be configured by passing a + :class:`.LogRegIntentClassifierConfig`""" + super(LogRegIntentClassifier, self).__init__(config, **shared) + self.classifier = None + self.intent_list = None + self.featurizer = None + + @property + def fitted(self): + """Whether or not the intent classifier has already been fitted""" + return self.intent_list is not None + + @fitted_required + def get_intent(self, text, intents_filter=None): + """Performs intent classification on the provided *text* + + Args: + text (str): Input + intents_filter (str or list of str): When defined, it will find + the most likely intent among the list, otherwise it will use + the whole list of intents defined in the dataset + + Returns: + dict or None: The most likely intent along with its probability or + *None* if no intent was found + + Raises: + :class:`snips_nlu.exceptions.NotTrained`: When the intent + classifier is not fitted + + """ + return self._get_intents(text, intents_filter)[0] + + @fitted_required + def get_intents(self, text): + """Performs intent classification on the provided *text* and returns + the list of intents ordered by decreasing probability + + The length of the returned list is exactly the number of intents in the + dataset + 1 for the None intent + + Raises: + :class:`snips_nlu.exceptions.NotTrained`: when the intent + classifier is not fitted + """ + return self._get_intents(text, intents_filter=None) + + def _get_intents(self, text, intents_filter): + if isinstance(intents_filter, str): + intents_filter = {intents_filter} + elif isinstance(intents_filter, list): + intents_filter = set(intents_filter) + + if not text or not self.intent_list or not self.featurizer: + results = [intent_classification_result(None, 1.0)] + results += [intent_classification_result(i, 0.0) + for i in self.intent_list if i is not None] + return results + + if len(self.intent_list) == 1: + return [intent_classification_result(self.intent_list[0], 1.0)] + + # pylint: disable=C0103 + X = self.featurizer.transform([text_to_utterance(text)]) + # pylint: enable=C0103 + proba_vec = self._predict_proba(X) + logger.debug( + "%s", DifferedLoggingMessage(self.log_activation_weights, text, X)) + results = [ + intent_classification_result(i, proba) + for i, proba in zip(self.intent_list, proba_vec[0]) + if intents_filter is None or i is None or i in intents_filter] + + return sorted(results, key=lambda res: -res[RES_PROBA]) + + def _predict_proba(self, X): # pylint: disable=C0103 + import numpy as np + + self.classifier._check_proba() # pylint: disable=W0212 + + prob = self.classifier.decision_function(X) + prob *= -1 + np.exp(prob, prob) + prob += 1 + np.reciprocal(prob, prob) + if prob.ndim == 1: + return np.vstack([1 - prob, prob]).T + return prob + + @classmethod + def from_path(cls, path, **shared): + """Loads a :class:`LogRegIntentClassifier` instance from a path + + The data at the given path must have been generated using + :func:`~LogRegIntentClassifier.persist` + """ + import numpy as np + from sklearn.linear_model import SGDClassifier + + path = Path(path) + model_path = path / "intent_classifier.json" + if not model_path.exists(): + raise LoadingError("Missing intent classifier model file: %s" + % model_path.name) + + with model_path.open(encoding="utf8") as f: + model_dict = json.load(f) + + # Create the classifier + config = LogRegIntentClassifierConfig.from_dict(model_dict["config"]) + intent_classifier = cls(config=config, **shared) + intent_classifier.intent_list = model_dict['intent_list'] + + # Create the underlying SGD classifier + sgd_classifier = None + coeffs = model_dict['coeffs'] + intercept = model_dict['intercept'] + t_ = model_dict["t_"] + if coeffs is not None and intercept is not None: + sgd_classifier = SGDClassifier(**LOG_REG_ARGS) + sgd_classifier.coef_ = np.array(coeffs) + sgd_classifier.intercept_ = np.array(intercept) + sgd_classifier.t_ = t_ + intent_classifier.classifier = sgd_classifier + + # Add the featurizer + featurizer = model_dict['featurizer'] + if featurizer is not None: + featurizer_path = path / featurizer + intent_classifier.featurizer = Featurizer.from_path( + featurizer_path, **shared) + + return intent_classifier + + def log_activation_weights(self, text, x, top_n=50): + import numpy as np + + if not hasattr(self.featurizer, "feature_index_to_feature_name"): + return None + + log = "\n\nTop {} feature activations for: \"{}\":\n".format( + top_n, text) + activations = np.multiply( + self.classifier.coef_, np.asarray(x.todense())) + abs_activation = np.absolute(activations).flatten().squeeze() + + if top_n > activations.size: + top_n = activations.size + + top_n_activations_ix = np.argpartition(abs_activation, -top_n, + axis=None)[-top_n:] + top_n_activations_ix = np.unravel_index( + top_n_activations_ix, activations.shape) + + index_to_feature = self.featurizer.feature_index_to_feature_name + features_intent_and_activation = [ + (self.intent_list[i], index_to_feature[f], activations[i, f]) + for i, f in zip(*top_n_activations_ix)] + + features_intent_and_activation = sorted( + features_intent_and_activation, key=lambda x: abs(x[2]), + reverse=True) + + for intent, feature, activation in features_intent_and_activation: + log += "\n\n\"{}\" -> ({}, {:.2f})".format( + intent, feature, float(activation)) + log += "\n\n" + return log diff --git a/snips_inference_agl/intent_classifier/log_reg_classifier_utils.py b/snips_inference_agl/intent_classifier/log_reg_classifier_utils.py new file mode 100644 index 0000000..75a8ab1 --- /dev/null +++ b/snips_inference_agl/intent_classifier/log_reg_classifier_utils.py @@ -0,0 +1,94 @@ +from __future__ import division, unicode_literals + +import itertools +import re +from builtins import next, range, str +from copy import deepcopy +from uuid import uuid4 + +from future.utils import iteritems, itervalues + +from snips_inference_agl.constants import (DATA, ENTITY, INTENTS, TEXT, + UNKNOWNWORD, UTTERANCES) +from snips_inference_agl.data_augmentation import augment_utterances +from snips_inference_agl.dataset import get_text_from_chunks +from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity +from snips_inference_agl.preprocessing import tokenize_light +from snips_inference_agl.resources import get_noise + +NOISE_NAME = str(uuid4()) +WORD_REGEX = re.compile(r"\w+(\s+\w+)*") +UNKNOWNWORD_REGEX = re.compile(r"%s(\s+%s)*" % (UNKNOWNWORD, UNKNOWNWORD)) + + +def get_noise_it(noise, mean_length, std_length, random_state): + it = itertools.cycle(noise) + while True: + noise_length = int(random_state.normal(mean_length, std_length)) + # pylint: disable=stop-iteration-return + yield " ".join(next(it) for _ in range(noise_length)) + # pylint: enable=stop-iteration-return + + +def generate_smart_noise(noise, augmented_utterances, replacement_string, + language): + text_utterances = [get_text_from_chunks(u[DATA]) + for u in augmented_utterances] + vocab = [w for u in text_utterances for w in tokenize_light(u, language)] + vocab = set(vocab) + return [w if w in vocab else replacement_string for w in noise] + + +def generate_noise_utterances(augmented_utterances, noise, num_intents, + data_augmentation_config, language, + random_state): + import numpy as np + + if not augmented_utterances or not num_intents: + return [] + avg_num_utterances = len(augmented_utterances) / float(num_intents) + if data_augmentation_config.unknown_words_replacement_string is not None: + noise = generate_smart_noise( + noise, augmented_utterances, + data_augmentation_config.unknown_words_replacement_string, + language) + + noise_size = min( + int(data_augmentation_config.noise_factor * avg_num_utterances), + len(noise)) + utterances_lengths = [ + len(tokenize_light(get_text_from_chunks(u[DATA]), language)) + for u in augmented_utterances] + mean_utterances_length = np.mean(utterances_lengths) + std_utterances_length = np.std(utterances_lengths) + noise_it = get_noise_it(noise, mean_utterances_length, + std_utterances_length, random_state) + # Remove duplicate 'unknownword unknownword' + return [ + text_to_utterance(UNKNOWNWORD_REGEX.sub(UNKNOWNWORD, next(noise_it))) + for _ in range(noise_size)] + + +def add_unknown_word_to_utterances(utterances, replacement_string, + unknown_word_prob, max_unknown_words, + random_state): + if not max_unknown_words: + return utterances + + new_utterances = deepcopy(utterances) + for u in new_utterances: + if random_state.rand() < unknown_word_prob: + num_unknown = random_state.randint(1, max_unknown_words + 1) + # We choose to put the noise at the end of the sentence and not + # in the middle so that it doesn't impact to much ngrams + # computation + extra_chunk = { + TEXT: " " + " ".join( + replacement_string for _ in range(num_unknown)) + } + u[DATA].append(extra_chunk) + return new_utterances + + +def text_to_utterance(text): + return {DATA: [{TEXT: text}]} diff --git a/snips_inference_agl/intent_parser/__init__.py b/snips_inference_agl/intent_parser/__init__.py new file mode 100644 index 0000000..1b0d446 --- /dev/null +++ b/snips_inference_agl/intent_parser/__init__.py @@ -0,0 +1,4 @@ +from .deterministic_intent_parser import DeterministicIntentParser +from .intent_parser import IntentParser +from .lookup_intent_parser import LookupIntentParser +from .probabilistic_intent_parser import ProbabilisticIntentParser diff --git a/snips_inference_agl/intent_parser/deterministic_intent_parser.py b/snips_inference_agl/intent_parser/deterministic_intent_parser.py new file mode 100644 index 0000000..845e59d --- /dev/null +++ b/snips_inference_agl/intent_parser/deterministic_intent_parser.py @@ -0,0 +1,518 @@ +from __future__ import unicode_literals + +import json +import logging +import re +from builtins import str +from collections import defaultdict +from pathlib import Path + +from future.utils import iteritems, itervalues + +from snips_inference_agl.common.dataset_utils import get_slot_name_mappings +from snips_inference_agl.common.log_utils import log_elapsed_time, log_result +from snips_inference_agl.common.utils import ( + check_persisted_path, deduplicate_overlapping_items, fitted_required, + json_string, ranges_overlap, regex_escape, + replace_entities_with_placeholders) +from snips_inference_agl.constants import ( + DATA, END, ENTITIES, ENTITY, + INTENTS, LANGUAGE, RES_INTENT, RES_INTENT_NAME, + RES_MATCH_RANGE, RES_SLOTS, RES_VALUE, SLOT_NAME, START, TEXT, UTTERANCES, + RES_PROBA) +from snips_inference_agl.dataset import validate_and_format_dataset +from snips_inference_agl.dataset.utils import get_stop_words_whitelist +from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity +from snips_inference_agl.exceptions import IntentNotFoundError, LoadingError +from snips_inference_agl.intent_parser.intent_parser import IntentParser +from snips_inference_agl.pipeline.configs import DeterministicIntentParserConfig +from snips_inference_agl.preprocessing import normalize_token, tokenize, tokenize_light +from snips_inference_agl.resources import get_stop_words +from snips_inference_agl.result import (empty_result, extraction_result, + intent_classification_result, parsing_result, + unresolved_slot) + +WHITESPACE_PATTERN = r"\s*" + +logger = logging.getLogger(__name__) + + +@IntentParser.register("deterministic_intent_parser") +class DeterministicIntentParser(IntentParser): + """Intent parser using pattern matching in a deterministic manner + + This intent parser is very strict by nature, and tends to have a very good + precision but a low recall. For this reason, it is interesting to use it + first before potentially falling back to another parser. + """ + + config_type = DeterministicIntentParserConfig + + def __init__(self, config=None, **shared): + """The deterministic intent parser can be configured by passing a + :class:`.DeterministicIntentParserConfig`""" + super(DeterministicIntentParser, self).__init__(config, **shared) + self._language = None + self._slot_names_to_entities = None + self._group_names_to_slot_names = None + self._stop_words = None + self._stop_words_whitelist = None + self.slot_names_to_group_names = None + self.regexes_per_intent = None + self.entity_scopes = None + + @property + def language(self): + return self._language + + @language.setter + def language(self, value): + self._language = value + if value is None: + self._stop_words = None + else: + if self.config.ignore_stop_words: + self._stop_words = get_stop_words(self.resources) + else: + self._stop_words = set() + + @property + def slot_names_to_entities(self): + return self._slot_names_to_entities + + @slot_names_to_entities.setter + def slot_names_to_entities(self, value): + self._slot_names_to_entities = value + if value is None: + self.entity_scopes = None + else: + self.entity_scopes = { + intent: { + "builtin": {ent for ent in itervalues(slot_mapping) + if is_builtin_entity(ent)}, + "custom": {ent for ent in itervalues(slot_mapping) + if not is_builtin_entity(ent)} + } + for intent, slot_mapping in iteritems(value)} + + @property + def group_names_to_slot_names(self): + return self._group_names_to_slot_names + + @group_names_to_slot_names.setter + def group_names_to_slot_names(self, value): + self._group_names_to_slot_names = value + if value is not None: + self.slot_names_to_group_names = { + slot_name: group for group, slot_name in iteritems(value)} + + @property + def patterns(self): + """Dictionary of patterns per intent""" + if self.regexes_per_intent is not None: + return {i: [r.pattern for r in regex_list] for i, regex_list in + iteritems(self.regexes_per_intent)} + return None + + @patterns.setter + def patterns(self, value): + if value is not None: + self.regexes_per_intent = dict() + for intent, pattern_list in iteritems(value): + regexes = [re.compile(r"%s" % p, re.IGNORECASE) + for p in pattern_list] + self.regexes_per_intent[intent] = regexes + + @property + def fitted(self): + """Whether or not the intent parser has already been trained""" + return self.regexes_per_intent is not None + + @log_elapsed_time( + logger, logging.INFO, "Fitted deterministic parser in {elapsed_time}") + def fit(self, dataset, force_retrain=True): + """Fits the intent parser with a valid Snips dataset""" + logger.info("Fitting deterministic intent parser...") + dataset = validate_and_format_dataset(dataset) + self.load_resources_if_needed(dataset[LANGUAGE]) + self.fit_builtin_entity_parser_if_needed(dataset) + self.fit_custom_entity_parser_if_needed(dataset) + self.language = dataset[LANGUAGE] + self.regexes_per_intent = dict() + entity_placeholders = _get_entity_placeholders(dataset, self.language) + self.slot_names_to_entities = get_slot_name_mappings(dataset) + self.group_names_to_slot_names = _get_group_names_to_slot_names( + self.slot_names_to_entities) + self._stop_words_whitelist = get_stop_words_whitelist( + dataset, self._stop_words) + + # Do not use ambiguous patterns that appear in more than one intent + all_patterns = set() + ambiguous_patterns = set() + intent_patterns = dict() + for intent_name, intent in iteritems(dataset[INTENTS]): + patterns = self._generate_patterns(intent_name, intent[UTTERANCES], + entity_placeholders) + patterns = [p for p in patterns + if len(p) < self.config.max_pattern_length] + existing_patterns = {p for p in patterns if p in all_patterns} + ambiguous_patterns.update(existing_patterns) + all_patterns.update(set(patterns)) + intent_patterns[intent_name] = patterns + + for intent_name, patterns in iteritems(intent_patterns): + patterns = [p for p in patterns if p not in ambiguous_patterns] + patterns = patterns[:self.config.max_queries] + regexes = [re.compile(p, re.IGNORECASE) for p in patterns] + self.regexes_per_intent[intent_name] = regexes + return self + + @log_result( + logger, logging.DEBUG, "DeterministicIntentParser result -> {result}") + @log_elapsed_time(logger, logging.DEBUG, "Parsed in {elapsed_time}.") + @fitted_required + def parse(self, text, intents=None, top_n=None): + """Performs intent parsing on the provided *text* + + Intent and slots are extracted simultaneously through pattern matching + + Args: + text (str): input + intents (str or list of str): if provided, reduces the scope of + intent parsing to the provided list of intents + top_n (int, optional): when provided, this method will return a + list of at most top_n most likely intents, instead of a single + parsing result. + Note that the returned list can contain less than ``top_n`` + elements, for instance when the parameter ``intents`` is not + None, or when ``top_n`` is greater than the total number of + intents. + + Returns: + dict or list: the most likely intent(s) along with the extracted + slots. See :func:`.parsing_result` and :func:`.extraction_result` + for the output format. + + Raises: + NotTrained: when the intent parser is not fitted + """ + if top_n is None: + top_intents = self._parse_top_intents(text, top_n=1, + intents=intents) + if top_intents: + intent = top_intents[0][RES_INTENT] + slots = top_intents[0][RES_SLOTS] + if intent[RES_PROBA] <= 0.5: + # return None in case of ambiguity + return empty_result(text, probability=1.0) + return parsing_result(text, intent, slots) + return empty_result(text, probability=1.0) + return self._parse_top_intents(text, top_n=top_n, intents=intents) + + def _parse_top_intents(self, text, top_n, intents=None): + if isinstance(intents, str): + intents = {intents} + elif isinstance(intents, list): + intents = set(intents) + + if top_n < 1: + raise ValueError( + "top_n argument must be greater or equal to 1, but got: %s" + % top_n) + + def placeholder_fn(entity_name): + return _get_entity_name_placeholder(entity_name, self.language) + + results = [] + + for intent, entity_scope in iteritems(self.entity_scopes): + if intents is not None and intent not in intents: + continue + builtin_entities = self.builtin_entity_parser.parse( + text, scope=entity_scope["builtin"], use_cache=True) + custom_entities = self.custom_entity_parser.parse( + text, scope=entity_scope["custom"], use_cache=True) + all_entities = builtin_entities + custom_entities + mapping, processed_text = replace_entities_with_placeholders( + text, all_entities, placeholder_fn=placeholder_fn) + cleaned_text = self._preprocess_text(text, intent) + cleaned_processed_text = self._preprocess_text(processed_text, + intent) + for regex in self.regexes_per_intent[intent]: + res = self._get_matching_result(text, cleaned_text, regex, + intent) + if res is None and cleaned_text != cleaned_processed_text: + res = self._get_matching_result( + text, cleaned_processed_text, regex, intent, mapping) + + if res is not None: + results.append(res) + break + + # In some rare cases there can be multiple ambiguous intents + # In such cases, priority is given to results containing fewer slots + weights = [1.0 / (1.0 + len(res[RES_SLOTS])) for res in results] + total_weight = sum(weights) + + for res, weight in zip(results, weights): + res[RES_INTENT][RES_PROBA] = weight / total_weight + + results = sorted(results, key=lambda r: -r[RES_INTENT][RES_PROBA]) + + return results[:top_n] + + @fitted_required + def get_intents(self, text): + """Returns the list of intents ordered by decreasing probability + + The length of the returned list is exactly the number of intents in the + dataset + 1 for the None intent + """ + nb_intents = len(self.regexes_per_intent) + top_intents = [intent_result[RES_INTENT] for intent_result in + self._parse_top_intents(text, top_n=nb_intents)] + matched_intents = {res[RES_INTENT_NAME] for res in top_intents} + for intent in self.regexes_per_intent: + if intent not in matched_intents: + top_intents.append(intent_classification_result(intent, 0.0)) + + # The None intent is not included in the regex patterns and is thus + # never matched by the deterministic parser + top_intents.append(intent_classification_result(None, 0.0)) + return top_intents + + @fitted_required + def get_slots(self, text, intent): + """Extracts slots from a text input, with the knowledge of the intent + + Args: + text (str): input + intent (str): the intent which the input corresponds to + + Returns: + list: the list of extracted slots + + Raises: + IntentNotFoundError: When the intent was not part of the training + data + """ + if intent is None: + return [] + + if intent not in self.regexes_per_intent: + raise IntentNotFoundError(intent) + + slots = self.parse(text, intents=[intent])[RES_SLOTS] + if slots is None: + slots = [] + return slots + + def _get_intent_stop_words(self, intent): + whitelist = self._stop_words_whitelist.get(intent, set()) + return self._stop_words.difference(whitelist) + + def _preprocess_text(self, string, intent): + """Replaces stop words and characters that are tokenized out by + whitespaces""" + tokens = tokenize(string, self.language) + current_idx = 0 + cleaned_string = "" + stop_words = self._get_intent_stop_words(intent) + for token in tokens: + if stop_words and normalize_token(token) in stop_words: + token.value = "".join(" " for _ in range(len(token.value))) + prefix_length = token.start - current_idx + cleaned_string += "".join((" " for _ in range(prefix_length))) + cleaned_string += token.value + current_idx = token.end + suffix_length = len(string) - current_idx + cleaned_string += "".join((" " for _ in range(suffix_length))) + return cleaned_string + + def _get_matching_result(self, text, processed_text, regex, intent, + entities_ranges_mapping=None): + found_result = regex.match(processed_text) + if found_result is None: + return None + parsed_intent = intent_classification_result(intent_name=intent, + probability=1.0) + slots = [] + for group_name in found_result.groupdict(): + ref_group_name = group_name + if "_" in group_name: + ref_group_name = group_name.split("_")[0] + slot_name = self.group_names_to_slot_names[ref_group_name] + entity = self.slot_names_to_entities[intent][slot_name] + rng = (found_result.start(group_name), + found_result.end(group_name)) + if entities_ranges_mapping is not None: + if rng in entities_ranges_mapping: + rng = entities_ranges_mapping[rng] + else: + shift = _get_range_shift( + rng, entities_ranges_mapping) + rng = {START: rng[0] + shift, END: rng[1] + shift} + else: + rng = {START: rng[0], END: rng[1]} + value = text[rng[START]:rng[END]] + parsed_slot = unresolved_slot( + match_range=rng, value=value, entity=entity, + slot_name=slot_name) + slots.append(parsed_slot) + parsed_slots = _deduplicate_overlapping_slots(slots, self.language) + parsed_slots = sorted(parsed_slots, + key=lambda s: s[RES_MATCH_RANGE][START]) + return extraction_result(parsed_intent, parsed_slots) + + def _generate_patterns(self, intent, intent_utterances, + entity_placeholders): + unique_patterns = set() + patterns = [] + stop_words = self._get_intent_stop_words(intent) + for utterance in intent_utterances: + pattern = self._utterance_to_pattern( + utterance, stop_words, entity_placeholders) + if pattern not in unique_patterns: + unique_patterns.add(pattern) + patterns.append(pattern) + return patterns + + def _utterance_to_pattern(self, utterance, stop_words, + entity_placeholders): + from snips_nlu_utils import normalize + + slot_names_count = defaultdict(int) + pattern = [] + for chunk in utterance[DATA]: + if SLOT_NAME in chunk: + slot_name = chunk[SLOT_NAME] + slot_names_count[slot_name] += 1 + group_name = self.slot_names_to_group_names[slot_name] + count = slot_names_count[slot_name] + if count > 1: + group_name = "%s_%s" % (group_name, count) + placeholder = entity_placeholders[chunk[ENTITY]] + pattern.append(r"(?P<%s>%s)" % (group_name, placeholder)) + else: + tokens = tokenize_light(chunk[TEXT], self.language) + pattern += [regex_escape(t.lower()) for t in tokens + if normalize(t) not in stop_words] + + pattern = r"^%s%s%s$" % (WHITESPACE_PATTERN, + WHITESPACE_PATTERN.join(pattern), + WHITESPACE_PATTERN) + return pattern + + @check_persisted_path + def persist(self, path): + """Persists the object at the given path""" + path.mkdir() + parser_json = json_string(self.to_dict()) + parser_path = path / "intent_parser.json" + + with parser_path.open(mode="w", encoding="utf8") as f: + f.write(parser_json) + self.persist_metadata(path) + + @classmethod + def from_path(cls, path, **shared): + """Loads a :class:`DeterministicIntentParser` instance from a path + + The data at the given path must have been generated using + :func:`~DeterministicIntentParser.persist` + """ + path = Path(path) + model_path = path / "intent_parser.json" + if not model_path.exists(): + raise LoadingError( + "Missing deterministic intent parser metadata file: %s" + % model_path.name) + + with model_path.open(encoding="utf8") as f: + metadata = json.load(f) + return cls.from_dict(metadata, **shared) + + def to_dict(self): + """Returns a json-serializable dict""" + stop_words_whitelist = None + if self._stop_words_whitelist is not None: + stop_words_whitelist = { + intent: sorted(values) + for intent, values in iteritems(self._stop_words_whitelist)} + return { + "config": self.config.to_dict(), + "language_code": self.language, + "patterns": self.patterns, + "group_names_to_slot_names": self.group_names_to_slot_names, + "slot_names_to_entities": self.slot_names_to_entities, + "stop_words_whitelist": stop_words_whitelist + } + + @classmethod + def from_dict(cls, unit_dict, **shared): + """Creates a :class:`DeterministicIntentParser` instance from a dict + + The dict must have been generated with + :func:`~DeterministicIntentParser.to_dict` + """ + config = cls.config_type.from_dict(unit_dict["config"]) + parser = cls(config=config, **shared) + parser.patterns = unit_dict["patterns"] + parser.language = unit_dict["language_code"] + parser.group_names_to_slot_names = unit_dict[ + "group_names_to_slot_names"] + parser.slot_names_to_entities = unit_dict["slot_names_to_entities"] + if parser.fitted: + whitelist = unit_dict.get("stop_words_whitelist", dict()) + # pylint:disable=protected-access + parser._stop_words_whitelist = { + intent: set(values) for intent, values in iteritems(whitelist)} + # pylint:enable=protected-access + return parser + + +def _get_range_shift(matched_range, ranges_mapping): + shift = 0 + previous_replaced_range_end = None + matched_start = matched_range[0] + for replaced_range, orig_range in iteritems(ranges_mapping): + if replaced_range[1] <= matched_start: + if previous_replaced_range_end is None \ + or replaced_range[1] > previous_replaced_range_end: + previous_replaced_range_end = replaced_range[1] + shift = orig_range[END] - replaced_range[1] + return shift + + +def _get_group_names_to_slot_names(slot_names_mapping): + slot_names = {slot_name for mapping in itervalues(slot_names_mapping) + for slot_name in mapping} + return {"group%s" % i: name + for i, name in enumerate(sorted(slot_names))} + + +def _get_entity_placeholders(dataset, language): + return { + e: _get_entity_name_placeholder(e, language) + for e in dataset[ENTITIES] + } + + +def _deduplicate_overlapping_slots(slots, language): + def overlap(lhs_slot, rhs_slot): + return ranges_overlap(lhs_slot[RES_MATCH_RANGE], + rhs_slot[RES_MATCH_RANGE]) + + def sort_key_fn(slot): + tokens = tokenize(slot[RES_VALUE], language) + return -(len(tokens) + len(slot[RES_VALUE])) + + deduplicated_slots = deduplicate_overlapping_items( + slots, overlap, sort_key_fn) + return sorted(deduplicated_slots, + key=lambda slot: slot[RES_MATCH_RANGE][START]) + + +def _get_entity_name_placeholder(entity_label, language): + return "%%%s%%" % "".join( + tokenize_light(entity_label, language)).upper() diff --git a/snips_inference_agl/intent_parser/intent_parser.py b/snips_inference_agl/intent_parser/intent_parser.py new file mode 100644 index 0000000..b269774 --- /dev/null +++ b/snips_inference_agl/intent_parser/intent_parser.py @@ -0,0 +1,85 @@ +from abc import abstractmethod, ABCMeta + +from future.utils import with_metaclass + +from snips_inference_agl.common.abc_utils import classproperty +from snips_inference_agl.pipeline.processing_unit import ProcessingUnit + + +class IntentParser(with_metaclass(ABCMeta, ProcessingUnit)): + """Abstraction which performs intent parsing + + A custom intent parser must inherit this class to be used in a + :class:`.SnipsNLUEngine` + """ + + @classproperty + def unit_name(cls): # pylint:disable=no-self-argument + return IntentParser.registered_name(cls) + + @abstractmethod + def fit(self, dataset, force_retrain): + """Fit the intent parser with a valid Snips dataset + + Args: + dataset (dict): valid Snips NLU dataset + force_retrain (bool): specify whether or not sub units of the + intent parser that may be already trained should be retrained + """ + pass + + @abstractmethod + def parse(self, text, intents, top_n): + """Performs intent parsing on the provided *text* + + Args: + text (str): input + intents (str or list of str): if provided, reduces the scope of + intent parsing to the provided list of intents + top_n (int, optional): when provided, this method will return a + list of at most top_n most likely intents, instead of a single + parsing result. + Note that the returned list can contain less than ``top_n`` + elements, for instance when the parameter ``intents`` is not + None, or when ``top_n`` is greater than the total number of + intents. + + Returns: + dict or list: the most likely intent(s) along with the extracted + slots. See :func:`.parsing_result` and :func:`.extraction_result` + for the output format. + """ + pass + + @abstractmethod + def get_intents(self, text): + """Performs intent classification on the provided *text* and returns + the list of intents ordered by decreasing probability + + The length of the returned list is exactly the number of intents in the + dataset + 1 for the None intent + + .. note:: + + The probabilities returned along with each intent are not + guaranteed to sum to 1.0. They should be considered as scores + between 0 and 1. + """ + pass + + @abstractmethod + def get_slots(self, text, intent): + """Extract slots from a text input, with the knowledge of the intent + + Args: + text (str): input + intent (str): the intent which the input corresponds to + + Returns: + list: the list of extracted slots + + Raises: + IntentNotFoundError: when the intent was not part of the training + data + """ + pass diff --git a/snips_inference_agl/intent_parser/lookup_intent_parser.py b/snips_inference_agl/intent_parser/lookup_intent_parser.py new file mode 100644 index 0000000..921dcc5 --- /dev/null +++ b/snips_inference_agl/intent_parser/lookup_intent_parser.py @@ -0,0 +1,509 @@ +from __future__ import unicode_literals + +import json +import logging +from builtins import str +from collections import defaultdict +from itertools import combinations +from pathlib import Path + +from future.utils import iteritems, itervalues +from snips_nlu_utils import normalize, hash_str + +from snips_inference_agl.common.log_utils import log_elapsed_time, log_result +from snips_inference_agl.common.utils import ( + check_persisted_path, deduplicate_overlapping_entities, fitted_required, + json_string) +from snips_inference_agl.constants import ( + DATA, END, ENTITIES, ENTITY, ENTITY_KIND, INTENTS, LANGUAGE, RES_INTENT, + RES_INTENT_NAME, RES_MATCH_RANGE, RES_SLOTS, SLOT_NAME, START, TEXT, + UTTERANCES, RES_PROBA) +from snips_inference_agl.dataset import ( + validate_and_format_dataset, extract_intent_entities) +from snips_inference_agl.dataset.utils import get_stop_words_whitelist +from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity +from snips_inference_agl.exceptions import IntentNotFoundError, LoadingError +from snips_inference_agl.intent_parser.intent_parser import IntentParser +from snips_inference_agl.pipeline.configs import LookupIntentParserConfig +from snips_inference_agl.preprocessing import tokenize_light +from snips_inference_agl.resources import get_stop_words +from snips_inference_agl.result import ( + empty_result, intent_classification_result, parsing_result, + unresolved_slot, extraction_result) + +logger = logging.getLogger(__name__) + + +@IntentParser.register("lookup_intent_parser") +class LookupIntentParser(IntentParser): + """A deterministic Intent parser implementation based on a dictionary + + This intent parser is very strict by nature, and tends to have a very good + precision but a low recall. For this reason, it is interesting to use it + first before potentially falling back to another parser. + """ + + config_type = LookupIntentParserConfig + + def __init__(self, config=None, **shared): + """The lookup intent parser can be configured by passing a + :class:`.LookupIntentParserConfig`""" + super(LookupIntentParser, self).__init__(config, **shared) + self._language = None + self._stop_words = None + self._stop_words_whitelist = None + self._map = None + self._intents_names = [] + self._slots_names = [] + self._intents_mapping = dict() + self._slots_mapping = dict() + self._entity_scopes = None + + @property + def language(self): + return self._language + + @language.setter + def language(self, value): + self._language = value + if value is None: + self._stop_words = None + else: + if self.config.ignore_stop_words: + self._stop_words = get_stop_words(self.resources) + else: + self._stop_words = set() + + @property + def fitted(self): + """Whether or not the intent parser has already been trained""" + return self._map is not None + + @log_elapsed_time( + logger, logging.INFO, "Fitted lookup intent parser in {elapsed_time}") + def fit(self, dataset, force_retrain=True): + """Fits the intent parser with a valid Snips dataset""" + logger.info("Fitting lookup intent parser...") + dataset = validate_and_format_dataset(dataset) + self.load_resources_if_needed(dataset[LANGUAGE]) + self.fit_builtin_entity_parser_if_needed(dataset) + self.fit_custom_entity_parser_if_needed(dataset) + self.language = dataset[LANGUAGE] + self._entity_scopes = _get_entity_scopes(dataset) + self._map = dict() + self._stop_words_whitelist = get_stop_words_whitelist( + dataset, self._stop_words) + entity_placeholders = _get_entity_placeholders(dataset, self.language) + + ambiguous_keys = set() + for (key, val) in self._generate_io_mapping(dataset[INTENTS], + entity_placeholders): + key = hash_str(key) + # handle key collisions -*- flag ambiguous entries -*- + if key in self._map and self._map[key] != val: + ambiguous_keys.add(key) + else: + self._map[key] = val + + # delete ambiguous keys + for key in ambiguous_keys: + self._map.pop(key) + + return self + + @log_result(logger, logging.DEBUG, "LookupIntentParser result -> {result}") + @log_elapsed_time(logger, logging.DEBUG, "Parsed in {elapsed_time}.") + @fitted_required + def parse(self, text, intents=None, top_n=None): + """Performs intent parsing on the provided *text* + + Intent and slots are extracted simultaneously through pattern matching + + Args: + text (str): input + intents (str or list of str): if provided, reduces the scope of + intent parsing to the provided list of intents + top_n (int, optional): when provided, this method will return a + list of at most top_n most likely intents, instead of a single + parsing result. + Note that the returned list can contain less than ``top_n`` + elements, for instance when the parameter ``intents`` is not + None, or when ``top_n`` is greater than the total number of + intents. + + Returns: + dict or list: the most likely intent(s) along with the extracted + slots. See :func:`.parsing_result` and :func:`.extraction_result` + for the output format. + + Raises: + NotTrained: when the intent parser is not fitted + """ + if top_n is None: + top_intents = self._parse_top_intents(text, top_n=1, + intents=intents) + if top_intents: + intent = top_intents[0][RES_INTENT] + slots = top_intents[0][RES_SLOTS] + if intent[RES_PROBA] <= 0.5: + # return None in case of ambiguity + return empty_result(text, probability=1.0) + return parsing_result(text, intent, slots) + return empty_result(text, probability=1.0) + return self._parse_top_intents(text, top_n=top_n, intents=intents) + + def _parse_top_intents(self, text, top_n, intents=None): + if isinstance(intents, str): + intents = {intents} + elif isinstance(intents, list): + intents = set(intents) + + if top_n < 1: + raise ValueError( + "top_n argument must be greater or equal to 1, but got: %s" + % top_n) + + results_per_intent = defaultdict(list) + for text_candidate, entities in self._get_candidates(text, intents): + val = self._map.get(hash_str(text_candidate)) + if val is not None: + result = self._parse_map_output(text, val, entities, intents) + if result: + intent_name = result[RES_INTENT][RES_INTENT_NAME] + results_per_intent[intent_name].append(result) + + results = [] + for intent_results in itervalues(results_per_intent): + sorted_results = sorted(intent_results, + key=lambda res: len(res[RES_SLOTS])) + results.append(sorted_results[0]) + + # In some rare cases there can be multiple ambiguous intents + # In such cases, priority is given to results containing fewer slots + weights = [1.0 / (1.0 + len(res[RES_SLOTS])) for res in results] + total_weight = sum(weights) + + for res, weight in zip(results, weights): + res[RES_INTENT][RES_PROBA] = weight / total_weight + + results = sorted(results, key=lambda r: -r[RES_INTENT][RES_PROBA]) + return results[:top_n] + + def _get_candidates(self, text, intents): + candidates = defaultdict(list) + for grouped_entity_scope in self._entity_scopes: + entity_scope = grouped_entity_scope["entity_scope"] + intent_group = grouped_entity_scope["intent_group"] + intent_group = [intent_ for intent_ in intent_group + if intents is None or intent_ in intents] + if not intent_group: + continue + + builtin_entities = self.builtin_entity_parser.parse( + text, scope=entity_scope["builtin"], use_cache=True) + custom_entities = self.custom_entity_parser.parse( + text, scope=entity_scope["custom"], use_cache=True) + all_entities = builtin_entities + custom_entities + all_entities = deduplicate_overlapping_entities(all_entities) + + # We generate all subsets of entities to match utterances + # containing ambivalent words which can be both entity values or + # random words + for entities in _get_entities_combinations(all_entities): + processed_text = self._replace_entities_with_placeholders( + text, entities) + for intent in intent_group: + cleaned_text = self._preprocess_text(text, intent) + cleaned_processed_text = self._preprocess_text( + processed_text, intent) + + raw_candidate = cleaned_text, [] + placeholder_candidate = cleaned_processed_text, entities + intent_candidates = [raw_candidate, placeholder_candidate] + for text_input, text_entities in intent_candidates: + if text_input not in candidates \ + or text_entities not in candidates[text_input]: + candidates[text_input].append(text_entities) + yield text_input, text_entities + + def _parse_map_output(self, text, output, entities, intents): + """Parse the map output to the parser's result format""" + intent_id, slot_ids = output + intent_name = self._intents_names[intent_id] + if intents is not None and intent_name not in intents: + return None + + parsed_intent = intent_classification_result( + intent_name=intent_name, probability=1.0) + slots = [] + # assert invariant + assert len(slot_ids) == len(entities) + for slot_id, entity in zip(slot_ids, entities): + slot_name = self._slots_names[slot_id] + rng_start = entity[RES_MATCH_RANGE][START] + rng_end = entity[RES_MATCH_RANGE][END] + slot_value = text[rng_start:rng_end] + entity_name = entity[ENTITY_KIND] + slot = unresolved_slot( + [rng_start, rng_end], slot_value, entity_name, slot_name) + slots.append(slot) + + return extraction_result(parsed_intent, slots) + + @fitted_required + def get_intents(self, text): + """Returns the list of intents ordered by decreasing probability + + The length of the returned list is exactly the number of intents in the + dataset + 1 for the None intent + """ + nb_intents = len(self._intents_names) + top_intents = [intent_result[RES_INTENT] for intent_result in + self._parse_top_intents(text, top_n=nb_intents)] + matched_intents = {res[RES_INTENT_NAME] for res in top_intents} + for intent in self._intents_names: + if intent not in matched_intents: + top_intents.append(intent_classification_result(intent, 0.0)) + + # The None intent is not included in the lookup table and is thus + # never matched by the lookup parser + top_intents.append(intent_classification_result(None, 0.0)) + return top_intents + + @fitted_required + def get_slots(self, text, intent): + """Extracts slots from a text input, with the knowledge of the intent + + Args: + text (str): input + intent (str): the intent which the input corresponds to + + Returns: + list: the list of extracted slots + + Raises: + IntentNotFoundError: When the intent was not part of the training + data + """ + if intent is None: + return [] + + if intent not in self._intents_names: + raise IntentNotFoundError(intent) + + slots = self.parse(text, intents=[intent])[RES_SLOTS] + if slots is None: + slots = [] + return slots + + def _get_intent_stop_words(self, intent): + whitelist = self._stop_words_whitelist.get(intent, set()) + return self._stop_words.difference(whitelist) + + def _get_intent_id(self, intent_name): + """generate a numeric id for an intent + + Args: + intent_name (str): intent name + + Returns: + int: numeric id + + """ + intent_id = self._intents_mapping.get(intent_name) + if intent_id is None: + intent_id = len(self._intents_names) + self._intents_names.append(intent_name) + self._intents_mapping[intent_name] = intent_id + + return intent_id + + def _get_slot_id(self, slot_name): + """generate a numeric id for a slot + + Args: + slot_name (str): intent name + + Returns: + int: numeric id + + """ + slot_id = self._slots_mapping.get(slot_name) + if slot_id is None: + slot_id = len(self._slots_names) + self._slots_names.append(slot_name) + self._slots_mapping[slot_name] = slot_id + + return slot_id + + def _preprocess_text(self, txt, intent): + """Replaces stop words and characters that are tokenized out by + whitespaces""" + stop_words = self._get_intent_stop_words(intent) + tokens = tokenize_light(txt, self.language) + cleaned_string = " ".join( + [tkn for tkn in tokens if normalize(tkn) not in stop_words]) + return cleaned_string.lower() + + def _generate_io_mapping(self, intents, entity_placeholders): + """Generate input-output pairs""" + for intent_name, intent in sorted(iteritems(intents)): + intent_id = self._get_intent_id(intent_name) + for entry in intent[UTTERANCES]: + yield self._build_io_mapping( + intent_id, entry, entity_placeholders) + + def _build_io_mapping(self, intent_id, utterance, entity_placeholders): + input_ = [] + output = [intent_id] + slots = [] + for chunk in utterance[DATA]: + if SLOT_NAME in chunk: + slot_name = chunk[SLOT_NAME] + slot_id = self._get_slot_id(slot_name) + entity_name = chunk[ENTITY] + placeholder = entity_placeholders[entity_name] + input_.append(placeholder) + slots.append(slot_id) + else: + input_.append(chunk[TEXT]) + output.append(slots) + + intent = self._intents_names[intent_id] + key = self._preprocess_text(" ".join(input_), intent) + + return key, output + + def _replace_entities_with_placeholders(self, text, entities): + if not entities: + return text + entities = sorted(entities, key=lambda e: e[RES_MATCH_RANGE][START]) + processed_text = "" + current_idx = 0 + for ent in entities: + start = ent[RES_MATCH_RANGE][START] + end = ent[RES_MATCH_RANGE][END] + processed_text += text[current_idx:start] + place_holder = _get_entity_name_placeholder( + ent[ENTITY_KIND], self.language) + processed_text += place_holder + current_idx = end + processed_text += text[current_idx:] + + return processed_text + + @check_persisted_path + def persist(self, path): + """Persists the object at the given path""" + path.mkdir() + parser_json = json_string(self.to_dict()) + parser_path = path / "intent_parser.json" + + with parser_path.open(mode="w", encoding="utf8") as pfile: + pfile.write(parser_json) + self.persist_metadata(path) + + @classmethod + def from_path(cls, path, **shared): + """Loads a :class:`LookupIntentParser` instance from a path + + The data at the given path must have been generated using + :func:`~LookupIntentParser.persist` + """ + path = Path(path) + model_path = path / "intent_parser.json" + if not model_path.exists(): + raise LoadingError( + "Missing lookup intent parser metadata file: %s" + % model_path.name) + + with model_path.open(encoding="utf8") as pfile: + metadata = json.load(pfile) + return cls.from_dict(metadata, **shared) + + def to_dict(self): + """Returns a json-serializable dict""" + stop_words_whitelist = None + if self._stop_words_whitelist is not None: + stop_words_whitelist = { + intent: sorted(values) + for intent, values in iteritems(self._stop_words_whitelist)} + return { + "config": self.config.to_dict(), + "language_code": self.language, + "map": self._map, + "slots_names": self._slots_names, + "intents_names": self._intents_names, + "entity_scopes": self._entity_scopes, + "stop_words_whitelist": stop_words_whitelist, + } + + @classmethod + def from_dict(cls, unit_dict, **shared): + """Creates a :class:`LookupIntentParser` instance from a dict + + The dict must have been generated with + :func:`~LookupIntentParser.to_dict` + """ + config = cls.config_type.from_dict(unit_dict["config"]) + parser = cls(config=config, **shared) + parser.language = unit_dict["language_code"] + # pylint:disable=protected-access + parser._map = _convert_dict_keys_to_int(unit_dict["map"]) + parser._slots_names = unit_dict["slots_names"] + parser._intents_names = unit_dict["intents_names"] + parser._entity_scopes = unit_dict["entity_scopes"] + if parser.fitted: + whitelist = unit_dict["stop_words_whitelist"] + parser._stop_words_whitelist = { + intent: set(values) for intent, values in iteritems(whitelist)} + # pylint:enable=protected-access + return parser + + +def _get_entity_scopes(dataset): + intent_entities = extract_intent_entities(dataset) + intent_groups = [] + entity_scopes = [] + for intent, entities in sorted(iteritems(intent_entities)): + scope = { + "builtin": list( + {ent for ent in entities if is_builtin_entity(ent)}), + "custom": list( + {ent for ent in entities if not is_builtin_entity(ent)}) + } + if scope in entity_scopes: + group_idx = entity_scopes.index(scope) + intent_groups[group_idx].append(intent) + else: + entity_scopes.append(scope) + intent_groups.append([intent]) + return [ + { + "intent_group": intent_group, + "entity_scope": entity_scope + } for intent_group, entity_scope in zip(intent_groups, entity_scopes) + ] + + +def _get_entity_placeholders(dataset, language): + return { + e: _get_entity_name_placeholder(e, language) for e in dataset[ENTITIES] + } + + +def _get_entity_name_placeholder(entity_label, language): + return "%%%s%%" % "".join(tokenize_light(entity_label, language)).upper() + + +def _convert_dict_keys_to_int(dct): + if isinstance(dct, dict): + return {int(k): v for k, v in iteritems(dct)} + return dct + + +def _get_entities_combinations(entities): + yield () + for nb_entities in reversed(range(1, len(entities) + 1)): + for combination in combinations(entities, nb_entities): + yield combination diff --git a/snips_inference_agl/intent_parser/probabilistic_intent_parser.py b/snips_inference_agl/intent_parser/probabilistic_intent_parser.py new file mode 100644 index 0000000..23e7829 --- /dev/null +++ b/snips_inference_agl/intent_parser/probabilistic_intent_parser.py @@ -0,0 +1,250 @@ +from __future__ import unicode_literals + +import json +import logging +from builtins import str +from copy import deepcopy +from datetime import datetime +from pathlib import Path + +from future.utils import iteritems, itervalues + +from snips_inference_agl.common.log_utils import log_elapsed_time, log_result +from snips_inference_agl.common.utils import ( + check_persisted_path, elapsed_since, fitted_required, json_string) +from snips_inference_agl.constants import INTENTS, RES_INTENT_NAME +from snips_inference_agl.dataset import validate_and_format_dataset +from snips_inference_agl.exceptions import IntentNotFoundError, LoadingError +from snips_inference_agl.intent_classifier import IntentClassifier +from snips_inference_agl.intent_parser.intent_parser import IntentParser +from snips_inference_agl.pipeline.configs import ProbabilisticIntentParserConfig +from snips_inference_agl.result import parsing_result, extraction_result +from snips_inference_agl.slot_filler import SlotFiller + +logger = logging.getLogger(__name__) + + +@IntentParser.register("probabilistic_intent_parser") +class ProbabilisticIntentParser(IntentParser): + """Intent parser which consists in two steps: intent classification then + slot filling""" + + config_type = ProbabilisticIntentParserConfig + + def __init__(self, config=None, **shared): + """The probabilistic intent parser can be configured by passing a + :class:`.ProbabilisticIntentParserConfig`""" + super(ProbabilisticIntentParser, self).__init__(config, **shared) + self.intent_classifier = None + self.slot_fillers = dict() + + @property + def fitted(self): + """Whether or not the intent parser has already been fitted""" + return self.intent_classifier is not None \ + and self.intent_classifier.fitted \ + and all(slot_filler is not None and slot_filler.fitted + for slot_filler in itervalues(self.slot_fillers)) + + @log_elapsed_time(logger, logging.INFO, + "Fitted probabilistic intent parser in {elapsed_time}") + # pylint:disable=arguments-differ + def fit(self, dataset, force_retrain=True): + """Fits the probabilistic intent parser + + Args: + dataset (dict): A valid Snips dataset + force_retrain (bool, optional): If *False*, will not retrain intent + classifier and slot fillers when they are already fitted. + Default to *True*. + + Returns: + :class:`ProbabilisticIntentParser`: The same instance, trained + """ + logger.info("Fitting probabilistic intent parser...") + dataset = validate_and_format_dataset(dataset) + intents = list(dataset[INTENTS]) + if self.intent_classifier is None: + self.intent_classifier = IntentClassifier.from_config( + self.config.intent_classifier_config, + builtin_entity_parser=self.builtin_entity_parser, + custom_entity_parser=self.custom_entity_parser, + resources=self.resources, + random_state=self.random_state, + ) + + if force_retrain or not self.intent_classifier.fitted: + self.intent_classifier.fit(dataset) + + if self.slot_fillers is None: + self.slot_fillers = dict() + slot_fillers_start = datetime.now() + for intent_name in intents: + # We need to copy the slot filler config as it may be mutated + if self.slot_fillers.get(intent_name) is None: + slot_filler_config = deepcopy(self.config.slot_filler_config) + self.slot_fillers[intent_name] = SlotFiller.from_config( + slot_filler_config, + builtin_entity_parser=self.builtin_entity_parser, + custom_entity_parser=self.custom_entity_parser, + resources=self.resources, + random_state=self.random_state, + ) + if force_retrain or not self.slot_fillers[intent_name].fitted: + self.slot_fillers[intent_name].fit(dataset, intent_name) + logger.debug("Fitted slot fillers in %s", + elapsed_since(slot_fillers_start)) + return self + + # pylint:enable=arguments-differ + + @log_result(logger, logging.DEBUG, + "ProbabilisticIntentParser result -> {result}") + @log_elapsed_time(logger, logging.DEBUG, + "ProbabilisticIntentParser parsed in {elapsed_time}") + @fitted_required + def parse(self, text, intents=None, top_n=None): + """Performs intent parsing on the provided *text* by first classifying + the intent and then using the correspond slot filler to extract slots + + Args: + text (str): input + intents (str or list of str): if provided, reduces the scope of + intent parsing to the provided list of intents + top_n (int, optional): when provided, this method will return a + list of at most top_n most likely intents, instead of a single + parsing result. + Note that the returned list can contain less than ``top_n`` + elements, for instance when the parameter ``intents`` is not + None, or when ``top_n`` is greater than the total number of + intents. + + Returns: + dict or list: the most likely intent(s) along with the extracted + slots. See :func:`.parsing_result` and :func:`.extraction_result` + for the output format. + + Raises: + NotTrained: when the intent parser is not fitted + """ + if isinstance(intents, str): + intents = {intents} + elif isinstance(intents, list): + intents = list(intents) + + if top_n is None: + intent_result = self.intent_classifier.get_intent(text, intents) + intent_name = intent_result[RES_INTENT_NAME] + if intent_name is not None: + slots = self.slot_fillers[intent_name].get_slots(text) + else: + slots = [] + return parsing_result(text, intent_result, slots) + + results = [] + intents_results = self.intent_classifier.get_intents(text) + for intent_result in intents_results[:top_n]: + intent_name = intent_result[RES_INTENT_NAME] + if intent_name is not None: + slots = self.slot_fillers[intent_name].get_slots(text) + else: + slots = [] + results.append(extraction_result(intent_result, slots)) + return results + + @fitted_required + def get_intents(self, text): + """Returns the list of intents ordered by decreasing probability + + The length of the returned list is exactly the number of intents in the + dataset + 1 for the None intent + """ + return self.intent_classifier.get_intents(text) + + @fitted_required + def get_slots(self, text, intent): + """Extracts slots from a text input, with the knowledge of the intent + + Args: + text (str): input + intent (str): the intent which the input corresponds to + + Returns: + list: the list of extracted slots + + Raises: + IntentNotFoundError: When the intent was not part of the training + data + """ + if intent is None: + return [] + + if intent not in self.slot_fillers: + raise IntentNotFoundError(intent) + return self.slot_fillers[intent].get_slots(text) + + @check_persisted_path + def persist(self, path): + """Persists the object at the given path""" + path.mkdir() + sorted_slot_fillers = sorted(iteritems(self.slot_fillers)) + slot_fillers = [] + for i, (intent, slot_filler) in enumerate(sorted_slot_fillers): + slot_filler_name = "slot_filler_%s" % i + slot_filler.persist(path / slot_filler_name) + slot_fillers.append({ + "intent": intent, + "slot_filler_name": slot_filler_name + }) + + if self.intent_classifier is not None: + self.intent_classifier.persist(path / "intent_classifier") + + model = { + "config": self.config.to_dict(), + "slot_fillers": slot_fillers + } + model_json = json_string(model) + model_path = path / "intent_parser.json" + with model_path.open(mode="w") as f: + f.write(model_json) + self.persist_metadata(path) + + @classmethod + def from_path(cls, path, **shared): + """Loads a :class:`ProbabilisticIntentParser` instance from a path + + The data at the given path must have been generated using + :func:`~ProbabilisticIntentParser.persist` + """ + path = Path(path) + model_path = path / "intent_parser.json" + if not model_path.exists(): + raise LoadingError( + "Missing probabilistic intent parser model file: %s" + % model_path.name) + + with model_path.open(encoding="utf8") as f: + model = json.load(f) + + config = cls.config_type.from_dict(model["config"]) + parser = cls(config=config, **shared) + classifier = None + intent_classifier_path = path / "intent_classifier" + if intent_classifier_path.exists(): + classifier_unit_name = config.intent_classifier_config.unit_name + classifier = IntentClassifier.load_from_path( + intent_classifier_path, classifier_unit_name, **shared) + + slot_fillers = dict() + slot_filler_unit_name = config.slot_filler_config.unit_name + for slot_filler_conf in model["slot_fillers"]: + intent = slot_filler_conf["intent"] + slot_filler_path = path / slot_filler_conf["slot_filler_name"] + slot_filler = SlotFiller.load_from_path( + slot_filler_path, slot_filler_unit_name, **shared) + slot_fillers[intent] = slot_filler + + parser.intent_classifier = classifier + parser.slot_fillers = slot_fillers + return parser diff --git a/snips_inference_agl/languages.py b/snips_inference_agl/languages.py new file mode 100644 index 0000000..cc205a3 --- /dev/null +++ b/snips_inference_agl/languages.py @@ -0,0 +1,44 @@ +from __future__ import unicode_literals + +import re +import string + +_PUNCTUATION_REGEXES = dict() +_NUM2WORDS_SUPPORT = dict() + + +# pylint:disable=unused-argument +def get_default_sep(language): + return " " + + +# pylint:enable=unused-argument + +# pylint:disable=unused-argument +def get_punctuation(language): + return string.punctuation + + +# pylint:enable=unused-argument + + +def get_punctuation_regex(language): + global _PUNCTUATION_REGEXES + if language not in _PUNCTUATION_REGEXES: + pattern = r"|".join(re.escape(p) for p in get_punctuation(language)) + _PUNCTUATION_REGEXES[language] = re.compile(pattern) + return _PUNCTUATION_REGEXES[language] + + +def supports_num2words(language): + from num2words import num2words + + global _NUM2WORDS_SUPPORT + + if language not in _NUM2WORDS_SUPPORT: + try: + num2words(0, lang=language) + _NUM2WORDS_SUPPORT[language] = True + except NotImplementedError: + _NUM2WORDS_SUPPORT[language] = False + return _NUM2WORDS_SUPPORT[language] diff --git a/snips_inference_agl/nlu_engine/__init__.py b/snips_inference_agl/nlu_engine/__init__.py new file mode 100644 index 0000000..b45eb32 --- /dev/null +++ b/snips_inference_agl/nlu_engine/__init__.py @@ -0,0 +1 @@ +from snips_inference_agl.nlu_engine.nlu_engine import SnipsNLUEngine diff --git a/snips_inference_agl/nlu_engine/nlu_engine.py b/snips_inference_agl/nlu_engine/nlu_engine.py new file mode 100644 index 0000000..3f19aba --- /dev/null +++ b/snips_inference_agl/nlu_engine/nlu_engine.py @@ -0,0 +1,330 @@ +from __future__ import unicode_literals + +import json +import logging +from builtins import str +from pathlib import Path + +from future.utils import itervalues + +from snips_inference_agl.__about__ import __model_version__, __version__ +from snips_inference_agl.common.log_utils import log_elapsed_time +from snips_inference_agl.common.utils import (fitted_required) +from snips_inference_agl.constants import ( + AUTOMATICALLY_EXTENSIBLE, BUILTIN_ENTITY_PARSER, CUSTOM_ENTITY_PARSER, + ENTITIES, ENTITY_KIND, LANGUAGE, RESOLVED_VALUE, RES_ENTITY, + RES_INTENT, RES_INTENT_NAME, RES_MATCH_RANGE, RES_PROBA, RES_SLOTS, + RES_VALUE, RESOURCES, BYPASS_VERSION_CHECK) +# from snips_inference_agl.dataset import validate_and_format_dataset +from snips_inference_agl.entity_parser import CustomEntityParser +from snips_inference_agl.entity_parser.builtin_entity_parser import ( + BuiltinEntityParser, is_builtin_entity) +from snips_inference_agl.exceptions import ( + InvalidInputError, IntentNotFoundError, LoadingError, + IncompatibleModelError) +from snips_inference_agl.intent_parser import IntentParser +from snips_inference_agl.pipeline.configs import NLUEngineConfig +from snips_inference_agl.pipeline.processing_unit import ProcessingUnit +from snips_inference_agl.resources import load_resources_from_dir +from snips_inference_agl.result import ( + builtin_slot, custom_slot, empty_result, extraction_result, is_empty, + parsing_result) + +logger = logging.getLogger(__name__) + + +@ProcessingUnit.register("nlu_engine") +class SnipsNLUEngine(ProcessingUnit): + """Main class to use for intent parsing + + A :class:`SnipsNLUEngine` relies on a list of :class:`.IntentParser` + object to parse intents, by calling them successively using the first + positive output. + + With the default parameters, it will use the two following intent parsers + in this order: + + - a :class:`.DeterministicIntentParser` + - a :class:`.ProbabilisticIntentParser` + + The logic behind is to first use a conservative parser which has a very + good precision while its recall is modest, so simple patterns will be + caught, and then fallback on a second parser which is machine-learning + based and will be able to parse unseen utterances while ensuring a good + precision and recall. + """ + + config_type = NLUEngineConfig + + def __init__(self, config=None, **shared): + """The NLU engine can be configured by passing a + :class:`.NLUEngineConfig`""" + super(SnipsNLUEngine, self).__init__(config, **shared) + self.intent_parsers = [] + """list of :class:`.IntentParser`""" + self.dataset_metadata = None + + @classmethod + def default_config(cls): + # Do not use the global default config, and use per-language default + # configs instead + return None + + @property + def fitted(self): + """Whether or not the nlu engine has already been fitted""" + return self.dataset_metadata is not None + + @log_elapsed_time(logger, logging.DEBUG, "Parsed input in {elapsed_time}") + @fitted_required + def parse(self, text, intents=None, top_n=None): + """Performs intent parsing on the provided *text* by calling its intent + parsers successively + + Args: + text (str): Input + intents (str or list of str, optional): If provided, reduces the + scope of intent parsing to the provided list of intents. + The ``None`` intent is never filtered out, meaning that it can + be returned even when using an intents scope. + top_n (int, optional): when provided, this method will return a + list of at most ``top_n`` most likely intents, instead of a + single parsing result. + Note that the returned list can contain less than ``top_n`` + elements, for instance when the parameter ``intents`` is not + None, or when ``top_n`` is greater than the total number of + intents. + + Returns: + dict or list: the most likely intent(s) along with the extracted + slots. See :func:`.parsing_result` and :func:`.extraction_result` + for the output format. + + Raises: + NotTrained: When the nlu engine is not fitted + InvalidInputError: When input type is not unicode + """ + if not isinstance(text, str): + raise InvalidInputError("Expected unicode but received: %s" + % type(text)) + + if isinstance(intents, str): + intents = {intents} + elif isinstance(intents, list): + intents = set(intents) + + if intents is not None: + for intent in intents: + if intent not in self.dataset_metadata["slot_name_mappings"]: + raise IntentNotFoundError(intent) + + if top_n is None: + none_proba = 0.0 + for parser in self.intent_parsers: + res = parser.parse(text, intents) + if is_empty(res): + none_proba = res[RES_INTENT][RES_PROBA] + continue + resolved_slots = self._resolve_slots(text, res[RES_SLOTS]) + return parsing_result(text, intent=res[RES_INTENT], + slots=resolved_slots) + return empty_result(text, none_proba) + + intents_results = self.get_intents(text) + if intents is not None: + intents_results = [res for res in intents_results + if res[RES_INTENT_NAME] is None + or res[RES_INTENT_NAME] in intents] + intents_results = intents_results[:top_n] + results = [] + for intent_res in intents_results: + slots = self.get_slots(text, intent_res[RES_INTENT_NAME]) + results.append(extraction_result(intent_res, slots)) + return results + + @log_elapsed_time(logger, logging.DEBUG, "Got intents in {elapsed_time}") + @fitted_required + def get_intents(self, text): + """Performs intent classification on the provided *text* and returns + the list of intents ordered by decreasing probability + + The length of the returned list is exactly the number of intents in the + dataset + 1 for the None intent + + .. note:: + + The probabilities returned along with each intent are not + guaranteed to sum to 1.0. They should be considered as scores + between 0 and 1. + """ + results = None + for parser in self.intent_parsers: + parser_results = parser.get_intents(text) + if results is None: + results = {res[RES_INTENT_NAME]: res for res in parser_results} + continue + + for res in parser_results: + intent = res[RES_INTENT_NAME] + proba = max(res[RES_PROBA], results[intent][RES_PROBA]) + results[intent][RES_PROBA] = proba + + return sorted(itervalues(results), key=lambda res: -res[RES_PROBA]) + + @log_elapsed_time(logger, logging.DEBUG, "Parsed slots in {elapsed_time}") + @fitted_required + def get_slots(self, text, intent): + """Extracts slots from a text input, with the knowledge of the intent + + Args: + text (str): input + intent (str): the intent which the input corresponds to + + Returns: + list: the list of extracted slots + + Raises: + IntentNotFoundError: When the intent was not part of the training + data + InvalidInputError: When input type is not unicode + """ + if not isinstance(text, str): + raise InvalidInputError("Expected unicode but received: %s" + % type(text)) + + if intent is None: + return [] + + if intent not in self.dataset_metadata["slot_name_mappings"]: + raise IntentNotFoundError(intent) + + for parser in self.intent_parsers: + slots = parser.get_slots(text, intent) + if not slots: + continue + return self._resolve_slots(text, slots) + return [] + + + @classmethod + def from_path(cls, path, **shared): + """Loads a :class:`SnipsNLUEngine` instance from a directory path + + The data at the given path must have been generated using + :func:`~SnipsNLUEngine.persist` + + Args: + path (str): The path where the nlu engine is stored + + Raises: + LoadingError: when some files are missing + IncompatibleModelError: when trying to load an engine model which + is not compatible with the current version of the lib + """ + directory_path = Path(path) + model_path = directory_path / "nlu_engine.json" + if not model_path.exists(): + raise LoadingError("Missing nlu engine model file: %s" + % model_path.name) + + with model_path.open(encoding="utf8") as f: + model = json.load(f) + model_version = model.get("model_version") + if model_version is None or model_version != __model_version__: + bypass_version_check = shared.get(BYPASS_VERSION_CHECK, False) + if bypass_version_check: + logger.warning( + "Incompatible model version found. The library expected " + "'%s' but the loaded engine is '%s'. The NLU engine may " + "not load correctly.", __model_version__, model_version) + else: + raise IncompatibleModelError(model_version) + + dataset_metadata = model["dataset_metadata"] + if shared.get(RESOURCES) is None and dataset_metadata is not None: + language = dataset_metadata["language_code"] + resources_dir = directory_path / "resources" / language + if resources_dir.is_dir(): + resources = load_resources_from_dir(resources_dir) + shared[RESOURCES] = resources + + if shared.get(BUILTIN_ENTITY_PARSER) is None: + path = model["builtin_entity_parser"] + if path is not None: + parser_path = directory_path / path + shared[BUILTIN_ENTITY_PARSER] = BuiltinEntityParser.from_path( + parser_path) + + if shared.get(CUSTOM_ENTITY_PARSER) is None: + path = model["custom_entity_parser"] + if path is not None: + parser_path = directory_path / path + shared[CUSTOM_ENTITY_PARSER] = CustomEntityParser.from_path( + parser_path) + + config = cls.config_type.from_dict(model["config"]) + nlu_engine = cls(config=config, **shared) + nlu_engine.dataset_metadata = dataset_metadata + intent_parsers = [] + for parser_idx, parser_name in enumerate(model["intent_parsers"]): + parser_config = config.intent_parsers_configs[parser_idx] + intent_parser_path = directory_path / parser_name + intent_parser = IntentParser.load_from_path( + intent_parser_path, parser_config.unit_name, **shared) + intent_parsers.append(intent_parser) + nlu_engine.intent_parsers = intent_parsers + return nlu_engine + + def _resolve_slots(self, text, slots): + builtin_scope = [slot[RES_ENTITY] for slot in slots + if is_builtin_entity(slot[RES_ENTITY])] + custom_scope = [slot[RES_ENTITY] for slot in slots + if not is_builtin_entity(slot[RES_ENTITY])] + # Do not use cached entities here as datetimes must be computed using + # current context + builtin_entities = self.builtin_entity_parser.parse( + text, builtin_scope, use_cache=False) + custom_entities = self.custom_entity_parser.parse( + text, custom_scope, use_cache=True) + + resolved_slots = [] + for slot in slots: + entity_name = slot[RES_ENTITY] + raw_value = slot[RES_VALUE] + is_builtin = is_builtin_entity(entity_name) + if is_builtin: + entities = builtin_entities + parser = self.builtin_entity_parser + slot_builder = builtin_slot + use_cache = False + extensible = False + else: + entities = custom_entities + parser = self.custom_entity_parser + slot_builder = custom_slot + use_cache = True + extensible = self.dataset_metadata[ENTITIES][entity_name][ + AUTOMATICALLY_EXTENSIBLE] + + resolved_slot = None + for ent in entities: + if ent[ENTITY_KIND] == entity_name and \ + ent[RES_MATCH_RANGE] == slot[RES_MATCH_RANGE]: + resolved_slot = slot_builder(slot, ent[RESOLVED_VALUE]) + break + if resolved_slot is None: + matches = parser.parse( + raw_value, scope=[entity_name], use_cache=use_cache) + if matches: + match = matches[0] + if is_builtin or len(match[RES_VALUE]) == len(raw_value): + resolved_slot = slot_builder( + slot, match[RESOLVED_VALUE]) + + if resolved_slot is None and extensible: + resolved_slot = slot_builder(slot) + + if resolved_slot is not None: + resolved_slots.append(resolved_slot) + + return resolved_slots
\ No newline at end of file diff --git a/snips_inference_agl/pipeline/__init__.py b/snips_inference_agl/pipeline/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/snips_inference_agl/pipeline/__init__.py diff --git a/snips_inference_agl/pipeline/configs/__init__.py b/snips_inference_agl/pipeline/configs/__init__.py new file mode 100644 index 0000000..027f286 --- /dev/null +++ b/snips_inference_agl/pipeline/configs/__init__.py @@ -0,0 +1,10 @@ +from .config import Config, ProcessingUnitConfig +from .features import default_features_factories +from .intent_classifier import (CooccurrenceVectorizerConfig, FeaturizerConfig, + IntentClassifierDataAugmentationConfig, + LogRegIntentClassifierConfig) +from .intent_parser import (DeterministicIntentParserConfig, + LookupIntentParserConfig, + ProbabilisticIntentParserConfig) +from .nlu_engine import NLUEngineConfig +from .slot_filler import CRFSlotFillerConfig, SlotFillerDataAugmentationConfig diff --git a/snips_inference_agl/pipeline/configs/config.py b/snips_inference_agl/pipeline/configs/config.py new file mode 100644 index 0000000..4267fa2 --- /dev/null +++ b/snips_inference_agl/pipeline/configs/config.py @@ -0,0 +1,49 @@ +from __future__ import unicode_literals + +from abc import ABCMeta, abstractmethod, abstractproperty +from builtins import object + +from future.utils import with_metaclass + + +class Config(with_metaclass(ABCMeta, object)): + @abstractmethod + def to_dict(self): + pass + + @classmethod + def from_dict(cls, obj_dict): + raise NotImplementedError + + +class ProcessingUnitConfig(with_metaclass(ABCMeta, Config)): + """Represents the configuration object needed to initialize a + :class:`.ProcessingUnit`""" + + @abstractproperty + def unit_name(self): + raise NotImplementedError + + def set_unit_name(self, value): + pass + + def get_required_resources(self): + return None + + +class DefaultProcessingUnitConfig(dict, ProcessingUnitConfig): + """Default config implemented as a simple dict""" + + @property + def unit_name(self): + return self["unit_name"] + + def set_unit_name(self, value): + self["unit_name"] = value + + def to_dict(self): + return self + + @classmethod + def from_dict(cls, obj_dict): + return cls(obj_dict) diff --git a/snips_inference_agl/pipeline/configs/features.py b/snips_inference_agl/pipeline/configs/features.py new file mode 100644 index 0000000..fa12e1a --- /dev/null +++ b/snips_inference_agl/pipeline/configs/features.py @@ -0,0 +1,81 @@ +def default_features_factories(): + """These are the default features used by the :class:`.CRFSlotFiller` + objects""" + + from snips_inference_agl.slot_filler.crf_utils import TaggingScheme + from snips_inference_agl.slot_filler.feature_factory import ( + NgramFactory, IsDigitFactory, IsFirstFactory, IsLastFactory, + ShapeNgramFactory, CustomEntityMatchFactory, BuiltinEntityMatchFactory) + + return [ + { + "args": { + "common_words_gazetteer_name": None, + "use_stemming": False, + "n": 1 + }, + "factory_name": NgramFactory.name, + "offsets": [-2, -1, 0, 1, 2] + }, + { + "args": { + "common_words_gazetteer_name": None, + "use_stemming": False, + "n": 2 + }, + "factory_name": NgramFactory.name, + "offsets": [-2, 1] + }, + { + "args": {}, + "factory_name": IsDigitFactory.name, + "offsets": [-1, 0, 1] + }, + { + "args": {}, + "factory_name": IsFirstFactory.name, + "offsets": [-2, -1, 0] + }, + { + "args": {}, + "factory_name": IsLastFactory.name, + "offsets": [0, 1, 2] + }, + { + "args": { + "n": 1 + }, + "factory_name": ShapeNgramFactory.name, + "offsets": [0] + }, + { + "args": { + "n": 2 + }, + "factory_name": ShapeNgramFactory.name, + "offsets": [-1, 0] + }, + { + "args": { + "n": 3 + }, + "factory_name": ShapeNgramFactory.name, + "offsets": [-1] + }, + { + "args": { + "use_stemming": False, + "tagging_scheme_code": TaggingScheme.BILOU.value, + }, + "factory_name": CustomEntityMatchFactory.name, + "offsets": [-2, -1, 0], + "drop_out": 0.5 + }, + { + "args": { + "tagging_scheme_code": TaggingScheme.BIO.value, + }, + "factory_name": BuiltinEntityMatchFactory.name, + "offsets": [-2, -1, 0] + }, + ] diff --git a/snips_inference_agl/pipeline/configs/intent_classifier.py b/snips_inference_agl/pipeline/configs/intent_classifier.py new file mode 100644 index 0000000..fc22c87 --- /dev/null +++ b/snips_inference_agl/pipeline/configs/intent_classifier.py @@ -0,0 +1,307 @@ +from __future__ import unicode_literals + +from snips_inference_agl.common.from_dict import FromDict +from snips_inference_agl.constants import ( + CUSTOM_ENTITY_PARSER_USAGE, NOISE, STEMS, STOP_WORDS, WORD_CLUSTERS) +from snips_inference_agl.entity_parser.custom_entity_parser import ( + CustomEntityParserUsage) +from snips_inference_agl.pipeline.configs import Config, ProcessingUnitConfig +from snips_inference_agl.resources import merge_required_resources + + +class LogRegIntentClassifierConfig(FromDict, ProcessingUnitConfig): + """Configuration of a :class:`.LogRegIntentClassifier`""" + + # pylint: disable=line-too-long + def __init__(self, data_augmentation_config=None, featurizer_config=None, + noise_reweight_factor=1.0): + """ + Args: + data_augmentation_config (:class:`IntentClassifierDataAugmentationConfig`): + Defines the strategy of the underlying data augmentation + featurizer_config (:class:`FeaturizerConfig`): Configuration of the + :class:`.Featurizer` used underneath + noise_reweight_factor (float, optional): this parameter allows to + change the weight of the None class. By default, the class + weights are computed using a "balanced" strategy. The + noise_reweight_factor allows to deviate from this strategy. + """ + if data_augmentation_config is None: + data_augmentation_config = IntentClassifierDataAugmentationConfig() + if featurizer_config is None: + featurizer_config = FeaturizerConfig() + self._data_augmentation_config = None + self.data_augmentation_config = data_augmentation_config + self._featurizer_config = None + self.featurizer_config = featurizer_config + self.noise_reweight_factor = noise_reweight_factor + + # pylint: enable=line-too-long + + @property + def data_augmentation_config(self): + return self._data_augmentation_config + + @data_augmentation_config.setter + def data_augmentation_config(self, value): + if isinstance(value, dict): + self._data_augmentation_config = \ + IntentClassifierDataAugmentationConfig.from_dict(value) + elif isinstance(value, IntentClassifierDataAugmentationConfig): + self._data_augmentation_config = value + else: + raise TypeError("Expected instance of " + "IntentClassifierDataAugmentationConfig or dict" + "but received: %s" % type(value)) + + @property + def featurizer_config(self): + return self._featurizer_config + + @featurizer_config.setter + def featurizer_config(self, value): + if isinstance(value, dict): + self._featurizer_config = \ + FeaturizerConfig.from_dict(value) + elif isinstance(value, FeaturizerConfig): + self._featurizer_config = value + else: + raise TypeError("Expected instance of FeaturizerConfig or dict" + "but received: %s" % type(value)) + + @property + def unit_name(self): + from snips_inference_agl.intent_classifier import LogRegIntentClassifier + return LogRegIntentClassifier.unit_name + + def get_required_resources(self): + resources = self.data_augmentation_config.get_required_resources() + resources = merge_required_resources( + resources, self.featurizer_config.get_required_resources()) + return resources + + def to_dict(self): + return { + "unit_name": self.unit_name, + "data_augmentation_config": + self.data_augmentation_config.to_dict(), + "featurizer_config": self.featurizer_config.to_dict(), + "noise_reweight_factor": self.noise_reweight_factor, + } + + +class IntentClassifierDataAugmentationConfig(FromDict, Config): + """Configuration used by a :class:`.LogRegIntentClassifier` which defines + how to augment data to improve the training of the classifier""" + + def __init__(self, min_utterances=20, noise_factor=5, + add_builtin_entities_examples=True, unknown_word_prob=0, + unknown_words_replacement_string=None, + max_unknown_words=None): + """ + Args: + min_utterances (int, optional): The minimum number of utterances to + automatically generate for each intent, based on the existing + utterances. Default is 20. + noise_factor (int, optional): Defines the size of the noise to + generate to train the implicit *None* intent, as a multiplier + of the average size of the other intents. Default is 5. + add_builtin_entities_examples (bool, optional): If True, some + builtin entity examples will be automatically added to the + training data. Default is True. + """ + self.min_utterances = min_utterances + self.noise_factor = noise_factor + self.add_builtin_entities_examples = add_builtin_entities_examples + self.unknown_word_prob = unknown_word_prob + self.unknown_words_replacement_string = \ + unknown_words_replacement_string + if max_unknown_words is not None and max_unknown_words < 0: + raise ValueError("max_unknown_words must be None or >= 0") + self.max_unknown_words = max_unknown_words + if unknown_word_prob > 0 and unknown_words_replacement_string is None: + raise ValueError("unknown_word_prob is positive (%s) but the " + "replacement string is None" % unknown_word_prob) + + @staticmethod + def get_required_resources(): + return { + NOISE: True, + STOP_WORDS: True + } + + def to_dict(self): + return { + "min_utterances": self.min_utterances, + "noise_factor": self.noise_factor, + "add_builtin_entities_examples": + self.add_builtin_entities_examples, + "unknown_word_prob": self.unknown_word_prob, + "unknown_words_replacement_string": + self.unknown_words_replacement_string, + "max_unknown_words": self.max_unknown_words + } + + +class FeaturizerConfig(FromDict, ProcessingUnitConfig): + """Configuration of a :class:`.Featurizer` object""" + + # pylint: disable=line-too-long + def __init__(self, tfidf_vectorizer_config=None, + cooccurrence_vectorizer_config=None, + pvalue_threshold=0.4, + added_cooccurrence_feature_ratio=0): + """ + Args: + tfidf_vectorizer_config (:class:`.TfidfVectorizerConfig`, optional): + empty configuration of the featurizer's + :attr:`tfidf_vectorizer` + cooccurrence_vectorizer_config: (:class:`.CooccurrenceVectorizerConfig`, optional): + configuration of the featurizer's + :attr:`cooccurrence_vectorizer` + pvalue_threshold (float): after fitting the training set to + extract tfidf features, a univariate feature selection is + applied. Features are tested for independence using a Chi-2 + test, under the null hypothesis that each feature should be + equally present in each class. Only features having a p-value + lower than the threshold are kept + added_cooccurrence_feature_ratio (float, optional): proportion of + cooccurrence features to add with respect to the number of + tfidf features. For instance with a ratio of 0.5, if 100 tfidf + features are remaining after feature selection, a maximum of 50 + cooccurrence features will be added + """ + self.pvalue_threshold = pvalue_threshold + self.added_cooccurrence_feature_ratio = \ + added_cooccurrence_feature_ratio + + if tfidf_vectorizer_config is None: + tfidf_vectorizer_config = TfidfVectorizerConfig() + elif isinstance(tfidf_vectorizer_config, dict): + tfidf_vectorizer_config = TfidfVectorizerConfig.from_dict( + tfidf_vectorizer_config) + self.tfidf_vectorizer_config = tfidf_vectorizer_config + + if cooccurrence_vectorizer_config is None: + cooccurrence_vectorizer_config = CooccurrenceVectorizerConfig() + elif isinstance(cooccurrence_vectorizer_config, dict): + cooccurrence_vectorizer_config = CooccurrenceVectorizerConfig \ + .from_dict(cooccurrence_vectorizer_config) + self.cooccurrence_vectorizer_config = cooccurrence_vectorizer_config + + # pylint: enable=line-too-long + + @property + def unit_name(self): + from snips_inference_agl.intent_classifier import Featurizer + return Featurizer.unit_name + + def get_required_resources(self): + required_resources = self.tfidf_vectorizer_config \ + .get_required_resources() + if self.cooccurrence_vectorizer_config: + required_resources = merge_required_resources( + required_resources, + self.cooccurrence_vectorizer_config.get_required_resources()) + return required_resources + + def to_dict(self): + return { + "unit_name": self.unit_name, + "pvalue_threshold": self.pvalue_threshold, + "added_cooccurrence_feature_ratio": + self.added_cooccurrence_feature_ratio, + "tfidf_vectorizer_config": self.tfidf_vectorizer_config.to_dict(), + "cooccurrence_vectorizer_config": + self.cooccurrence_vectorizer_config.to_dict(), + } + + +class TfidfVectorizerConfig(FromDict, ProcessingUnitConfig): + """Configuration of a :class:`.TfidfVectorizerConfig` object""" + + def __init__(self, word_clusters_name=None, use_stemming=False): + """ + Args: + word_clusters_name (str, optional): if a word cluster name is + provided then the featurizer will use the word clusters IDs + detected in the utterances and add them to the utterance text + before computing the tfidf. Default to None + use_stemming (bool, optional): use stemming before computing the + tfdif. Defaults to False (no stemming used) + """ + self.word_clusters_name = word_clusters_name + self.use_stemming = use_stemming + + @property + def unit_name(self): + from snips_inference_agl.intent_classifier import TfidfVectorizer + return TfidfVectorizer.unit_name + + def get_required_resources(self): + resources = {STEMS: True if self.use_stemming else False} + if self.word_clusters_name: + resources[WORD_CLUSTERS] = {self.word_clusters_name} + return resources + + def to_dict(self): + return { + "unit_name": self.unit_name, + "word_clusters_name": self.word_clusters_name, + "use_stemming": self.use_stemming + } + + +class CooccurrenceVectorizerConfig(FromDict, ProcessingUnitConfig): + """Configuration of a :class:`.CooccurrenceVectorizer` object""" + + def __init__(self, window_size=None, unknown_words_replacement_string=None, + filter_stop_words=True, keep_order=True): + """ + Args: + window_size (int, optional): if provided, word cooccurrences will + be taken into account only in a context window of size + :attr:`window_size`. If the window size is 3 then given a word + w[i], the vectorizer will only extract the following pairs: + (w[i], w[i + 1]), (w[i], w[i + 2]) and (w[i], w[i + 3]). + Defaults to None, which means that we consider all words + unknown_words_replacement_string (str, optional) + filter_stop_words (bool, optional): if True, stop words are ignored + when computing cooccurrences + keep_order (bool, optional): if True then cooccurrence are computed + taking the words order into account, which means the pairs + (w1, w2) and (w2, w1) will count as two separate features. + Defaults to `True`. + """ + self.window_size = window_size + self.unknown_words_replacement_string = \ + unknown_words_replacement_string + self.filter_stop_words = filter_stop_words + self.keep_order = keep_order + + @property + def unit_name(self): + from snips_inference_agl.intent_classifier import CooccurrenceVectorizer + return CooccurrenceVectorizer.unit_name + + def get_required_resources(self): + return { + STOP_WORDS: self.filter_stop_words, + # We require the parser to be trained without stems because we + # don't normalize and stem when processing in the + # CooccurrenceVectorizer (in order to run the builtin and + # custom parser on the same unormalized input). + # Requiring no stems ensures we'll be able to parse the unstemmed + # input + CUSTOM_ENTITY_PARSER_USAGE: CustomEntityParserUsage.WITHOUT_STEMS + } + + def to_dict(self): + return { + "unit_name": self.unit_name, + "unknown_words_replacement_string": + self.unknown_words_replacement_string, + "window_size": self.window_size, + "filter_stop_words": self.filter_stop_words, + "keep_order": self.keep_order + } diff --git a/snips_inference_agl/pipeline/configs/intent_parser.py b/snips_inference_agl/pipeline/configs/intent_parser.py new file mode 100644 index 0000000..f017472 --- /dev/null +++ b/snips_inference_agl/pipeline/configs/intent_parser.py @@ -0,0 +1,127 @@ +from __future__ import unicode_literals + +from snips_inference_agl.common.from_dict import FromDict +from snips_inference_agl.constants import CUSTOM_ENTITY_PARSER_USAGE, STOP_WORDS +from snips_inference_agl.entity_parser import CustomEntityParserUsage +from snips_inference_agl.pipeline.configs import ProcessingUnitConfig +from snips_inference_agl.resources import merge_required_resources + + +class ProbabilisticIntentParserConfig(FromDict, ProcessingUnitConfig): + """Configuration of a :class:`.ProbabilisticIntentParser` object + + Args: + intent_classifier_config (:class:`.ProcessingUnitConfig`): The + configuration of the underlying intent classifier, by default + it uses a :class:`.LogRegIntentClassifierConfig` + slot_filler_config (:class:`.ProcessingUnitConfig`): The configuration + that will be used for the underlying slot fillers, by default it + uses a :class:`.CRFSlotFillerConfig` + """ + + def __init__(self, intent_classifier_config=None, slot_filler_config=None): + from snips_inference_agl.intent_classifier import IntentClassifier + from snips_inference_agl.slot_filler import SlotFiller + + if intent_classifier_config is None: + from snips_inference_agl.pipeline.configs import LogRegIntentClassifierConfig + intent_classifier_config = LogRegIntentClassifierConfig() + if slot_filler_config is None: + from snips_inference_agl.pipeline.configs import CRFSlotFillerConfig + slot_filler_config = CRFSlotFillerConfig() + self.intent_classifier_config = IntentClassifier.get_config( + intent_classifier_config) + self.slot_filler_config = SlotFiller.get_config(slot_filler_config) + + @property + def unit_name(self): + from snips_inference_agl.intent_parser import ProbabilisticIntentParser + return ProbabilisticIntentParser.unit_name + + def get_required_resources(self): + resources = self.intent_classifier_config.get_required_resources() + resources = merge_required_resources( + resources, self.slot_filler_config.get_required_resources()) + return resources + + def to_dict(self): + return { + "unit_name": self.unit_name, + "slot_filler_config": self.slot_filler_config.to_dict(), + "intent_classifier_config": self.intent_classifier_config.to_dict() + } + + +class DeterministicIntentParserConfig(FromDict, ProcessingUnitConfig): + """Configuration of a :class:`.DeterministicIntentParser` + + Args: + max_queries (int, optional): Maximum number of regex patterns per + intent. 50 by default. + max_pattern_length (int, optional): Maximum length of regex patterns. + ignore_stop_words (bool, optional): If True, stop words will be + removed before building patterns. + + + This allows to deactivate the usage of regular expression when they are + too big to avoid explosion in time and memory + + Note: + In the future, a FST will be used instead of regexps, removing the need + for all this + """ + + def __init__(self, max_queries=100, max_pattern_length=1000, + ignore_stop_words=False): + self.max_queries = max_queries + self.max_pattern_length = max_pattern_length + self.ignore_stop_words = ignore_stop_words + + @property + def unit_name(self): + from snips_inference_agl.intent_parser import DeterministicIntentParser + return DeterministicIntentParser.unit_name + + def get_required_resources(self): + return { + CUSTOM_ENTITY_PARSER_USAGE: CustomEntityParserUsage.WITHOUT_STEMS, + STOP_WORDS: self.ignore_stop_words + } + + def to_dict(self): + return { + "unit_name": self.unit_name, + "max_queries": self.max_queries, + "max_pattern_length": self.max_pattern_length, + "ignore_stop_words": self.ignore_stop_words + } + + +class LookupIntentParserConfig(FromDict, ProcessingUnitConfig): + """Configuration of a :class:`.LookupIntentParser` + + Args: + ignore_stop_words (bool, optional): If True, stop words will be + removed before building patterns. + """ + + def __init__(self, ignore_stop_words=False): + self.ignore_stop_words = ignore_stop_words + + @property + def unit_name(self): + from snips_inference_agl.intent_parser.lookup_intent_parser import \ + LookupIntentParser + return LookupIntentParser.unit_name + + def get_required_resources(self): + return { + CUSTOM_ENTITY_PARSER_USAGE: CustomEntityParserUsage.WITHOUT_STEMS, + STOP_WORDS: self.ignore_stop_words + } + + def to_dict(self): + return { + "unit_name": self.unit_name, + "ignore_stop_words": self.ignore_stop_words + } diff --git a/snips_inference_agl/pipeline/configs/nlu_engine.py b/snips_inference_agl/pipeline/configs/nlu_engine.py new file mode 100644 index 0000000..3826702 --- /dev/null +++ b/snips_inference_agl/pipeline/configs/nlu_engine.py @@ -0,0 +1,55 @@ +from __future__ import unicode_literals + +from snips_inference_agl.common.from_dict import FromDict +from snips_inference_agl.constants import CUSTOM_ENTITY_PARSER_USAGE +from snips_inference_agl.entity_parser import CustomEntityParserUsage +from snips_inference_agl.pipeline.configs import ProcessingUnitConfig +from snips_inference_agl.resources import merge_required_resources + + +class NLUEngineConfig(FromDict, ProcessingUnitConfig): + """Configuration of a :class:`.SnipsNLUEngine` object + + Args: + intent_parsers_configs (list): List of intent parser configs + (:class:`.ProcessingUnitConfig`). The order in the list determines + the order in which each parser will be called by the nlu engine. + """ + + def __init__(self, intent_parsers_configs=None, random_seed=None): + from snips_inference_agl.intent_parser import IntentParser + + if intent_parsers_configs is None: + from snips_inference_agl.pipeline.configs import ( + ProbabilisticIntentParserConfig, + DeterministicIntentParserConfig) + intent_parsers_configs = [ + DeterministicIntentParserConfig(), + ProbabilisticIntentParserConfig() + ] + self.intent_parsers_configs = [ + IntentParser.get_config(conf) for conf in intent_parsers_configs] + self.random_seed = random_seed + + @property + def unit_name(self): + from snips_inference_agl.nlu_engine.nlu_engine import SnipsNLUEngine + return SnipsNLUEngine.unit_name + + def get_required_resources(self): + # Resolving custom slot values must be done without stemming + resources = { + CUSTOM_ENTITY_PARSER_USAGE: CustomEntityParserUsage.WITHOUT_STEMS + } + for config in self.intent_parsers_configs: + resources = merge_required_resources( + resources, config.get_required_resources()) + return resources + + def to_dict(self): + return { + "unit_name": self.unit_name, + "intent_parsers_configs": [ + config.to_dict() for config in self.intent_parsers_configs + ] + } diff --git a/snips_inference_agl/pipeline/configs/slot_filler.py b/snips_inference_agl/pipeline/configs/slot_filler.py new file mode 100644 index 0000000..be36e9c --- /dev/null +++ b/snips_inference_agl/pipeline/configs/slot_filler.py @@ -0,0 +1,145 @@ +from __future__ import unicode_literals + +from snips_inference_agl.common.from_dict import FromDict +from snips_inference_agl.constants import STOP_WORDS +from snips_inference_agl.pipeline.configs import ( + Config, ProcessingUnitConfig, default_features_factories) +from snips_inference_agl.resources import merge_required_resources + + +class CRFSlotFillerConfig(FromDict, ProcessingUnitConfig): + # pylint: disable=line-too-long + """Configuration of a :class:`.CRFSlotFiller` + + Args: + feature_factory_configs (list, optional): List of configurations that + specify the list of :class:`.CRFFeatureFactory` to use with the CRF + tagging_scheme (:class:`.TaggingScheme`, optional): Tagging scheme to + use to enrich CRF labels (default=BIO) + crf_args (dict, optional): Allow to overwrite the parameters of the CRF + defined in *sklearn_crfsuite*, see :class:`sklearn_crfsuite.CRF` + (default={"c1": .1, "c2": .1, "algorithm": "lbfgs"}) + data_augmentation_config (dict or :class:`.SlotFillerDataAugmentationConfig`, optional): + Specify how to augment data before training the CRF, see the + corresponding config object for more details. + random_seed (int, optional): Specify to make the CRF training + deterministic and reproducible (default=None) + """ + + # pylint: enable=line-too-long + + def __init__(self, feature_factory_configs=None, + tagging_scheme=None, crf_args=None, + data_augmentation_config=None): + if tagging_scheme is None: + from snips_inference_agl.slot_filler.crf_utils import TaggingScheme + tagging_scheme = TaggingScheme.BIO + if feature_factory_configs is None: + feature_factory_configs = default_features_factories() + if crf_args is None: + crf_args = _default_crf_args() + if data_augmentation_config is None: + data_augmentation_config = SlotFillerDataAugmentationConfig() + self.feature_factory_configs = feature_factory_configs + self._tagging_scheme = None + self.tagging_scheme = tagging_scheme + self.crf_args = crf_args + self._data_augmentation_config = None + self.data_augmentation_config = data_augmentation_config + + @property + def tagging_scheme(self): + return self._tagging_scheme + + @tagging_scheme.setter + def tagging_scheme(self, value): + from snips_inference_agl.slot_filler.crf_utils import TaggingScheme + if isinstance(value, TaggingScheme): + self._tagging_scheme = value + elif isinstance(value, int): + self._tagging_scheme = TaggingScheme(value) + else: + raise TypeError("Expected instance of TaggingScheme or int but" + "received: %s" % type(value)) + + @property + def data_augmentation_config(self): + return self._data_augmentation_config + + @data_augmentation_config.setter + def data_augmentation_config(self, value): + if isinstance(value, dict): + self._data_augmentation_config = \ + SlotFillerDataAugmentationConfig.from_dict(value) + elif isinstance(value, SlotFillerDataAugmentationConfig): + self._data_augmentation_config = value + else: + raise TypeError("Expected instance of " + "SlotFillerDataAugmentationConfig or dict but " + "received: %s" % type(value)) + + @property + def unit_name(self): + from snips_inference_agl.slot_filler import CRFSlotFiller + return CRFSlotFiller.unit_name + + def get_required_resources(self): + # Import here to avoid circular imports + from snips_inference_agl.slot_filler.feature_factory import CRFFeatureFactory + + resources = self.data_augmentation_config.get_required_resources() + for config in self.feature_factory_configs: + factory = CRFFeatureFactory.from_config(config) + resources = merge_required_resources( + resources, factory.get_required_resources()) + return resources + + def to_dict(self): + return { + "unit_name": self.unit_name, + "feature_factory_configs": self.feature_factory_configs, + "crf_args": self.crf_args, + "tagging_scheme": self.tagging_scheme.value, + "data_augmentation_config": + self.data_augmentation_config.to_dict() + } + + +class SlotFillerDataAugmentationConfig(FromDict, Config): + """Specify how to augment data before training the CRF + + Data augmentation essentially consists in creating additional utterances + by combining utterance patterns and slot values + + Args: + min_utterances (int, optional): Specify the minimum amount of + utterances to generate per intent (default=200) + capitalization_ratio (float, optional): If an entity has one or more + capitalized values, the data augmentation will randomly capitalize + its values with a ratio of *capitalization_ratio* (default=.2) + add_builtin_entities_examples (bool, optional): If True, some builtin + entity examples will be automatically added to the training data. + Default is True. + """ + + def __init__(self, min_utterances=200, capitalization_ratio=.2, + add_builtin_entities_examples=True): + self.min_utterances = min_utterances + self.capitalization_ratio = capitalization_ratio + self.add_builtin_entities_examples = add_builtin_entities_examples + + def get_required_resources(self): + return { + STOP_WORDS: True + } + + def to_dict(self): + return { + "min_utterances": self.min_utterances, + "capitalization_ratio": self.capitalization_ratio, + "add_builtin_entities_examples": self.add_builtin_entities_examples + } + + +def _default_crf_args(): + return {"c1": .1, "c2": .1, "algorithm": "lbfgs"} diff --git a/snips_inference_agl/pipeline/processing_unit.py b/snips_inference_agl/pipeline/processing_unit.py new file mode 100644 index 0000000..1928470 --- /dev/null +++ b/snips_inference_agl/pipeline/processing_unit.py @@ -0,0 +1,177 @@ +from __future__ import unicode_literals + +import io +import json +import shutil +from abc import ABCMeta, abstractmethod, abstractproperty +from builtins import str, bytes +from pathlib import Path + +from future.utils import with_metaclass + +from snips_inference_agl.common.abc_utils import abstractclassmethod, classproperty +from snips_inference_agl.common.io_utils import temp_dir, unzip_archive +from snips_inference_agl.common.registrable import Registrable +from snips_inference_agl.common.utils import ( + json_string, check_random_state) +from snips_inference_agl.constants import ( + BUILTIN_ENTITY_PARSER, CUSTOM_ENTITY_PARSER, CUSTOM_ENTITY_PARSER_USAGE, + RESOURCES, LANGUAGE, RANDOM_STATE) +from snips_inference_agl.entity_parser import ( + BuiltinEntityParser, CustomEntityParser, CustomEntityParserUsage) +from snips_inference_agl.exceptions import LoadingError +from snips_inference_agl.pipeline.configs import ProcessingUnitConfig +from snips_inference_agl.pipeline.configs.config import DefaultProcessingUnitConfig +from snips_inference_agl.resources import load_resources + + +class ProcessingUnit(with_metaclass(ABCMeta, Registrable)): + """Abstraction of a NLU pipeline unit + + Pipeline processing units such as intent parsers, intent classifiers and + slot fillers must implement this class. + + A :class:`ProcessingUnit` is associated with a *config_type*, which + represents the :class:`.ProcessingUnitConfig` used to initialize it. + """ + + def __init__(self, config, **shared): + if config is None: + self.config = self.default_config() + elif isinstance(config, ProcessingUnitConfig): + self.config = config + elif isinstance(config, dict): + self.config = self.config_type.from_dict(config) + else: + raise ValueError("Unexpected config type: %s" % type(config)) + if self.config is not None: + self.config.set_unit_name(self.unit_name) + self.builtin_entity_parser = shared.get(BUILTIN_ENTITY_PARSER) + self.custom_entity_parser = shared.get(CUSTOM_ENTITY_PARSER) + self.resources = shared.get(RESOURCES) + self.random_state = check_random_state(shared.get(RANDOM_STATE)) + + @classproperty + def config_type(cls): # pylint:disable=no-self-argument + return DefaultProcessingUnitConfig + + @classmethod + def default_config(cls): + config = cls.config_type() # pylint:disable=no-value-for-parameter + config.set_unit_name(cls.unit_name) + return config + + @classproperty + def unit_name(cls): # pylint:disable=no-self-argument + return ProcessingUnit.registered_name(cls) + + @classmethod + def from_config(cls, unit_config, **shared): + """Build a :class:`ProcessingUnit` from the provided config""" + unit = cls.by_name(unit_config.unit_name) + return unit(unit_config, **shared) + + @classmethod + def load_from_path(cls, unit_path, unit_name=None, **shared): + """Load a :class:`ProcessingUnit` from a persisted processing unit + directory + + Args: + unit_path (str or :class:`pathlib.Path`): path to the persisted + processing unit + unit_name (str, optional): Name of the processing unit to load. + By default, the unit name is assumed to be stored in a + "metadata.json" file located in the directory at unit_path. + + Raises: + LoadingError: when unit_name is None and no metadata file is found + in the processing unit directory + """ + unit_path = Path(unit_path) + if unit_name is None: + metadata_path = unit_path / "metadata.json" + if not metadata_path.exists(): + raise LoadingError( + "Missing metadata for processing unit at path %s" + % str(unit_path)) + with metadata_path.open(encoding="utf8") as f: + metadata = json.load(f) + unit_name = metadata["unit_name"] + unit = cls.by_name(unit_name) + return unit.from_path(unit_path, **shared) + + @classmethod + def get_config(cls, unit_config): + """Returns the :class:`.ProcessingUnitConfig` corresponding to + *unit_config*""" + if isinstance(unit_config, ProcessingUnitConfig): + return unit_config + elif isinstance(unit_config, dict): + unit_name = unit_config["unit_name"] + processing_unit_type = cls.by_name(unit_name) + return processing_unit_type.config_type.from_dict(unit_config) + elif isinstance(unit_config, (str, bytes)): + unit_name = unit_config + unit_config = {"unit_name": unit_name} + processing_unit_type = cls.by_name(unit_name) + return processing_unit_type.config_type.from_dict(unit_config) + else: + raise ValueError( + "Expected `unit_config` to be an instance of " + "ProcessingUnitConfig or dict or str but found: %s" + % type(unit_config)) + + @abstractproperty + def fitted(self): + """Whether or not the processing unit has already been trained""" + pass + + def load_resources_if_needed(self, language): + if self.resources is None or self.fitted: + required_resources = None + if self.config is not None: + required_resources = self.config.get_required_resources() + self.resources = load_resources(language, required_resources) + + def fit_builtin_entity_parser_if_needed(self, dataset): + # We only fit a builtin entity parser when the unit has already been + # fitted or if the parser is none. + # In the other cases the parser is provided fitted by another unit. + if self.builtin_entity_parser is None or self.fitted: + self.builtin_entity_parser = BuiltinEntityParser.build( + dataset=dataset) + return self + + def fit_custom_entity_parser_if_needed(self, dataset): + # We only fit a custom entity parser when the unit has already been + # fitted or if the parser is none. + # In the other cases the parser is provided fitted by another unit. + required_resources = self.config.get_required_resources() + if not required_resources or not required_resources.get( + CUSTOM_ENTITY_PARSER_USAGE): + # In these cases we need a custom entity parser only to do the + # final slot resolution step, which must be done without stemming. + parser_usage = CustomEntityParserUsage.WITHOUT_STEMS + else: + parser_usage = required_resources[CUSTOM_ENTITY_PARSER_USAGE] + + if self.custom_entity_parser is None or self.fitted: + self.load_resources_if_needed(dataset[LANGUAGE]) + self.custom_entity_parser = CustomEntityParser.build( + dataset, parser_usage, self.resources) + return self + + def persist_metadata(self, path, **kwargs): + metadata = {"unit_name": self.unit_name} + metadata.update(kwargs) + metadata_json = json_string(metadata) + with (path / "metadata.json").open(mode="w", encoding="utf8") as f: + f.write(metadata_json) + + # @abstractmethod + def persist(self, path): + pass + + @abstractclassmethod + def from_path(cls, path, **shared): + pass diff --git a/snips_inference_agl/preprocessing.py b/snips_inference_agl/preprocessing.py new file mode 100644 index 0000000..cfb4aa5 --- /dev/null +++ b/snips_inference_agl/preprocessing.py @@ -0,0 +1,97 @@ +# coding=utf-8 +from __future__ import unicode_literals + +from builtins import object + +from snips_inference_agl.resources import get_stems + + +def stem(string, language, resources): + from snips_nlu_utils import normalize + + normalized_string = normalize(string) + tokens = tokenize_light(normalized_string, language) + stemmed_tokens = [_stem(token, resources) for token in tokens] + return " ".join(stemmed_tokens) + + +def stem_token(token, resources): + from snips_nlu_utils import normalize + + if token.stemmed_value: + return token.stemmed_value + if not token.normalized_value: + token.normalized_value = normalize(token.value) + token.stemmed_value = _stem(token.normalized_value, resources) + return token.stemmed_value + + +def normalize_token(token): + from snips_nlu_utils import normalize + + if token.normalized_value: + return token.normalized_value + token.normalized_value = normalize(token.value) + return token.normalized_value + + +def _stem(string, resources): + return get_stems(resources).get(string, string) + + +class Token(object): + """Token object which is output by the tokenization + + Attributes: + value (str): Tokenized string + start (int): Start position of the token within the sentence + end (int): End position of the token within the sentence + normalized_value (str): Normalized value of the tokenized string + stemmed_value (str): Stemmed value of the tokenized string + """ + + def __init__(self, value, start, end, normalized_value=None, + stemmed_value=None): + self.value = value + self.start = start + self.end = end + self.normalized_value = normalized_value + self.stemmed_value = stemmed_value + + def __eq__(self, other): + if not isinstance(other, type(self)): + return False + return (self.value == other.value + and self.start == other.start + and self.end == other.end) + + def __ne__(self, other): + return not self.__eq__(other) + + +def tokenize(string, language): + """Tokenizes the input + + Args: + string (str): Input to tokenize + language (str): Language to use during tokenization + + Returns: + list of :class:`.Token`: The list of tokenized values + """ + from snips_nlu_utils import tokenize as _tokenize + + tokens = [Token(value=token["value"], + start=token["char_range"]["start"], + end=token["char_range"]["end"]) + for token in _tokenize(string, language)] + return tokens + + +def tokenize_light(string, language): + """Same behavior as :func:`tokenize` but returns tokenized strings instead + of :class:`Token` objects""" + from snips_nlu_utils import tokenize_light as _tokenize_light + + tokenized_string = _tokenize_light(string, language) + return tokenized_string diff --git a/snips_inference_agl/resources.py b/snips_inference_agl/resources.py new file mode 100644 index 0000000..3b3e478 --- /dev/null +++ b/snips_inference_agl/resources.py @@ -0,0 +1,267 @@ +from __future__ import unicode_literals + +import json +from copy import deepcopy +from pathlib import Path + +from snips_inference_agl.common.utils import get_package_path, is_package +from snips_inference_agl.constants import ( + CUSTOM_ENTITY_PARSER_USAGE, DATA_PATH, GAZETTEERS, NOISE, + STEMS, STOP_WORDS, WORD_CLUSTERS, METADATA) +from snips_inference_agl.entity_parser.custom_entity_parser import ( + CustomEntityParserUsage) + + +class MissingResource(LookupError): + pass + + +def load_resources(name, required_resources=None): + """Load language specific resources + + Args: + name (str): Resource name as in ``snips-nlu download <name>``. Can also + be the name of a python package or a directory path. + required_resources (dict, optional): Resources requirement + dict which, when provided, allows to limit the amount of resources + to load. By default, all existing resources are loaded. + """ + if name in set(d.name for d in DATA_PATH.iterdir()): + return load_resources_from_dir(DATA_PATH / name, required_resources) + elif is_package(name): + package_path = get_package_path(name) + resources_sub_dir = get_resources_sub_directory(package_path) + return load_resources_from_dir(resources_sub_dir, required_resources) + elif Path(name).exists(): + path = Path(name) + if (path / "__init__.py").exists(): + path = get_resources_sub_directory(path) + return load_resources_from_dir(path, required_resources) + else: + raise MissingResource("Language resource '{r}' not found. This may be " + "solved by running " + "'python -m snips_nlu download {r}'" + .format(r=name)) + + +def load_resources_from_dir(resources_dir, required_resources=None): + with (resources_dir / "metadata.json").open(encoding="utf8") as f: + metadata = json.load(f) + metadata = _update_metadata(metadata, required_resources) + gazetteer_names = metadata["gazetteers"] + clusters_names = metadata["word_clusters"] + stop_words_filename = metadata["stop_words"] + stems_filename = metadata["stems"] + noise_filename = metadata["noise"] + + gazetteers = _get_gazetteers(resources_dir / "gazetteers", gazetteer_names) + word_clusters = _get_word_clusters(resources_dir / "word_clusters", + clusters_names) + + stems = None + stop_words = None + noise = None + + if stems_filename is not None: + stems = _get_stems(resources_dir / "stemming", stems_filename) + if stop_words_filename is not None: + stop_words = _get_stop_words(resources_dir, stop_words_filename) + if noise_filename is not None: + noise = _get_noise(resources_dir, noise_filename) + + return { + METADATA: metadata, + WORD_CLUSTERS: word_clusters, + GAZETTEERS: gazetteers, + STOP_WORDS: stop_words, + NOISE: noise, + STEMS: stems, + } + + +def _update_metadata(metadata, required_resources): + if required_resources is None: + return metadata + metadata = deepcopy(metadata) + required_gazetteers = required_resources.get(GAZETTEERS, []) + required_word_clusters = required_resources.get(WORD_CLUSTERS, []) + for gazetter in required_gazetteers: + if gazetter not in metadata["gazetteers"]: + raise ValueError("Unknown gazetteer for language '%s': '%s'" + % (metadata["language"], gazetter)) + for word_clusters in required_word_clusters: + if word_clusters not in metadata["word_clusters"]: + raise ValueError("Unknown word clusters for language '%s': '%s'" + % (metadata["language"], word_clusters)) + metadata["gazetteers"] = required_gazetteers + metadata["word_clusters"] = required_word_clusters + if not required_resources.get(STEMS, False): + metadata["stems"] = None + if not required_resources.get(NOISE, False): + metadata["noise"] = None + if not required_resources.get(STOP_WORDS, False): + metadata["stop_words"] = None + return metadata + + +def get_resources_sub_directory(resources_dir): + resources_dir = Path(resources_dir) + with (resources_dir / "metadata.json").open(encoding="utf8") as f: + metadata = json.load(f) + resource_name = metadata["name"] + version = metadata["version"] + sub_dir_name = "{r}-{v}".format(r=resource_name, v=version) + return resources_dir / sub_dir_name + + +def get_stop_words(resources): + return _get_resource(resources, STOP_WORDS) + + +def get_noise(resources): + return _get_resource(resources, NOISE) + + +def get_word_clusters(resources): + return _get_resource(resources, WORD_CLUSTERS) + + +def get_word_cluster(resources, cluster_name): + word_clusters = get_word_clusters(resources) + if cluster_name not in word_clusters: + raise MissingResource("Word cluster '{}' not found" % cluster_name) + return word_clusters[cluster_name] + + +def get_gazetteer(resources, gazetteer_name): + gazetteers = _get_resource(resources, GAZETTEERS) + if gazetteer_name not in gazetteers: + raise MissingResource("Gazetteer '%s' not found in resources" + % gazetteer_name) + return gazetteers[gazetteer_name] + + +def get_stems(resources): + return _get_resource(resources, STEMS) + + +def merge_required_resources(lhs, rhs): + if not lhs: + return dict() if rhs is None else rhs + if not rhs: + return dict() if lhs is None else lhs + merged_resources = dict() + if lhs.get(NOISE, False) or rhs.get(NOISE, False): + merged_resources[NOISE] = True + if lhs.get(STOP_WORDS, False) or rhs.get(STOP_WORDS, False): + merged_resources[STOP_WORDS] = True + if lhs.get(STEMS, False) or rhs.get(STEMS, False): + merged_resources[STEMS] = True + lhs_parser_usage = lhs.get(CUSTOM_ENTITY_PARSER_USAGE) + rhs_parser_usage = rhs.get(CUSTOM_ENTITY_PARSER_USAGE) + parser_usage = CustomEntityParserUsage.merge_usages( + lhs_parser_usage, rhs_parser_usage) + merged_resources[CUSTOM_ENTITY_PARSER_USAGE] = parser_usage + gazetteers = lhs.get(GAZETTEERS, set()).union(rhs.get(GAZETTEERS, set())) + if gazetteers: + merged_resources[GAZETTEERS] = gazetteers + word_clusters = lhs.get(WORD_CLUSTERS, set()).union( + rhs.get(WORD_CLUSTERS, set())) + if word_clusters: + merged_resources[WORD_CLUSTERS] = word_clusters + return merged_resources + + +def _get_resource(resources, resource_name): + if resource_name not in resources or resources[resource_name] is None: + raise MissingResource("Resource '%s' not found" % resource_name) + return resources[resource_name] + + +def _get_stop_words(resources_dir, stop_words_filename): + if not stop_words_filename: + return None + stop_words_path = (resources_dir / stop_words_filename).with_suffix(".txt") + return _load_stop_words(stop_words_path) + + +def _load_stop_words(stop_words_path): + with stop_words_path.open(encoding="utf8") as f: + stop_words = set(l.strip() for l in f if l) + return stop_words + + +def _get_noise(resources_dir, noise_filename): + if not noise_filename: + return None + noise_path = (resources_dir / noise_filename).with_suffix(".txt") + return _load_noise(noise_path) + + +def _load_noise(noise_path): + with noise_path.open(encoding="utf8") as f: + # Here we split on a " " knowing that it's always ignored by + # the tokenization (see tokenization unit tests) + # It is not important to tokenize precisely as this noise is just used + # to generate utterances for the None intent + noise = [word for l in f for word in l.split()] + return noise + + +def _get_word_clusters(word_clusters_dir, clusters_names): + if not clusters_names: + return dict() + + clusters = dict() + for clusters_name in clusters_names: + clusters_path = (word_clusters_dir / clusters_name).with_suffix(".txt") + clusters[clusters_name] = _load_word_clusters(clusters_path) + return clusters + + +def _load_word_clusters(path): + clusters = dict() + with path.open(encoding="utf8") as f: + for line in f: + split = line.rstrip().split("\t") + if not split: + continue + clusters[split[0]] = split[1] + return clusters + + +def _get_gazetteers(gazetteers_dir, gazetteer_names): + if not gazetteer_names: + return dict() + + gazetteers = dict() + for gazetteer_name in gazetteer_names: + gazetteer_path = (gazetteers_dir / gazetteer_name).with_suffix(".txt") + gazetteers[gazetteer_name] = _load_gazetteer(gazetteer_path) + return gazetteers + + +def _load_gazetteer(path): + with path.open(encoding="utf8") as f: + gazetteer = set(v.strip() for v in f if v) + return gazetteer + + + +def _get_stems(stems_dir, filename): + if not filename: + return None + stems_path = (stems_dir / filename).with_suffix(".txt") + return _load_stems(stems_path) + + +def _load_stems(path): + with path.open(encoding="utf8") as f: + stems = dict() + for line in f: + elements = line.strip().split(',') + stem = elements[0] + for value in elements[1:]: + stems[value] = stem + return stems + diff --git a/snips_inference_agl/result.py b/snips_inference_agl/result.py new file mode 100644 index 0000000..39930c8 --- /dev/null +++ b/snips_inference_agl/result.py @@ -0,0 +1,342 @@ +from __future__ import unicode_literals + +from snips_inference_agl.constants import ( + RES_ENTITY, RES_INPUT, RES_INTENT, RES_INTENT_NAME, RES_MATCH_RANGE, + RES_PROBA, RES_RAW_VALUE, RES_SLOTS, RES_SLOT_NAME, RES_VALUE, ENTITY_KIND, + RESOLVED_VALUE, VALUE) + + +def intent_classification_result(intent_name, probability): + """Creates an intent classification result to be returned by + :meth:`.IntentClassifier.get_intent` + + Example: + + >>> intent_classification_result("GetWeather", 0.93) + {'intentName': 'GetWeather', 'probability': 0.93} + """ + return { + RES_INTENT_NAME: intent_name, + RES_PROBA: probability + } + + +def unresolved_slot(match_range, value, entity, slot_name): + """Creates an internal slot yet to be resolved + + Example: + + >>> from snips_inference_agl.common.utils import json_string + >>> slot = unresolved_slot([0, 8], "tomorrow", "snips/datetime", \ + "startDate") + >>> print(json_string(slot, indent=4, sort_keys=True)) + { + "entity": "snips/datetime", + "range": { + "end": 8, + "start": 0 + }, + "slotName": "startDate", + "value": "tomorrow" + } + """ + return { + RES_MATCH_RANGE: _convert_range(match_range), + RES_VALUE: value, + RES_ENTITY: entity, + RES_SLOT_NAME: slot_name + } + + +def custom_slot(internal_slot, resolved_value=None): + """Creates a custom slot with *resolved_value* being the reference value + of the slot + + Example: + + >>> s = unresolved_slot([10, 19], "earl grey", "beverage", "beverage") + >>> from snips_inference_agl.common.utils import json_string + >>> print(json_string(custom_slot(s, "tea"), indent=4, sort_keys=True)) + { + "entity": "beverage", + "range": { + "end": 19, + "start": 10 + }, + "rawValue": "earl grey", + "slotName": "beverage", + "value": { + "kind": "Custom", + "value": "tea" + } + } + """ + + if resolved_value is None: + resolved_value = internal_slot[RES_VALUE] + return { + RES_MATCH_RANGE: _convert_range(internal_slot[RES_MATCH_RANGE]), + RES_RAW_VALUE: internal_slot[RES_VALUE], + RES_VALUE: { + "kind": "Custom", + "value": resolved_value + }, + RES_ENTITY: internal_slot[RES_ENTITY], + RES_SLOT_NAME: internal_slot[RES_SLOT_NAME] + } + + +def builtin_slot(internal_slot, resolved_value): + """Creates a builtin slot with *resolved_value* being the resolved value + of the slot + + Example: + + >>> rng = [10, 32] + >>> raw_value = "twenty degrees celsius" + >>> entity = "snips/temperature" + >>> slot_name = "beverageTemperature" + >>> s = unresolved_slot(rng, raw_value, entity, slot_name) + >>> resolved = { + ... "kind": "Temperature", + ... "value": 20, + ... "unit": "celsius" + ... } + >>> from snips_inference_agl.common.utils import json_string + >>> print(json_string(builtin_slot(s, resolved), indent=4)) + { + "entity": "snips/temperature", + "range": { + "end": 32, + "start": 10 + }, + "rawValue": "twenty degrees celsius", + "slotName": "beverageTemperature", + "value": { + "kind": "Temperature", + "unit": "celsius", + "value": 20 + } + } + """ + return { + RES_MATCH_RANGE: _convert_range(internal_slot[RES_MATCH_RANGE]), + RES_RAW_VALUE: internal_slot[RES_VALUE], + RES_VALUE: resolved_value, + RES_ENTITY: internal_slot[RES_ENTITY], + RES_SLOT_NAME: internal_slot[RES_SLOT_NAME] + } + + +def resolved_slot(match_range, raw_value, resolved_value, entity, slot_name): + """Creates a resolved slot + + Args: + match_range (dict): Range of the slot within the sentence + (ex: {"start": 3, "end": 10}) + raw_value (str): Slot value as it appears in the sentence + resolved_value (dict): Resolved value of the slot + entity (str): Entity which the slot belongs to + slot_name (str): Slot type + + Returns: + dict: The resolved slot + + Example: + + >>> resolved_value = { + ... "kind": "Temperature", + ... "value": 20, + ... "unit": "celsius" + ... } + >>> slot = resolved_slot({"start": 10, "end": 19}, "earl grey", + ... resolved_value, "beverage", "beverage") + >>> from snips_inference_agl.common.utils import json_string + >>> print(json_string(slot, indent=4, sort_keys=True)) + { + "entity": "beverage", + "range": { + "end": 19, + "start": 10 + }, + "rawValue": "earl grey", + "slotName": "beverage", + "value": { + "kind": "Temperature", + "unit": "celsius", + "value": 20 + } + } + """ + return { + RES_MATCH_RANGE: match_range, + RES_RAW_VALUE: raw_value, + RES_VALUE: resolved_value, + RES_ENTITY: entity, + RES_SLOT_NAME: slot_name + } + + +def parsing_result(input, intent, slots): # pylint:disable=redefined-builtin + """Create the final output of :meth:`.SnipsNLUEngine.parse` or + :meth:`.IntentParser.parse` + + Example: + + >>> text = "Hello Bill!" + >>> intent_result = intent_classification_result("Greeting", 0.95) + >>> internal_slot = unresolved_slot([6, 10], "Bill", "name", + ... "greetee") + >>> slots = [custom_slot(internal_slot, "William")] + >>> res = parsing_result(text, intent_result, slots) + >>> from snips_inference_agl.common.utils import json_string + >>> print(json_string(res, indent=4, sort_keys=True)) + { + "input": "Hello Bill!", + "intent": { + "intentName": "Greeting", + "probability": 0.95 + }, + "slots": [ + { + "entity": "name", + "range": { + "end": 10, + "start": 6 + }, + "rawValue": "Bill", + "slotName": "greetee", + "value": { + "kind": "Custom", + "value": "William" + } + } + ] + } + """ + return { + RES_INPUT: input, + RES_INTENT: intent, + RES_SLOTS: slots + } + + +def extraction_result(intent, slots): + """Create the items in the output of :meth:`.SnipsNLUEngine.parse` or + :meth:`.IntentParser.parse` when called with a defined ``top_n`` value + + This differs from :func:`.parsing_result` in that the input is omitted. + + Example: + + >>> intent_result = intent_classification_result("Greeting", 0.95) + >>> internal_slot = unresolved_slot([6, 10], "Bill", "name", + ... "greetee") + >>> slots = [custom_slot(internal_slot, "William")] + >>> res = extraction_result(intent_result, slots) + >>> from snips_inference_agl.common.utils import json_string + >>> print(json_string(res, indent=4, sort_keys=True)) + { + "intent": { + "intentName": "Greeting", + "probability": 0.95 + }, + "slots": [ + { + "entity": "name", + "range": { + "end": 10, + "start": 6 + }, + "rawValue": "Bill", + "slotName": "greetee", + "value": { + "kind": "Custom", + "value": "William" + } + } + ] + } + """ + return { + RES_INTENT: intent, + RES_SLOTS: slots + } + + +def is_empty(result): + """Check if a result is empty + + Example: + + >>> res = empty_result("foo bar", 1.0) + >>> is_empty(res) + True + """ + return result[RES_INTENT][RES_INTENT_NAME] is None + + +def empty_result(input, probability): # pylint:disable=redefined-builtin + """Creates an empty parsing result of the same format as the one of + :func:`parsing_result` + + An empty is typically returned by a :class:`.SnipsNLUEngine` or + :class:`.IntentParser` when no intent nor slots were found. + + Example: + + >>> res = empty_result("foo bar", 0.8) + >>> from snips_inference_agl.common.utils import json_string + >>> print(json_string(res, indent=4, sort_keys=True)) + { + "input": "foo bar", + "intent": { + "intentName": null, + "probability": 0.8 + }, + "slots": [] + } + """ + intent = intent_classification_result(None, probability) + return parsing_result(input=input, intent=intent, slots=[]) + + +def parsed_entity(entity_kind, entity_value, entity_resolved_value, + entity_range): + """Create the items in the output of + :meth:`snips_inference_agl.entity_parser.EntityParser.parse` + + Example: + >>> resolved_value = dict(age=28, role="datascientist") + >>> range = dict(start=0, end=6) + >>> ent = parsed_entity("snipster", "adrien", resolved_value, range) + >>> import json + >>> print(json.dumps(ent, indent=4, sort_keys=True)) + { + "entity_kind": "snipster", + "range": { + "end": 6, + "start": 0 + }, + "resolved_value": { + "age": 28, + "role": "datascientist" + }, + "value": "adrien" + } + """ + return { + VALUE: entity_value, + RESOLVED_VALUE: entity_resolved_value, + ENTITY_KIND: entity_kind, + RES_MATCH_RANGE: entity_range + } + + +def _convert_range(rng): + if isinstance(rng, dict): + return rng + return { + "start": rng[0], + "end": rng[1] + } diff --git a/snips_inference_agl/slot_filler/__init__.py b/snips_inference_agl/slot_filler/__init__.py new file mode 100644 index 0000000..70974aa --- /dev/null +++ b/snips_inference_agl/slot_filler/__init__.py @@ -0,0 +1,3 @@ +from .crf_slot_filler import CRFSlotFiller +from .feature import Feature +from .slot_filler import SlotFiller diff --git a/snips_inference_agl/slot_filler/crf_slot_filler.py b/snips_inference_agl/slot_filler/crf_slot_filler.py new file mode 100644 index 0000000..e6ec7e6 --- /dev/null +++ b/snips_inference_agl/slot_filler/crf_slot_filler.py @@ -0,0 +1,467 @@ +from __future__ import unicode_literals + +import base64 +import json +import logging +import math +import os +import shutil +import tempfile +from builtins import range +from copy import deepcopy +from pathlib import Path + +from future.utils import iteritems + +from snips_inference_agl.common.dataset_utils import get_slot_name_mapping +from snips_inference_agl.common.dict_utils import UnupdatableDict +from snips_inference_agl.common.io_utils import mkdir_p +from snips_inference_agl.common.log_utils import DifferedLoggingMessage, log_elapsed_time +from snips_inference_agl.common.utils import ( + check_persisted_path, fitted_required, json_string) +from snips_inference_agl.constants import DATA, LANGUAGE +from snips_inference_agl.data_augmentation import augment_utterances +from snips_inference_agl.dataset import validate_and_format_dataset +from snips_inference_agl.exceptions import LoadingError +from snips_inference_agl.pipeline.configs import CRFSlotFillerConfig +from snips_inference_agl.preprocessing import tokenize +from snips_inference_agl.slot_filler.crf_utils import ( + OUTSIDE, TAGS, TOKENS, tags_to_slots, utterance_to_sample) +from snips_inference_agl.slot_filler.feature import TOKEN_NAME +from snips_inference_agl.slot_filler.feature_factory import CRFFeatureFactory +from snips_inference_agl.slot_filler.slot_filler import SlotFiller + +CRF_MODEL_FILENAME = "model.crfsuite" + +logger = logging.getLogger(__name__) + + +@SlotFiller.register("crf_slot_filler") +class CRFSlotFiller(SlotFiller): + """Slot filler which uses Linear-Chain Conditional Random Fields underneath + + Check https://en.wikipedia.org/wiki/Conditional_random_field to learn + more about CRFs + """ + + config_type = CRFSlotFillerConfig + + def __init__(self, config=None, **shared): + """The CRF slot filler can be configured by passing a + :class:`.CRFSlotFillerConfig`""" + # The CRFSlotFillerConfig must be deep-copied as it is mutated when + # fitting the feature factories + config = deepcopy(config) + super(CRFSlotFiller, self).__init__(config, **shared) + self.crf_model = None + self.features_factories = [ + CRFFeatureFactory.from_config(conf, **shared) + for conf in self.config.feature_factory_configs] + self._features = None + self.language = None + self.intent = None + self.slot_name_mapping = None + + @property + def features(self): + """List of :class:`.Feature` used by the CRF""" + if self._features is None: + self._features = [] + feature_names = set() + for factory in self.features_factories: + for feature in factory.build_features(): + if feature.name in feature_names: + raise KeyError("Duplicated feature: %s" % feature.name) + feature_names.add(feature.name) + self._features.append(feature) + return self._features + + @property + def labels(self): + """List of CRF labels + + These labels differ from the slot names as they contain an additional + prefix which depends on the :class:`.TaggingScheme` that is used + (BIO by default). + """ + labels = [] + if self.crf_model.tagger_ is not None: + labels = [_decode_tag(label) for label in + self.crf_model.tagger_.labels()] + return labels + + @property + def fitted(self): + """Whether or not the slot filler has already been fitted""" + return self.slot_name_mapping is not None + + @log_elapsed_time(logger, logging.INFO, + "Fitted CRFSlotFiller in {elapsed_time}") + # pylint:disable=arguments-differ + def fit(self, dataset, intent): + """Fits the slot filler + + Args: + dataset (dict): A valid Snips dataset + intent (str): The specific intent of the dataset to train + the slot filler on + + Returns: + :class:`CRFSlotFiller`: The same instance, trained + """ + logger.info("Fitting %s slot filler...", intent) + dataset = validate_and_format_dataset(dataset) + self.load_resources_if_needed(dataset[LANGUAGE]) + self.fit_builtin_entity_parser_if_needed(dataset) + self.fit_custom_entity_parser_if_needed(dataset) + + for factory in self.features_factories: + factory.custom_entity_parser = self.custom_entity_parser + factory.builtin_entity_parser = self.builtin_entity_parser + factory.resources = self.resources + + self.language = dataset[LANGUAGE] + self.intent = intent + self.slot_name_mapping = get_slot_name_mapping(dataset, intent) + + if not self.slot_name_mapping: + # No need to train the CRF if the intent has no slots + return self + + augmented_intent_utterances = augment_utterances( + dataset, self.intent, language=self.language, + resources=self.resources, random_state=self.random_state, + **self.config.data_augmentation_config.to_dict()) + + crf_samples = [ + utterance_to_sample(u[DATA], self.config.tagging_scheme, + self.language) + for u in augmented_intent_utterances] + + for factory in self.features_factories: + factory.fit(dataset, intent) + + # Ensure that X, Y are safe and that the OUTSIDE label is learnt to + # avoid segfault at inference time + # pylint: disable=C0103 + X = [self.compute_features(sample[TOKENS], drop_out=True) + for sample in crf_samples] + Y = [[tag for tag in sample[TAGS]] for sample in crf_samples] + X, Y = _ensure_safe(X, Y) + + # ensure ascii tags + Y = [[_encode_tag(tag) for tag in y] for y in Y] + + # pylint: enable=C0103 + self.crf_model = _get_crf_model(self.config.crf_args) + self.crf_model.fit(X, Y) + + logger.debug( + "Most relevant features for %s:\n%s", self.intent, + DifferedLoggingMessage(self.log_weights)) + return self + + # pylint:enable=arguments-differ + + @fitted_required + def get_slots(self, text): + """Extracts slots from the provided text + + Returns: + list of dict: The list of extracted slots + + Raises: + NotTrained: When the slot filler is not fitted + """ + if not self.slot_name_mapping: + # Early return if the intent has no slots + return [] + + tokens = tokenize(text, self.language) + if not tokens: + return [] + features = self.compute_features(tokens) + tags = self.crf_model.predict_single(features) + logger.debug(DifferedLoggingMessage( + self.log_inference_weights, text, tokens=tokens, features=features, + tags=tags)) + decoded_tags = [_decode_tag(t) for t in tags] + return tags_to_slots(text, tokens, decoded_tags, + self.config.tagging_scheme, + self.slot_name_mapping) + + def compute_features(self, tokens, drop_out=False): + """Computes features on the provided tokens + + The *drop_out* parameters allows to activate drop out on features that + have a positive drop out ratio. This should only be used during + training. + """ + + cache = [{TOKEN_NAME: token} for token in tokens] + features = [] + for i in range(len(tokens)): + token_features = UnupdatableDict() + for feature in self.features: + f_drop_out = feature.drop_out + if drop_out and self.random_state.rand() < f_drop_out: + continue + value = feature.compute(i, cache) + if value is not None: + token_features[feature.name] = value + features.append(token_features) + return features + + @fitted_required + def get_sequence_probability(self, tokens, labels): + """Gives the joint probability of a sequence of tokens and CRF labels + + Args: + tokens (list of :class:`.Token`): list of tokens + labels (list of str): CRF labels with their tagging scheme prefix + ("B-color", "I-color", "O", etc) + + Note: + The absolute value returned here is generally not very useful, + however it can be used to compare a sequence of labels relatively + to another one. + """ + if not self.slot_name_mapping: + return 0.0 if any(label != OUTSIDE for label in labels) else 1.0 + features = self.compute_features(tokens) + return self._get_sequence_probability(features, labels) + + @fitted_required + def _get_sequence_probability(self, features, labels): + # Use a default substitution label when a label was not seen during + # training + substitution_label = OUTSIDE if OUTSIDE in self.labels else \ + self.labels[0] + cleaned_labels = [ + _encode_tag(substitution_label if l not in self.labels else l) + for l in labels] + self.crf_model.tagger_.set(features) + return self.crf_model.tagger_.probability(cleaned_labels) + + @fitted_required + def log_weights(self): + """Returns a logs for both the label-to-label and label-to-features + weights""" + if not self.slot_name_mapping: + return "No weights to display: intent '%s' has no slots" \ + % self.intent + log = "" + transition_features = self.crf_model.transition_features_ + transition_features = sorted( + iteritems(transition_features), key=_weight_absolute_value, + reverse=True) + log += "\nTransition weights: \n\n" + for (state_1, state_2), weight in transition_features: + log += "\n%s %s: %s" % ( + _decode_tag(state_1), _decode_tag(state_2), weight) + feature_weights = self.crf_model.state_features_ + feature_weights = sorted( + iteritems(feature_weights), key=_weight_absolute_value, + reverse=True) + log += "\n\nFeature weights: \n\n" + for (feat, tag), weight in feature_weights: + log += "\n%s %s: %s" % (feat, _decode_tag(tag), weight) + return log + + def log_inference_weights(self, text, tokens, features, tags): + model_features = set( + f for (f, _), w in iteritems(self.crf_model.state_features_)) + log = "Feature weights for \"%s\":\n\n" % text + max_index = len(tokens) - 1 + tokens_logs = [] + for i, (token, feats, tag) in enumerate(zip(tokens, features, tags)): + token_log = "# Token \"%s\" (tagged as %s):" \ + % (token.value, _decode_tag(tag)) + if i != 0: + weights = sorted(self._get_outgoing_weights(tags[i - 1]), + key=_weight_absolute_value, reverse=True) + if weights: + token_log += "\n\nTransition weights from previous tag:" + weight_lines = ( + "- (%s, %s) -> %s" + % (_decode_tag(a), _decode_tag(b), w) + for (a, b), w in weights + ) + token_log += "\n" + "\n".join(weight_lines) + else: + token_log += \ + "\n\nNo transition from previous tag seen at" \ + " train time !" + + if i != max_index: + weights = sorted(self._get_incoming_weights(tags[i + 1]), + key=_weight_absolute_value, reverse=True) + if weights: + token_log += "\n\nTransition weights to next tag:" + weight_lines = ( + "- (%s, %s) -> %s" + % (_decode_tag(a), _decode_tag(b), w) + for (a, b), w in weights + ) + token_log += "\n" + "\n".join(weight_lines) + else: + token_log += \ + "\n\nNo transition to next tag seen at train time !" + feats = [":".join(f) for f in iteritems(feats)] + weights = (w for f in feats for w in self._get_feature_weight(f)) + weights = sorted(weights, key=_weight_absolute_value, reverse=True) + if weights: + token_log += "\n\nFeature weights:\n" + token_log += "\n".join( + "- (%s, %s) -> %s" + % (f, _decode_tag(t), w) for (f, t), w in weights + ) + else: + token_log += "\n\nNo feature weights !" + + unseen_features = sorted( + set(f for f in feats if f not in model_features)) + if unseen_features: + token_log += "\n\nFeatures not seen at train time:\n%s" % \ + "\n".join("- %s" % f for f in unseen_features) + tokens_logs.append(token_log) + + log += "\n\n\n".join(tokens_logs) + return log + + @fitted_required + def _get_incoming_weights(self, tag): + return [((first, second), w) for (first, second), w + in iteritems(self.crf_model.transition_features_) + if second == tag] + + @fitted_required + def _get_outgoing_weights(self, tag): + return [((first, second), w) for (first, second), w + in iteritems(self.crf_model.transition_features_) + if first == tag] + + @fitted_required + def _get_feature_weight(self, feature): + return [((f, tag), w) for (f, tag), w + in iteritems(self.crf_model.state_features_) if f == feature] + + @check_persisted_path + def persist(self, path): + """Persists the object at the given path""" + path.mkdir() + + crf_model_file = None + if self.crf_model is not None: + crf_model_file = CRF_MODEL_FILENAME + destination = path / crf_model_file + shutil.copy(self.crf_model.modelfile.name, str(destination)) + # On windows, permissions of crfsuite files are correct + if os.name == "posix": + umask = os.umask(0o022) # retrieve the system umask + os.umask(umask) # restore the sys umask to its original value + os.chmod(str(destination), 0o644 & ~umask) + + model = { + "language_code": self.language, + "intent": self.intent, + "crf_model_file": crf_model_file, + "slot_name_mapping": self.slot_name_mapping, + "config": self.config.to_dict(), + } + model_json = json_string(model) + model_path = path / "slot_filler.json" + with model_path.open(mode="w", encoding="utf8") as f: + f.write(model_json) + self.persist_metadata(path) + + @classmethod + def from_path(cls, path, **shared): + """Loads a :class:`CRFSlotFiller` instance from a path + + The data at the given path must have been generated using + :func:`~CRFSlotFiller.persist` + """ + path = Path(path) + model_path = path / "slot_filler.json" + if not model_path.exists(): + raise LoadingError( + "Missing slot filler model file: %s" % model_path.name) + + with model_path.open(encoding="utf8") as f: + model = json.load(f) + + slot_filler_config = cls.config_type.from_dict(model["config"]) + slot_filler = cls(config=slot_filler_config, **shared) + slot_filler.language = model["language_code"] + slot_filler.intent = model["intent"] + slot_filler.slot_name_mapping = model["slot_name_mapping"] + crf_model_file = model["crf_model_file"] + if crf_model_file is not None: + crf = _crf_model_from_path(path / crf_model_file) + slot_filler.crf_model = crf + return slot_filler + + def _cleanup(self): + if self.crf_model is not None: + self.crf_model.modelfile.cleanup() + + def __del__(self): + self._cleanup() + + +def _get_crf_model(crf_args): + from sklearn_crfsuite import CRF + + model_filename = crf_args.get("model_filename", None) + if model_filename is not None: + directory = Path(model_filename).parent + if not directory.is_dir(): + mkdir_p(directory) + + return CRF(model_filename=model_filename, **crf_args) + + +def _encode_tag(tag): + return base64.b64encode(tag.encode("utf8")) + + +def _decode_tag(tag): + return base64.b64decode(tag).decode("utf8") + + +def _crf_model_from_path(crf_model_path): + from sklearn_crfsuite import CRF + + with crf_model_path.open(mode="rb") as f: + crf_model_data = f.read() + with tempfile.NamedTemporaryFile(suffix=".crfsuite", prefix="model", + delete=False) as f: + f.write(crf_model_data) + f.flush() + crf = CRF(model_filename=f.name) + return crf + + +# pylint: disable=invalid-name +def _ensure_safe(X, Y): + """Ensures that Y has at least one not empty label, otherwise the CRF model + does not contain any label and crashes at + + Args: + X: features + Y: labels + + Returns: + (safe_X, safe_Y): a pair of safe features and labels + """ + safe_X = list(X) + safe_Y = list(Y) + if not any(X) or not any(Y): + safe_X.append([""]) # empty feature + safe_Y.append([OUTSIDE]) # outside label + return safe_X, safe_Y + + +def _weight_absolute_value(x): + return math.fabs(x[1]) diff --git a/snips_inference_agl/slot_filler/crf_utils.py b/snips_inference_agl/slot_filler/crf_utils.py new file mode 100644 index 0000000..817a59b --- /dev/null +++ b/snips_inference_agl/slot_filler/crf_utils.py @@ -0,0 +1,219 @@ +from __future__ import unicode_literals + +from builtins import range +from enum import Enum, unique + +from snips_inference_agl.constants import END, SLOT_NAME, START, TEXT +from snips_inference_agl.preprocessing import Token, tokenize +from snips_inference_agl.result import unresolved_slot + +BEGINNING_PREFIX = "B-" +INSIDE_PREFIX = "I-" +LAST_PREFIX = "L-" +UNIT_PREFIX = "U-" +OUTSIDE = "O" + +RANGE = "range" +TAGS = "tags" +TOKENS = "tokens" + + +@unique +class TaggingScheme(Enum): + """CRF Coding Scheme""" + + IO = 0 + """Inside-Outside scheme""" + BIO = 1 + """Beginning-Inside-Outside scheme""" + BILOU = 2 + """Beginning-Inside-Last-Outside-Unit scheme, sometimes referred as + BWEMO""" + + +def tag_name_to_slot_name(tag): + return tag[2:] + + +def start_of_io_slot(tags, i): + if i == 0: + return tags[i] != OUTSIDE + if tags[i] == OUTSIDE: + return False + return tags[i - 1] == OUTSIDE + + +def end_of_io_slot(tags, i): + if i + 1 == len(tags): + return tags[i] != OUTSIDE + if tags[i] == OUTSIDE: + return False + return tags[i + 1] == OUTSIDE + + +def start_of_bio_slot(tags, i): + if i == 0: + return tags[i] != OUTSIDE + if tags[i] == OUTSIDE: + return False + if tags[i].startswith(BEGINNING_PREFIX): + return True + if tags[i - 1] != OUTSIDE: + return False + return True + + +def end_of_bio_slot(tags, i): + if i + 1 == len(tags): + return tags[i] != OUTSIDE + if tags[i] == OUTSIDE: + return False + if tags[i + 1].startswith(INSIDE_PREFIX): + return False + return True + + +def start_of_bilou_slot(tags, i): + if i == 0: + return tags[i] != OUTSIDE + if tags[i] == OUTSIDE: + return False + if tags[i].startswith(BEGINNING_PREFIX): + return True + if tags[i].startswith(UNIT_PREFIX): + return True + if tags[i - 1].startswith(UNIT_PREFIX): + return True + if tags[i - 1].startswith(LAST_PREFIX): + return True + if tags[i - 1] != OUTSIDE: + return False + return True + + +def end_of_bilou_slot(tags, i): + if i + 1 == len(tags): + return tags[i] != OUTSIDE + if tags[i] == OUTSIDE: + return False + if tags[i + 1] == OUTSIDE: + return True + if tags[i].startswith(LAST_PREFIX): + return True + if tags[i].startswith(UNIT_PREFIX): + return True + if tags[i + 1].startswith(BEGINNING_PREFIX): + return True + if tags[i + 1].startswith(UNIT_PREFIX): + return True + return False + + +def _tags_to_preslots(tags, tokens, is_start_of_slot, is_end_of_slot): + slots = [] + current_slot_start = 0 + for i, tag in enumerate(tags): + if is_start_of_slot(tags, i): + current_slot_start = i + if is_end_of_slot(tags, i): + slots.append({ + RANGE: { + START: tokens[current_slot_start].start, + END: tokens[i].end + }, + SLOT_NAME: tag_name_to_slot_name(tag) + }) + current_slot_start = i + return slots + + +def tags_to_preslots(tokens, tags, tagging_scheme): + if tagging_scheme == TaggingScheme.IO: + slots = _tags_to_preslots(tags, tokens, start_of_io_slot, + end_of_io_slot) + elif tagging_scheme == TaggingScheme.BIO: + slots = _tags_to_preslots(tags, tokens, start_of_bio_slot, + end_of_bio_slot) + elif tagging_scheme == TaggingScheme.BILOU: + slots = _tags_to_preslots(tags, tokens, start_of_bilou_slot, + end_of_bilou_slot) + else: + raise ValueError("Unknown tagging scheme %s" % tagging_scheme) + return slots + + +def tags_to_slots(text, tokens, tags, tagging_scheme, intent_slots_mapping): + slots = tags_to_preslots(tokens, tags, tagging_scheme) + return [ + unresolved_slot(match_range=slot[RANGE], + value=text[slot[RANGE][START]:slot[RANGE][END]], + entity=intent_slots_mapping[slot[SLOT_NAME]], + slot_name=slot[SLOT_NAME]) + for slot in slots + ] + + +def positive_tagging(tagging_scheme, slot_name, slot_size): + if slot_name == OUTSIDE: + return [OUTSIDE for _ in range(slot_size)] + + if tagging_scheme == TaggingScheme.IO: + tags = [INSIDE_PREFIX + slot_name for _ in range(slot_size)] + elif tagging_scheme == TaggingScheme.BIO: + if slot_size > 0: + tags = [BEGINNING_PREFIX + slot_name] + tags += [INSIDE_PREFIX + slot_name for _ in range(1, slot_size)] + else: + tags = [] + elif tagging_scheme == TaggingScheme.BILOU: + if slot_size == 0: + tags = [] + elif slot_size == 1: + tags = [UNIT_PREFIX + slot_name] + else: + tags = [BEGINNING_PREFIX + slot_name] + tags += [INSIDE_PREFIX + slot_name + for _ in range(1, slot_size - 1)] + tags.append(LAST_PREFIX + slot_name) + else: + raise ValueError("Invalid tagging scheme %s" % tagging_scheme) + return tags + + +def negative_tagging(size): + return [OUTSIDE for _ in range(size)] + + +def utterance_to_sample(query_data, tagging_scheme, language): + tokens, tags = [], [] + current_length = 0 + for chunk in query_data: + chunk_tokens = tokenize(chunk[TEXT], language) + tokens += [Token(t.value, current_length + t.start, + current_length + t.end) for t in chunk_tokens] + current_length += len(chunk[TEXT]) + if SLOT_NAME not in chunk: + tags += negative_tagging(len(chunk_tokens)) + else: + tags += positive_tagging(tagging_scheme, chunk[SLOT_NAME], + len(chunk_tokens)) + return {TOKENS: tokens, TAGS: tags} + + +def get_scheme_prefix(index, indexes, tagging_scheme): + if tagging_scheme == TaggingScheme.IO: + return INSIDE_PREFIX + elif tagging_scheme == TaggingScheme.BIO: + if index == indexes[0]: + return BEGINNING_PREFIX + return INSIDE_PREFIX + elif tagging_scheme == TaggingScheme.BILOU: + if len(indexes) == 1: + return UNIT_PREFIX + if index == indexes[0]: + return BEGINNING_PREFIX + if index == indexes[-1]: + return LAST_PREFIX + return INSIDE_PREFIX + else: + raise ValueError("Invalid tagging scheme %s" % tagging_scheme) diff --git a/snips_inference_agl/slot_filler/feature.py b/snips_inference_agl/slot_filler/feature.py new file mode 100644 index 0000000..a6da552 --- /dev/null +++ b/snips_inference_agl/slot_filler/feature.py @@ -0,0 +1,69 @@ +from __future__ import unicode_literals + +from builtins import object + +TOKEN_NAME = "token" + + +class Feature(object): + """CRF Feature which is used by :class:`.CRFSlotFiller` + + Attributes: + base_name (str): Feature name (e.g. 'is_digit', 'is_first' etc) + func (function): The actual feature function for example: + + def is_first(tokens, token_index): + return "1" if token_index == 0 else None + + offset (int, optional): Token offset to consider when computing + the feature (e.g -1 for computing the feature on the previous word) + drop_out (float, optional): Drop out to use when computing the + feature during training + + Note: + The easiest way to add additional features to the existing ones is + to create a :class:`.CRFFeatureFactory` + """ + + def __init__(self, base_name, func, offset=0, drop_out=0): + if base_name == TOKEN_NAME: + raise ValueError("'%s' name is reserved" % TOKEN_NAME) + self.offset = offset + self._name = None + self._base_name = None + self.base_name = base_name + self.function = func + self.drop_out = drop_out + + @property + def name(self): + return self._name + + @property + def base_name(self): + return self._base_name + + @base_name.setter + def base_name(self, value): + self._name = _offset_name(value, self.offset) + self._base_name = _offset_name(value, 0) + + def compute(self, token_index, cache): + if not 0 <= (token_index + self.offset) < len(cache): + return None + + if self.base_name in cache[token_index + self.offset]: + return cache[token_index + self.offset][self.base_name] + + tokens = [c["token"] for c in cache] + value = self.function(tokens, token_index + self.offset) + cache[token_index + self.offset][self.base_name] = value + return value + + +def _offset_name(name, offset): + if offset > 0: + return "%s[+%s]" % (name, offset) + if offset < 0: + return "%s[%s]" % (name, offset) + return name diff --git a/snips_inference_agl/slot_filler/feature_factory.py b/snips_inference_agl/slot_filler/feature_factory.py new file mode 100644 index 0000000..50f4598 --- /dev/null +++ b/snips_inference_agl/slot_filler/feature_factory.py @@ -0,0 +1,568 @@ +from __future__ import unicode_literals + +import logging + +from abc import ABCMeta, abstractmethod +from builtins import str + +from future.utils import with_metaclass + +from snips_inference_agl.common.abc_utils import classproperty +from snips_inference_agl.common.registrable import Registrable +from snips_inference_agl.common.utils import check_random_state +from snips_inference_agl.constants import ( + CUSTOM_ENTITY_PARSER_USAGE, END, GAZETTEERS, LANGUAGE, RES_MATCH_RANGE, + START, STEMS, WORD_CLUSTERS, CUSTOM_ENTITY_PARSER, BUILTIN_ENTITY_PARSER, + RESOURCES, RANDOM_STATE, AUTOMATICALLY_EXTENSIBLE, ENTITIES) +from snips_inference_agl.dataset import ( + extract_intent_entities, get_dataset_gazetteer_entities) +from snips_inference_agl.entity_parser.builtin_entity_parser import is_builtin_entity +from snips_inference_agl.entity_parser.custom_entity_parser import ( + CustomEntityParserUsage) +from snips_inference_agl.languages import get_default_sep +from snips_inference_agl.preprocessing import Token, normalize_token, stem_token +from snips_inference_agl.resources import get_gazetteer, get_word_cluster +from snips_inference_agl.slot_filler.crf_utils import TaggingScheme, get_scheme_prefix +from snips_inference_agl.slot_filler.feature import Feature +from snips_inference_agl.slot_filler.features_utils import ( + entity_filter, get_word_chunk, initial_string_from_tokens) + +logger = logging.getLogger(__name__) + + +class CRFFeatureFactory(with_metaclass(ABCMeta, Registrable)): + """Abstraction to implement to build CRF features + + A :class:`CRFFeatureFactory` is initialized with a dict which describes + the feature, it must contains the three following keys: + + - 'factory_name' + - 'args': the parameters of the feature, if any + - 'offsets': the offsets to consider when using the feature in the CRF. + An empty list corresponds to no feature. + + + In addition, a 'drop_out' to use at training time can be specified. + """ + + def __init__(self, factory_config, **shared): + self.factory_config = factory_config + self.resources = shared.get(RESOURCES) + self.builtin_entity_parser = shared.get(BUILTIN_ENTITY_PARSER) + self.custom_entity_parser = shared.get(CUSTOM_ENTITY_PARSER) + self.random_state = check_random_state(shared.get(RANDOM_STATE)) + + @classmethod + def from_config(cls, factory_config, **shared): + """Retrieve the :class:`CRFFeatureFactory` corresponding the provided + config + + Raises: + NotRegisteredError: when the factory is not registered + """ + factory_name = factory_config["factory_name"] + factory = cls.by_name(factory_name) + return factory(factory_config, **shared) + + @classproperty + def name(cls): # pylint:disable=no-self-argument + return CRFFeatureFactory.registered_name(cls) + + @property + def args(self): + return self.factory_config["args"] + + @property + def offsets(self): + return self.factory_config["offsets"] + + @property + def drop_out(self): + return self.factory_config.get("drop_out", 0.0) + + def fit(self, dataset, intent): # pylint: disable=unused-argument + """Fit the factory, if needed, with the provided *dataset* and *intent* + """ + return self + + @abstractmethod + def build_features(self): + """Build a list of :class:`.Feature`""" + pass + + def get_required_resources(self): + return None + + +class SingleFeatureFactory(with_metaclass(ABCMeta, CRFFeatureFactory)): + """A CRF feature factory which produces only one feature""" + + @property + def feature_name(self): + # by default, use the factory name + return self.name + + @abstractmethod + def compute_feature(self, tokens, token_index): + pass + + def build_features(self): + return [ + Feature( + base_name=self.feature_name, + func=self.compute_feature, + offset=offset, + drop_out=self.drop_out) for offset in self.offsets + ] + + +@CRFFeatureFactory.register("is_digit") +class IsDigitFactory(SingleFeatureFactory): + """Feature: is the considered token a digit?""" + + def compute_feature(self, tokens, token_index): + return "1" if tokens[token_index].value.isdigit() else None + + +@CRFFeatureFactory.register("is_first") +class IsFirstFactory(SingleFeatureFactory): + """Feature: is the considered token the first in the input?""" + + def compute_feature(self, tokens, token_index): + return "1" if token_index == 0 else None + + +@CRFFeatureFactory.register("is_last") +class IsLastFactory(SingleFeatureFactory): + """Feature: is the considered token the last in the input?""" + + def compute_feature(self, tokens, token_index): + return "1" if token_index == len(tokens) - 1 else None + + +@CRFFeatureFactory.register("ngram") +class NgramFactory(SingleFeatureFactory): + """Feature: the n-gram consisting of the considered token and potentially + the following ones + + This feature has several parameters: + + - 'n' (int): Corresponds to the size of the n-gram. n=1 corresponds to a + unigram, n=2 is a bigram etc + - 'use_stemming' (bool): Whether or not to stem the n-gram + - 'common_words_gazetteer_name' (str, optional): If defined, use a + gazetteer of common words and replace out-of-corpus ngram with the + alias + 'rare_word' + + """ + + def __init__(self, factory_config, **shared): + super(NgramFactory, self).__init__(factory_config, **shared) + self.n = self.args["n"] + if self.n < 1: + raise ValueError("n should be >= 1") + + self.use_stemming = self.args["use_stemming"] + self.common_words_gazetteer_name = self.args[ + "common_words_gazetteer_name"] + self._gazetteer = None + self._language = None + self.language = self.args.get("language_code") + + @property + def language(self): + return self._language + + @language.setter + def language(self, value): + if value is not None: + self._language = value + self.args["language_code"] = self.language + + @property + def gazetteer(self): + # Load the gazetteer lazily + if self.common_words_gazetteer_name is None: + return None + if self._gazetteer is None: + self._gazetteer = get_gazetteer( + self.resources, self.common_words_gazetteer_name) + return self._gazetteer + + @property + def feature_name(self): + return "ngram_%s" % self.n + + def fit(self, dataset, intent): + self.language = dataset[LANGUAGE] + + def compute_feature(self, tokens, token_index): + max_len = len(tokens) + end = token_index + self.n + if 0 <= token_index < max_len and end <= max_len: + if self.gazetteer is None: + if self.use_stemming: + stems = (stem_token(t, self.resources) + for t in tokens[token_index:end]) + return get_default_sep(self.language).join(stems) + normalized_values = (normalize_token(t) + for t in tokens[token_index:end]) + return get_default_sep(self.language).join(normalized_values) + words = [] + for t in tokens[token_index:end]: + if self.use_stemming: + value = stem_token(t, self.resources) + else: + value = normalize_token(t) + words.append(value if value in self.gazetteer else "rare_word") + return get_default_sep(self.language).join(words) + return None + + def get_required_resources(self): + resources = dict() + if self.common_words_gazetteer_name is not None: + resources[GAZETTEERS] = {self.common_words_gazetteer_name} + if self.use_stemming: + resources[STEMS] = True + return resources + + +@CRFFeatureFactory.register("shape_ngram") +class ShapeNgramFactory(SingleFeatureFactory): + """Feature: the shape of the n-gram consisting of the considered token and + potentially the following ones + + This feature has one parameters, *n*, which corresponds to the size of the + n-gram. + + Possible types of shape are: + + - 'xxx' -> lowercased + - 'Xxx' -> Capitalized + - 'XXX' -> UPPERCASED + - 'xX' -> None of the above + """ + + def __init__(self, factory_config, **shared): + super(ShapeNgramFactory, self).__init__(factory_config, **shared) + self.n = self.args["n"] + if self.n < 1: + raise ValueError("n should be >= 1") + self._language = None + self.language = self.args.get("language_code") + + @property + def language(self): + return self._language + + @language.setter + def language(self, value): + if value is not None: + self._language = value + self.args["language_code"] = value + + @property + def feature_name(self): + return "shape_ngram_%s" % self.n + + def fit(self, dataset, intent): + self.language = dataset[LANGUAGE] + + def compute_feature(self, tokens, token_index): + from snips_nlu_utils import get_shape + + max_len = len(tokens) + end = token_index + self.n + if 0 <= token_index < max_len and end <= max_len: + return get_default_sep(self.language).join( + get_shape(t.value) for t in tokens[token_index:end]) + return None + + +@CRFFeatureFactory.register("word_cluster") +class WordClusterFactory(SingleFeatureFactory): + """Feature: The cluster which the considered token belongs to, if any + + This feature has several parameters: + + - 'cluster_name' (str): the name of the word cluster to use + - 'use_stemming' (bool): whether or not to stem the token before looking + for its cluster + + Typical words clusters are the Brown Clusters in which words are + clustered into a binary tree resulting in clusters of the form '100111001' + See https://en.wikipedia.org/wiki/Brown_clustering + """ + + def __init__(self, factory_config, **shared): + super(WordClusterFactory, self).__init__(factory_config, **shared) + self.cluster_name = self.args["cluster_name"] + self.use_stemming = self.args["use_stemming"] + self._cluster = None + + @property + def cluster(self): + if self._cluster is None: + self._cluster = get_word_cluster(self.resources, self.cluster_name) + return self._cluster + + @property + def feature_name(self): + return "word_cluster_%s" % self.cluster_name + + def compute_feature(self, tokens, token_index): + if self.use_stemming: + value = stem_token(tokens[token_index], self.resources) + else: + value = normalize_token(tokens[token_index]) + return self.cluster.get(value, None) + + def get_required_resources(self): + return { + WORD_CLUSTERS: {self.cluster_name}, + STEMS: self.use_stemming + } + + +@CRFFeatureFactory.register("entity_match") +class CustomEntityMatchFactory(CRFFeatureFactory): + """Features: does the considered token belongs to the values of one of the + entities in the training dataset + + This factory builds as many features as there are entities in the dataset, + one per entity. + + It has the following parameters: + + - 'use_stemming' (bool): whether or not to stem the token before looking + for it among the (stemmed) entity values + - 'tagging_scheme_code' (int): Represents a :class:`.TaggingScheme`. This + allows to give more information about the match. + - 'entity_filter' (dict): a filter applied to select the custom entities + for which the custom match feature will be computed. Available + filters: + - 'automatically_extensible': if True, selects automatically + extensible entities only, if False selects non automatically + extensible entities only + """ + + def __init__(self, factory_config, **shared): + super(CustomEntityMatchFactory, self).__init__(factory_config, + **shared) + self.use_stemming = self.args["use_stemming"] + self.tagging_scheme = TaggingScheme( + self.args["tagging_scheme_code"]) + self._entities = None + self.entities = self.args.get("entities") + ent_filter = self.args.get("entity_filter") + if ent_filter: + try: + _check_custom_entity_filter(ent_filter) + except _InvalidCustomEntityFilter as e: + logger.warning( + "Invalid filter '%s', invalid arguments have been ignored:" + " %s", ent_filter, e, + ) + self.entity_filter = ent_filter or dict() + + @property + def entities(self): + return self._entities + + @entities.setter + def entities(self, value): + if value is not None: + self._entities = value + self.args["entities"] = value + + def fit(self, dataset, intent): + entities_names = extract_intent_entities( + dataset, lambda e: not is_builtin_entity(e))[intent] + extensible = self.entity_filter.get(AUTOMATICALLY_EXTENSIBLE) + if extensible is not None: + entities_names = [ + e for e in entities_names + if dataset[ENTITIES][e][AUTOMATICALLY_EXTENSIBLE] == extensible + ] + self.entities = list(entities_names) + return self + + def _transform(self, tokens): + if self.use_stemming: + light_tokens = (stem_token(t, self.resources) for t in tokens) + else: + light_tokens = (normalize_token(t) for t in tokens) + current_index = 0 + transformed_tokens = [] + for light_token in light_tokens: + transformed_token = Token( + value=light_token, + start=current_index, + end=current_index + len(light_token)) + transformed_tokens.append(transformed_token) + current_index = transformed_token.end + 1 + return transformed_tokens + + def build_features(self): + features = [] + for entity_name in self.entities: + # We need to call this wrapper in order to properly capture + # `entity_name` + entity_match = self._build_entity_match_fn(entity_name) + + for offset in self.offsets: + feature = Feature("entity_match_%s" % entity_name, + entity_match, offset, self.drop_out) + features.append(feature) + return features + + def _build_entity_match_fn(self, entity): + + def entity_match(tokens, token_index): + transformed_tokens = self._transform(tokens) + text = initial_string_from_tokens(transformed_tokens) + token_start = transformed_tokens[token_index].start + token_end = transformed_tokens[token_index].end + custom_entities = self.custom_entity_parser.parse( + text, scope=[entity], use_cache=True) + # only keep builtin entities (of type `entity`) which overlap with + # the current token + custom_entities = [ent for ent in custom_entities + if entity_filter(ent, token_start, token_end)] + if custom_entities: + # In most cases, 0 or 1 entity will be found. We fall back to + # the first entity if 2 or more were found + ent = custom_entities[0] + indexes = [] + for index, token in enumerate(transformed_tokens): + if entity_filter(ent, token.start, token.end): + indexes.append(index) + return get_scheme_prefix(token_index, indexes, + self.tagging_scheme) + return None + + return entity_match + + def get_required_resources(self): + if self.use_stemming: + return { + STEMS: True, + CUSTOM_ENTITY_PARSER_USAGE: CustomEntityParserUsage.WITH_STEMS + } + return { + STEMS: False, + CUSTOM_ENTITY_PARSER_USAGE: + CustomEntityParserUsage.WITHOUT_STEMS + } + + +class _InvalidCustomEntityFilter(ValueError): + pass + + +CUSTOM_ENTITIES_FILTER_KEYS = {"automatically_extensible"} + + +# pylint: disable=redefined-outer-name +def _check_custom_entity_filter(entity_filter): + for k in entity_filter: + if k not in CUSTOM_ENTITIES_FILTER_KEYS: + msg = "Invalid custom entity filter key '%s'. Accepted filter " \ + "keys are %s" % (k, list(CUSTOM_ENTITIES_FILTER_KEYS)) + raise _InvalidCustomEntityFilter(msg) + + +@CRFFeatureFactory.register("builtin_entity_match") +class BuiltinEntityMatchFactory(CRFFeatureFactory): + """Features: is the considered token part of a builtin entity such as a + date, a temperature etc + + This factory builds as many features as there are builtin entities + available in the considered language. + + It has one parameter, *tagging_scheme_code*, which represents a + :class:`.TaggingScheme`. This allows to give more information about the + match. + """ + + def __init__(self, factory_config, **shared): + super(BuiltinEntityMatchFactory, self).__init__(factory_config, + **shared) + self.tagging_scheme = TaggingScheme( + self.args["tagging_scheme_code"]) + self.builtin_entities = None + self.builtin_entities = self.args.get("entity_labels") + self._language = None + self.language = self.args.get("language_code") + + @property + def language(self): + return self._language + + @language.setter + def language(self, value): + if value is not None: + self._language = value + self.args["language_code"] = self.language + + def fit(self, dataset, intent): + self.language = dataset[LANGUAGE] + self.builtin_entities = sorted( + self._get_builtin_entity_scope(dataset, intent)) + self.args["entity_labels"] = self.builtin_entities + + def build_features(self): + features = [] + + for builtin_entity in self.builtin_entities: + # We need to call this wrapper in order to properly capture + # `builtin_entity` + builtin_entity_match = self._build_entity_match_fn(builtin_entity) + for offset in self.offsets: + feature_name = "builtin_entity_match_%s" % builtin_entity + feature = Feature(feature_name, builtin_entity_match, offset, + self.drop_out) + features.append(feature) + + return features + + def _build_entity_match_fn(self, builtin_entity): + + def builtin_entity_match(tokens, token_index): + text = initial_string_from_tokens(tokens) + start = tokens[token_index].start + end = tokens[token_index].end + + builtin_entities = self.builtin_entity_parser.parse( + text, scope=[builtin_entity], use_cache=True) + # only keep builtin entities (of type `builtin_entity`) which + # overlap with the current token + builtin_entities = [ent for ent in builtin_entities + if entity_filter(ent, start, end)] + if builtin_entities: + # In most cases, 0 or 1 entity will be found. We fall back to + # the first entity if 2 or more were found + ent = builtin_entities[0] + entity_start = ent[RES_MATCH_RANGE][START] + entity_end = ent[RES_MATCH_RANGE][END] + indexes = [] + for index, token in enumerate(tokens): + if (entity_start <= token.start < entity_end) \ + and (entity_start < token.end <= entity_end): + indexes.append(index) + return get_scheme_prefix(token_index, indexes, + self.tagging_scheme) + return None + + return builtin_entity_match + + @staticmethod + def _get_builtin_entity_scope(dataset, intent=None): + from snips_nlu_parsers import get_supported_grammar_entities + + language = dataset[LANGUAGE] + grammar_entities = list(get_supported_grammar_entities(language)) + gazetteer_entities = list( + get_dataset_gazetteer_entities(dataset, intent)) + return grammar_entities + gazetteer_entities diff --git a/snips_inference_agl/slot_filler/features_utils.py b/snips_inference_agl/slot_filler/features_utils.py new file mode 100644 index 0000000..483e9c0 --- /dev/null +++ b/snips_inference_agl/slot_filler/features_utils.py @@ -0,0 +1,47 @@ +from __future__ import unicode_literals + +from copy import deepcopy + +from snips_inference_agl.common.dict_utils import LimitedSizeDict +from snips_inference_agl.constants import END, RES_MATCH_RANGE, START + +_NGRAMS_CACHE = LimitedSizeDict(size_limit=1000) + + +def get_all_ngrams(tokens): + from snips_nlu_utils import compute_all_ngrams + + if not tokens: + return [] + key = "<||>".join(tokens) + if key not in _NGRAMS_CACHE: + ngrams = compute_all_ngrams(tokens, len(tokens)) + _NGRAMS_CACHE[key] = ngrams + return deepcopy(_NGRAMS_CACHE[key]) + + +def get_word_chunk(word, chunk_size, chunk_start, reverse=False): + if chunk_size < 1: + raise ValueError("chunk size should be >= 1") + if chunk_size > len(word): + return None + start = chunk_start - chunk_size if reverse else chunk_start + end = chunk_start if reverse else chunk_start + chunk_size + return word[start:end] + + +def initial_string_from_tokens(tokens): + current_index = 0 + s = "" + for t in tokens: + if t.start > current_index: + s += " " * (t.start - current_index) + s += t.value + current_index = t.end + return s + + +def entity_filter(entity, start, end): + entity_start = entity[RES_MATCH_RANGE][START] + entity_end = entity[RES_MATCH_RANGE][END] + return entity_start <= start < end <= entity_end diff --git a/snips_inference_agl/slot_filler/keyword_slot_filler.py b/snips_inference_agl/slot_filler/keyword_slot_filler.py new file mode 100644 index 0000000..087d997 --- /dev/null +++ b/snips_inference_agl/slot_filler/keyword_slot_filler.py @@ -0,0 +1,70 @@ +from __future__ import unicode_literals + +import json + +from snips_inference_agl.common.utils import json_string +from snips_inference_agl.preprocessing import tokenize +from snips_inference_agl.result import unresolved_slot +from snips_inference_agl.slot_filler import SlotFiller + + +@SlotFiller.register("keyword_slot_filler") +class KeywordSlotFiller(SlotFiller): + def __init__(self, config=None, **shared): + super(KeywordSlotFiller, self).__init__(config, **shared) + self.slots_keywords = None + self.language = None + + @property + def fitted(self): + return self.slots_keywords is not None + + def fit(self, dataset, intent): + self.language = dataset["language"] + self.slots_keywords = dict() + utterances = dataset["intents"][intent]["utterances"] + for utterance in utterances: + for chunk in utterance["data"]: + if "slot_name" in chunk: + text = chunk["text"] + if self.config.get("lowercase", False): + text = text.lower() + self.slots_keywords[text] = [ + chunk["entity"], + chunk["slot_name"] + ] + return self + + def get_slots(self, text): + tokens = tokenize(text, self.language) + slots = [] + for token in tokens: + normalized_value = token.value + if self.config.get("lowercase", False): + normalized_value = normalized_value.lower() + if normalized_value in self.slots_keywords: + entity = self.slots_keywords[normalized_value][0] + slot_name = self.slots_keywords[normalized_value][1] + slot = unresolved_slot((token.start, token.end), token.value, + entity, slot_name) + slots.append(slot) + return slots + + def persist(self, path): + model = { + "language": self.language, + "slots_keywords": self.slots_keywords, + "config": self.config.to_dict() + } + with path.open(mode="w", encoding="utf8") as f: + f.write(json_string(model)) + + @classmethod + def from_path(cls, path, **shared): + with path.open() as f: + model = json.load(f) + slot_filler = cls() + slot_filler.language = model["language"] + slot_filler.slots_keywords = model["slots_keywords"] + slot_filler.config = cls.config_type.from_dict(model["config"]) + return slot_filler diff --git a/snips_inference_agl/slot_filler/slot_filler.py b/snips_inference_agl/slot_filler/slot_filler.py new file mode 100644 index 0000000..a1fc937 --- /dev/null +++ b/snips_inference_agl/slot_filler/slot_filler.py @@ -0,0 +1,33 @@ +from abc import abstractmethod, ABCMeta + +from future.utils import with_metaclass + +from snips_inference_agl.common.abc_utils import classproperty +from snips_inference_agl.pipeline.processing_unit import ProcessingUnit + + +class SlotFiller(with_metaclass(ABCMeta, ProcessingUnit)): + """Abstraction which performs slot filling + + A custom slot filler must inherit this class to be used in a + :class:`.ProbabilisticIntentParser` + """ + + @classproperty + def unit_name(cls): # pylint:disable=no-self-argument + return SlotFiller.registered_name(cls) + + @abstractmethod + def fit(self, dataset, intent): + """Fit the slot filler with a valid Snips dataset""" + pass + + @abstractmethod + def get_slots(self, text): + """Performs slot extraction (slot filling) on the provided *text* + + Returns: + list of dict: The list of extracted slots. See + :func:`.unresolved_slot` for the output format of a slot + """ + pass diff --git a/snips_inference_agl/string_variations.py b/snips_inference_agl/string_variations.py new file mode 100644 index 0000000..f65e34e --- /dev/null +++ b/snips_inference_agl/string_variations.py @@ -0,0 +1,195 @@ +# coding=utf-8 +from __future__ import unicode_literals + +import itertools +import re +from builtins import range, str, zip + +from future.utils import iteritems + +from snips_inference_agl.constants import ( + END, LANGUAGE_DE, LANGUAGE_EN, LANGUAGE_ES, LANGUAGE_FR, RESOLVED_VALUE, + RES_MATCH_RANGE, SNIPS_NUMBER, START, VALUE) +from snips_inference_agl.languages import ( + get_default_sep, get_punctuation_regex, supports_num2words) +from snips_inference_agl.preprocessing import tokenize_light + +AND_UTTERANCES = { + LANGUAGE_EN: ["and", "&"], + LANGUAGE_FR: ["et", "&"], + LANGUAGE_ES: ["y", "&"], + LANGUAGE_DE: ["und", "&"], +} + +AND_REGEXES = { + language: re.compile( + r"|".join(r"(?<=\s)%s(?=\s)" % re.escape(u) for u in utterances), + re.IGNORECASE) + for language, utterances in iteritems(AND_UTTERANCES) +} + +MAX_ENTITY_VARIATIONS = 10 + + +def build_variated_query(string, ranges_and_utterances): + variated_string = "" + current_ix = 0 + for rng, u in ranges_and_utterances: + start = rng[START] + end = rng[END] + variated_string += string[current_ix:start] + variated_string += u + current_ix = end + variated_string += string[current_ix:] + return variated_string + + +def and_variations(string, language): + and_regex = AND_REGEXES.get(language, None) + if and_regex is None: + return set() + + matches = [m for m in and_regex.finditer(string)] + if not matches: + return set() + + matches = sorted(matches, key=lambda x: x.start()) + and_utterances = AND_UTTERANCES[language] + values = [({START: m.start(), END: m.end()}, and_utterances) + for m in matches] + + n_values = len(values) + n_and_utterances = len(and_utterances) + if n_and_utterances ** n_values > MAX_ENTITY_VARIATIONS: + return set() + + combinations = itertools.product(range(n_and_utterances), repeat=n_values) + variations = set() + for c in combinations: + ranges_and_utterances = [(values[i][0], values[i][1][ix]) + for i, ix in enumerate(c)] + variations.add(build_variated_query(string, ranges_and_utterances)) + return variations + + +def punctuation_variations(string, language): + matches = [m for m in get_punctuation_regex(language).finditer(string)] + if not matches: + return set() + + matches = sorted(matches, key=lambda x: x.start()) + values = [({START: m.start(), END: m.end()}, (m.group(0), "")) + for m in matches] + + n_values = len(values) + if 2 ** n_values > MAX_ENTITY_VARIATIONS: + return set() + + combinations = itertools.product(range(2), repeat=n_values) + variations = set() + for c in combinations: + ranges_and_utterances = [(values[i][0], values[i][1][ix]) + for i, ix in enumerate(c)] + variations.add(build_variated_query(string, ranges_and_utterances)) + return variations + + +def digit_value(number_entity): + value = number_entity[RESOLVED_VALUE][VALUE] + if value == int(value): + # Convert 24.0 into "24" instead of "24.0" + value = int(value) + return str(value) + + +def alphabetic_value(number_entity, language): + from num2words import num2words + + value = number_entity[RESOLVED_VALUE][VALUE] + if value != int(value): # num2words does not handle floats correctly + return None + return num2words(int(value), lang=language) + + +def numbers_variations(string, language, builtin_entity_parser): + if not supports_num2words(language): + return set() + + number_entities = builtin_entity_parser.parse( + string, scope=[SNIPS_NUMBER], use_cache=True) + + number_entities = sorted(number_entities, + key=lambda x: x[RES_MATCH_RANGE][START]) + if not number_entities: + return set() + + digit_values = [digit_value(e) for e in number_entities] + alpha_values = [alphabetic_value(e, language) for e in number_entities] + + values = [(n[RES_MATCH_RANGE], (d, a)) for (n, d, a) in + zip(number_entities, digit_values, alpha_values) + if a is not None] + + n_values = len(values) + if 2 ** n_values > MAX_ENTITY_VARIATIONS: + return set() + + combinations = itertools.product(range(2), repeat=n_values) + variations = set() + for c in combinations: + ranges_and_utterances = [(values[i][0], values[i][1][ix]) + for i, ix in enumerate(c)] + variations.add(build_variated_query(string, ranges_and_utterances)) + return variations + + +def case_variations(string): + return {string.lower(), string.title()} + + +def normalization_variations(string): + from snips_nlu_utils import normalize + + return {normalize(string)} + + +def flatten(results): + return set(i for r in results for i in r) + + +def get_string_variations(string, language, builtin_entity_parser, + numbers=True, case=True, and_=True, + punctuation=True): + variations = {string} + if case: + variations.update(flatten(case_variations(v) for v in variations)) + + variations.update(flatten(normalization_variations(v) for v in variations)) + # We re-generate case variations as normalization can produce new + # variations + if case: + variations.update(flatten(case_variations(v) for v in variations)) + if and_: + variations.update( + flatten(and_variations(v, language) for v in variations)) + if punctuation: + variations.update( + flatten(punctuation_variations(v, language) for v in variations)) + + # Special case of number variation which are long to generate due to the + # BuilinEntityParser running on each variation + if numbers: + variations.update( + flatten(numbers_variations(v, language, builtin_entity_parser) + for v in variations) + ) + + # Add single space variations + single_space_variations = set(" ".join(v.split()) for v in variations) + variations.update(single_space_variations) + # Add tokenized variations + tokenized_variations = set( + get_default_sep(language).join(tokenize_light(v, language)) for v in + variations) + variations.update(tokenized_variations) + return variations |