diff options
Diffstat (limited to 'snips_inference_agl/cli')
-rw-r--r-- | snips_inference_agl/cli/__init__.py | 39 | ||||
-rw-r--r-- | snips_inference_agl/cli/inference.py | 66 | ||||
-rw-r--r-- | snips_inference_agl/cli/utils.py | 79 | ||||
-rw-r--r-- | snips_inference_agl/cli/versions.py | 19 |
4 files changed, 203 insertions, 0 deletions
diff --git a/snips_inference_agl/cli/__init__.py b/snips_inference_agl/cli/__init__.py new file mode 100644 index 0000000..ccfbf18 --- /dev/null +++ b/snips_inference_agl/cli/__init__.py @@ -0,0 +1,39 @@ +import argparse + + +class Formatter(argparse.ArgumentDefaultsHelpFormatter): + def __init__(self, prog): + super(Formatter, self).__init__(prog, max_help_position=35, width=150) + + +def get_arg_parser(): + from snips_inference_agl.cli.inference import add_parse_parser + from snips_inference_agl.cli.versions import ( + add_version_parser, add_model_version_parser) + + arg_parser = argparse.ArgumentParser( + description="Snips NLU command line interface", + prog="python -m snips_nlu", formatter_class=Formatter) + arg_parser.add_argument("-v", "--version", action="store_true", + help="Print package version") + subparsers = arg_parser.add_subparsers( + title="available commands", metavar="command [options ...]") + add_parse_parser(subparsers, formatter_class=Formatter) + add_version_parser(subparsers, formatter_class=Formatter) + add_model_version_parser(subparsers, formatter_class=Formatter) + return arg_parser + + +def main(): + from snips_inference_agl.__about__ import __version__ + + arg_parser = get_arg_parser() + args = arg_parser.parse_args() + + if hasattr(args, "func"): + args.func(args) + elif "version" in args: + print(__version__) + else: + arg_parser.print_help() + exit(1) diff --git a/snips_inference_agl/cli/inference.py b/snips_inference_agl/cli/inference.py new file mode 100644 index 0000000..4ef5618 --- /dev/null +++ b/snips_inference_agl/cli/inference.py @@ -0,0 +1,66 @@ +from __future__ import unicode_literals, print_function + + +def add_parse_parser(subparsers, formatter_class): + subparser = subparsers.add_parser( + "parse", formatter_class=formatter_class, + help="Load a trained NLU engine and perform parsing") + subparser.add_argument("training_path", type=str, + help="Path to a trained engine") + subparser.add_argument("-q", "--query", type=str, + help="Query to parse. If provided, it disables the " + "interactive behavior.") + subparser.add_argument("-v", "--verbosity", action="count", default=0, + help="Increase output verbosity") + subparser.add_argument("-f", "--intents-filter", type=str, + help="Intents filter as a comma-separated list") + subparser.set_defaults(func=_parse) + return subparser + + +def _parse(args_namespace): + return parse(args_namespace.training_path, args_namespace.query, + args_namespace.verbosity, args_namespace.intents_filter) + + +def parse(training_path, query, verbose=False, intents_filter=None): + """Load a trained NLU engine and play with its parsing API interactively""" + import csv + import logging + from builtins import input, str + from snips_inference_agl import SnipsNLUEngine + from snips_inference_agl.cli.utils import set_nlu_logger + + if verbose == 1: + set_nlu_logger(logging.INFO) + elif verbose >= 2: + set_nlu_logger(logging.DEBUG) + if intents_filter: + # use csv in order to properly handle commas and other special + # characters in intent names + intents_filter = next(csv.reader([intents_filter])) + else: + intents_filter = None + + engine = SnipsNLUEngine.from_path(training_path) + + if query: + print_parsing_result(engine, query, intents_filter) + return + + while True: + query = input("Enter a query (type 'q' to quit): ").strip() + if not isinstance(query, str): + query = query.decode("utf-8") + if query == "q": + break + print_parsing_result(engine, query, intents_filter) + + +def print_parsing_result(engine, query, intents_filter): + from snips_inference_agl.common.utils import unicode_string, json_string + + query = unicode_string(query) + json_dump = json_string(engine.parse(query, intents_filter), + sort_keys=True, indent=2) + print(json_dump) diff --git a/snips_inference_agl/cli/utils.py b/snips_inference_agl/cli/utils.py new file mode 100644 index 0000000..0e5464c --- /dev/null +++ b/snips_inference_agl/cli/utils.py @@ -0,0 +1,79 @@ +from __future__ import print_function, unicode_literals + +import logging +import sys +from enum import Enum, unique + +import requests + +import snips_inference_agl +from snips_inference_agl import __about__ + + +@unique +class PrettyPrintLevel(Enum): + INFO = 0 + WARNING = 1 + ERROR = 2 + SUCCESS = 3 + + +FMT = "[%(levelname)s][%(asctime)s.%(msecs)03d][%(name)s]: %(message)s" +DATE_FMT = "%H:%M:%S" + + +def pretty_print(*texts, **kwargs): + """Print formatted message + + Args: + *texts (str): Texts to print. Each argument is rendered as paragraph. + **kwargs: 'title' becomes coloured headline. exits=True performs sys + exit. + """ + exits = kwargs.get("exits") + title = kwargs.get("title") + level = kwargs.get("level", PrettyPrintLevel.INFO) + title_color = _color_from_level(level) + if title: + title = "\033[{color}m{title}\033[0m\n".format(title=title, + color=title_color) + else: + title = "" + message = "\n\n".join([text for text in texts]) + print("\n{title}{message}\n".format(title=title, message=message)) + if exits is not None: + sys.exit(exits) + + +def _color_from_level(level): + if level == PrettyPrintLevel.INFO: + return "92" + if level == PrettyPrintLevel.WARNING: + return "93" + if level == PrettyPrintLevel.ERROR: + return "91" + if level == PrettyPrintLevel.SUCCESS: + return "92" + else: + raise ValueError("Unknown PrettyPrintLevel: %s" % level) + + +def get_json(url, desc): + r = requests.get(url) + if r.status_code != 200: + raise OSError("%s: Received status code %s when fetching the resource" + % (desc, r.status_code)) + return r.json() + + +def set_nlu_logger(level=logging.INFO): + logger = logging.getLogger(snips_inference_agl.__name__) + logger.setLevel(level) + + formatter = logging.Formatter(FMT, DATE_FMT) + + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(formatter) + handler.setLevel(level) + + logger.addHandler(handler) diff --git a/snips_inference_agl/cli/versions.py b/snips_inference_agl/cli/versions.py new file mode 100644 index 0000000..d7922ef --- /dev/null +++ b/snips_inference_agl/cli/versions.py @@ -0,0 +1,19 @@ +from __future__ import print_function + + +def add_version_parser(subparsers, formatter_class): + from snips_inference_agl.__about__ import __version__ + subparser = subparsers.add_parser( + "version", formatter_class=formatter_class, + help="Print the package version") + subparser.set_defaults(func=lambda _: print(__version__)) + return subparser + + +def add_model_version_parser(subparsers, formatter_class): + from snips_inference_agl.__about__ import __model_version__ + subparser = subparsers.add_parser( + "model-version", formatter_class=formatter_class, + help="Print the model version") + subparser.set_defaults(func=lambda _: print(__model_version__)) + return subparser |