aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/cli
diff options
context:
space:
mode:
Diffstat (limited to 'snips_inference_agl/cli')
-rw-r--r--snips_inference_agl/cli/__init__.py39
-rw-r--r--snips_inference_agl/cli/inference.py66
-rw-r--r--snips_inference_agl/cli/utils.py79
-rw-r--r--snips_inference_agl/cli/versions.py19
4 files changed, 203 insertions, 0 deletions
diff --git a/snips_inference_agl/cli/__init__.py b/snips_inference_agl/cli/__init__.py
new file mode 100644
index 0000000..ccfbf18
--- /dev/null
+++ b/snips_inference_agl/cli/__init__.py
@@ -0,0 +1,39 @@
+import argparse
+
+
+class Formatter(argparse.ArgumentDefaultsHelpFormatter):
+ def __init__(self, prog):
+ super(Formatter, self).__init__(prog, max_help_position=35, width=150)
+
+
+def get_arg_parser():
+ from snips_inference_agl.cli.inference import add_parse_parser
+ from snips_inference_agl.cli.versions import (
+ add_version_parser, add_model_version_parser)
+
+ arg_parser = argparse.ArgumentParser(
+ description="Snips NLU command line interface",
+ prog="python -m snips_nlu", formatter_class=Formatter)
+ arg_parser.add_argument("-v", "--version", action="store_true",
+ help="Print package version")
+ subparsers = arg_parser.add_subparsers(
+ title="available commands", metavar="command [options ...]")
+ add_parse_parser(subparsers, formatter_class=Formatter)
+ add_version_parser(subparsers, formatter_class=Formatter)
+ add_model_version_parser(subparsers, formatter_class=Formatter)
+ return arg_parser
+
+
+def main():
+ from snips_inference_agl.__about__ import __version__
+
+ arg_parser = get_arg_parser()
+ args = arg_parser.parse_args()
+
+ if hasattr(args, "func"):
+ args.func(args)
+ elif "version" in args:
+ print(__version__)
+ else:
+ arg_parser.print_help()
+ exit(1)
diff --git a/snips_inference_agl/cli/inference.py b/snips_inference_agl/cli/inference.py
new file mode 100644
index 0000000..4ef5618
--- /dev/null
+++ b/snips_inference_agl/cli/inference.py
@@ -0,0 +1,66 @@
+from __future__ import unicode_literals, print_function
+
+
+def add_parse_parser(subparsers, formatter_class):
+ subparser = subparsers.add_parser(
+ "parse", formatter_class=formatter_class,
+ help="Load a trained NLU engine and perform parsing")
+ subparser.add_argument("training_path", type=str,
+ help="Path to a trained engine")
+ subparser.add_argument("-q", "--query", type=str,
+ help="Query to parse. If provided, it disables the "
+ "interactive behavior.")
+ subparser.add_argument("-v", "--verbosity", action="count", default=0,
+ help="Increase output verbosity")
+ subparser.add_argument("-f", "--intents-filter", type=str,
+ help="Intents filter as a comma-separated list")
+ subparser.set_defaults(func=_parse)
+ return subparser
+
+
+def _parse(args_namespace):
+ return parse(args_namespace.training_path, args_namespace.query,
+ args_namespace.verbosity, args_namespace.intents_filter)
+
+
+def parse(training_path, query, verbose=False, intents_filter=None):
+ """Load a trained NLU engine and play with its parsing API interactively"""
+ import csv
+ import logging
+ from builtins import input, str
+ from snips_inference_agl import SnipsNLUEngine
+ from snips_inference_agl.cli.utils import set_nlu_logger
+
+ if verbose == 1:
+ set_nlu_logger(logging.INFO)
+ elif verbose >= 2:
+ set_nlu_logger(logging.DEBUG)
+ if intents_filter:
+ # use csv in order to properly handle commas and other special
+ # characters in intent names
+ intents_filter = next(csv.reader([intents_filter]))
+ else:
+ intents_filter = None
+
+ engine = SnipsNLUEngine.from_path(training_path)
+
+ if query:
+ print_parsing_result(engine, query, intents_filter)
+ return
+
+ while True:
+ query = input("Enter a query (type 'q' to quit): ").strip()
+ if not isinstance(query, str):
+ query = query.decode("utf-8")
+ if query == "q":
+ break
+ print_parsing_result(engine, query, intents_filter)
+
+
+def print_parsing_result(engine, query, intents_filter):
+ from snips_inference_agl.common.utils import unicode_string, json_string
+
+ query = unicode_string(query)
+ json_dump = json_string(engine.parse(query, intents_filter),
+ sort_keys=True, indent=2)
+ print(json_dump)
diff --git a/snips_inference_agl/cli/utils.py b/snips_inference_agl/cli/utils.py
new file mode 100644
index 0000000..0e5464c
--- /dev/null
+++ b/snips_inference_agl/cli/utils.py
@@ -0,0 +1,79 @@
+from __future__ import print_function, unicode_literals
+
+import logging
+import sys
+from enum import Enum, unique
+
+import requests
+
+import snips_inference_agl
+from snips_inference_agl import __about__
+
+
+@unique
+class PrettyPrintLevel(Enum):
+ INFO = 0
+ WARNING = 1
+ ERROR = 2
+ SUCCESS = 3
+
+
+FMT = "[%(levelname)s][%(asctime)s.%(msecs)03d][%(name)s]: %(message)s"
+DATE_FMT = "%H:%M:%S"
+
+
+def pretty_print(*texts, **kwargs):
+ """Print formatted message
+
+ Args:
+ *texts (str): Texts to print. Each argument is rendered as paragraph.
+ **kwargs: 'title' becomes coloured headline. exits=True performs sys
+ exit.
+ """
+ exits = kwargs.get("exits")
+ title = kwargs.get("title")
+ level = kwargs.get("level", PrettyPrintLevel.INFO)
+ title_color = _color_from_level(level)
+ if title:
+ title = "\033[{color}m{title}\033[0m\n".format(title=title,
+ color=title_color)
+ else:
+ title = ""
+ message = "\n\n".join([text for text in texts])
+ print("\n{title}{message}\n".format(title=title, message=message))
+ if exits is not None:
+ sys.exit(exits)
+
+
+def _color_from_level(level):
+ if level == PrettyPrintLevel.INFO:
+ return "92"
+ if level == PrettyPrintLevel.WARNING:
+ return "93"
+ if level == PrettyPrintLevel.ERROR:
+ return "91"
+ if level == PrettyPrintLevel.SUCCESS:
+ return "92"
+ else:
+ raise ValueError("Unknown PrettyPrintLevel: %s" % level)
+
+
+def get_json(url, desc):
+ r = requests.get(url)
+ if r.status_code != 200:
+ raise OSError("%s: Received status code %s when fetching the resource"
+ % (desc, r.status_code))
+ return r.json()
+
+
+def set_nlu_logger(level=logging.INFO):
+ logger = logging.getLogger(snips_inference_agl.__name__)
+ logger.setLevel(level)
+
+ formatter = logging.Formatter(FMT, DATE_FMT)
+
+ handler = logging.StreamHandler(sys.stdout)
+ handler.setFormatter(formatter)
+ handler.setLevel(level)
+
+ logger.addHandler(handler)
diff --git a/snips_inference_agl/cli/versions.py b/snips_inference_agl/cli/versions.py
new file mode 100644
index 0000000..d7922ef
--- /dev/null
+++ b/snips_inference_agl/cli/versions.py
@@ -0,0 +1,19 @@
+from __future__ import print_function
+
+
+def add_version_parser(subparsers, formatter_class):
+ from snips_inference_agl.__about__ import __version__
+ subparser = subparsers.add_parser(
+ "version", formatter_class=formatter_class,
+ help="Print the package version")
+ subparser.set_defaults(func=lambda _: print(__version__))
+ return subparser
+
+
+def add_model_version_parser(subparsers, formatter_class):
+ from snips_inference_agl.__about__ import __model_version__
+ subparser = subparsers.add_parser(
+ "model-version", formatter_class=formatter_class,
+ help="Print the model version")
+ subparser.set_defaults(func=lambda _: print(__model_version__))
+ return subparser