aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/cli/inference.py
diff options
context:
space:
mode:
Diffstat (limited to 'snips_inference_agl/cli/inference.py')
-rw-r--r--snips_inference_agl/cli/inference.py66
1 files changed, 66 insertions, 0 deletions
diff --git a/snips_inference_agl/cli/inference.py b/snips_inference_agl/cli/inference.py
new file mode 100644
index 0000000..4ef5618
--- /dev/null
+++ b/snips_inference_agl/cli/inference.py
@@ -0,0 +1,66 @@
+from __future__ import unicode_literals, print_function
+
+
+def add_parse_parser(subparsers, formatter_class):
+ subparser = subparsers.add_parser(
+ "parse", formatter_class=formatter_class,
+ help="Load a trained NLU engine and perform parsing")
+ subparser.add_argument("training_path", type=str,
+ help="Path to a trained engine")
+ subparser.add_argument("-q", "--query", type=str,
+ help="Query to parse. If provided, it disables the "
+ "interactive behavior.")
+ subparser.add_argument("-v", "--verbosity", action="count", default=0,
+ help="Increase output verbosity")
+ subparser.add_argument("-f", "--intents-filter", type=str,
+ help="Intents filter as a comma-separated list")
+ subparser.set_defaults(func=_parse)
+ return subparser
+
+
+def _parse(args_namespace):
+ return parse(args_namespace.training_path, args_namespace.query,
+ args_namespace.verbosity, args_namespace.intents_filter)
+
+
+def parse(training_path, query, verbose=False, intents_filter=None):
+ """Load a trained NLU engine and play with its parsing API interactively"""
+ import csv
+ import logging
+ from builtins import input, str
+ from snips_inference_agl import SnipsNLUEngine
+ from snips_inference_agl.cli.utils import set_nlu_logger
+
+ if verbose == 1:
+ set_nlu_logger(logging.INFO)
+ elif verbose >= 2:
+ set_nlu_logger(logging.DEBUG)
+ if intents_filter:
+ # use csv in order to properly handle commas and other special
+ # characters in intent names
+ intents_filter = next(csv.reader([intents_filter]))
+ else:
+ intents_filter = None
+
+ engine = SnipsNLUEngine.from_path(training_path)
+
+ if query:
+ print_parsing_result(engine, query, intents_filter)
+ return
+
+ while True:
+ query = input("Enter a query (type 'q' to quit): ").strip()
+ if not isinstance(query, str):
+ query = query.decode("utf-8")
+ if query == "q":
+ break
+ print_parsing_result(engine, query, intents_filter)
+
+
+def print_parsing_result(engine, query, intents_filter):
+ from snips_inference_agl.common.utils import unicode_string, json_string
+
+ query = unicode_string(query)
+ json_dump = json_string(engine.parse(query, intents_filter),
+ sort_keys=True, indent=2)
+ print(json_dump)