aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/common
diff options
context:
space:
mode:
Diffstat (limited to 'snips_inference_agl/common')
-rw-r--r--snips_inference_agl/common/__init__.py0
-rw-r--r--snips_inference_agl/common/abc_utils.py36
-rw-r--r--snips_inference_agl/common/dataset_utils.py48
-rw-r--r--snips_inference_agl/common/dict_utils.py36
-rw-r--r--snips_inference_agl/common/from_dict.py30
-rw-r--r--snips_inference_agl/common/io_utils.py36
-rw-r--r--snips_inference_agl/common/log_utils.py61
-rw-r--r--snips_inference_agl/common/registrable.py73
-rw-r--r--snips_inference_agl/common/utils.py239
9 files changed, 559 insertions, 0 deletions
diff --git a/snips_inference_agl/common/__init__.py b/snips_inference_agl/common/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/snips_inference_agl/common/__init__.py
diff --git a/snips_inference_agl/common/abc_utils.py b/snips_inference_agl/common/abc_utils.py
new file mode 100644
index 0000000..db7b933
--- /dev/null
+++ b/snips_inference_agl/common/abc_utils.py
@@ -0,0 +1,36 @@
+class abstractclassmethod(classmethod): # pylint: disable=invalid-name
+ __isabstractmethod__ = True
+
+ def __init__(self, callable):
+ callable.__isabstractmethod__ = True
+ super(abstractclassmethod, self).__init__(callable)
+
+
+class ClassPropertyDescriptor(object):
+ def __init__(self, fget, fset=None):
+ self.fget = fget
+ self.fset = fset
+
+ def __get__(self, obj, klass=None):
+ if klass is None:
+ klass = type(obj)
+ return self.fget.__get__(obj, klass)()
+
+ def __set__(self, obj, value):
+ if not self.fset:
+ raise AttributeError("can't set attribute")
+ type_ = type(obj)
+ return self.fset.__get__(obj, type_)(value)
+
+ def setter(self, func):
+ if not isinstance(func, (classmethod, staticmethod)):
+ func = classmethod(func)
+ self.fset = func
+ return self
+
+
+def classproperty(func):
+ if not isinstance(func, (classmethod, staticmethod)):
+ func = classmethod(func)
+
+ return ClassPropertyDescriptor(func)
diff --git a/snips_inference_agl/common/dataset_utils.py b/snips_inference_agl/common/dataset_utils.py
new file mode 100644
index 0000000..34648e6
--- /dev/null
+++ b/snips_inference_agl/common/dataset_utils.py
@@ -0,0 +1,48 @@
+from snips_inference_agl.constants import INTENTS, UTTERANCES, DATA, SLOT_NAME, ENTITY
+from snips_inference_agl.exceptions import DatasetFormatError
+
+
+def type_error(expected_type, found_type, object_label=None):
+ if object_label is None:
+ raise DatasetFormatError("Invalid type: expected %s but found %s"
+ % (expected_type, found_type))
+ raise DatasetFormatError("Invalid type for '%s': expected %s but found %s"
+ % (object_label, expected_type, found_type))
+
+
+def validate_type(obj, expected_type, object_label=None):
+ if not isinstance(obj, expected_type):
+ type_error(expected_type, type(obj), object_label)
+
+
+def missing_key_error(key, object_label=None):
+ if object_label is None:
+ raise DatasetFormatError("Missing key: '%s'" % key)
+ raise DatasetFormatError("Expected %s to have key: '%s'"
+ % (object_label, key))
+
+
+def validate_key(obj, key, object_label=None):
+ if key not in obj:
+ missing_key_error(key, object_label)
+
+
+def validate_keys(obj, keys, object_label=None):
+ for key in keys:
+ validate_key(obj, key, object_label)
+
+
+def get_slot_name_mapping(dataset, intent):
+ """Returns a dict which maps slot names to entities for the provided intent
+ """
+ slot_name_mapping = dict()
+ for utterance in dataset[INTENTS][intent][UTTERANCES]:
+ for chunk in utterance[DATA]:
+ if SLOT_NAME in chunk:
+ slot_name_mapping[chunk[SLOT_NAME]] = chunk[ENTITY]
+ return slot_name_mapping
+
+def get_slot_name_mappings(dataset):
+ """Returns a dict which maps intents to their slot name mapping"""
+ return {intent: get_slot_name_mapping(dataset, intent)
+ for intent in dataset[INTENTS]} \ No newline at end of file
diff --git a/snips_inference_agl/common/dict_utils.py b/snips_inference_agl/common/dict_utils.py
new file mode 100644
index 0000000..a70b217
--- /dev/null
+++ b/snips_inference_agl/common/dict_utils.py
@@ -0,0 +1,36 @@
+from collections import OrderedDict
+
+
+class LimitedSizeDict(OrderedDict):
+ def __init__(self, *args, **kwds):
+ if "size_limit" not in kwds:
+ raise ValueError("'size_limit' must be passed as a keyword "
+ "argument")
+ self.size_limit = kwds.pop("size_limit")
+ if len(args) > 1:
+ raise TypeError('expected at most 1 arguments, got %d' % len(args))
+ if len(args) == 1 and len(args[0]) + len(kwds) > self.size_limit:
+ raise ValueError("Tried to initialize LimitedSizedDict with more "
+ "value than permitted with 'limit_size'")
+ super(LimitedSizeDict, self).__init__(*args, **kwds)
+
+ def __setitem__(self, key, value, dict_setitem=OrderedDict.__setitem__):
+ dict_setitem(self, key, value)
+ self._check_size_limit()
+
+ def _check_size_limit(self):
+ if self.size_limit is not None:
+ while len(self) > self.size_limit:
+ self.popitem(last=False)
+
+ def __eq__(self, other):
+ if self.size_limit != other.size_limit:
+ return False
+ return super(LimitedSizeDict, self).__eq__(other)
+
+
+class UnupdatableDict(dict):
+ def __setitem__(self, key, value):
+ if key in self:
+ raise KeyError("Can't update key '%s'" % key)
+ super(UnupdatableDict, self).__setitem__(key, value)
diff --git a/snips_inference_agl/common/from_dict.py b/snips_inference_agl/common/from_dict.py
new file mode 100644
index 0000000..2b776a6
--- /dev/null
+++ b/snips_inference_agl/common/from_dict.py
@@ -0,0 +1,30 @@
+try:
+ import funcsigs as inspect
+except ImportError:
+ import inspect
+
+from future.utils import iteritems
+
+KEYWORD_KINDS = {inspect.Parameter.POSITIONAL_OR_KEYWORD,
+ inspect.Parameter.KEYWORD_ONLY}
+
+
+class FromDict(object):
+ @classmethod
+ def from_dict(cls, dict):
+ if dict is None:
+ return cls()
+ params = inspect.signature(cls.__init__).parameters
+
+ if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in
+ params.values()):
+ return cls(**dict)
+
+ param_names = set()
+ for i, (name, param) in enumerate(iteritems(params)):
+ if i == 0 and name == "self":
+ continue
+ if param.kind in KEYWORD_KINDS:
+ param_names.add(name)
+ filtered_dict = {k: v for k, v in iteritems(dict) if k in param_names}
+ return cls(**filtered_dict)
diff --git a/snips_inference_agl/common/io_utils.py b/snips_inference_agl/common/io_utils.py
new file mode 100644
index 0000000..f4407a3
--- /dev/null
+++ b/snips_inference_agl/common/io_utils.py
@@ -0,0 +1,36 @@
+import errno
+import os
+import shutil
+from contextlib import contextmanager
+from pathlib import Path
+from tempfile import mkdtemp
+from zipfile import ZipFile, ZIP_DEFLATED
+
+
+def mkdir_p(path):
+ """Reproduces the 'mkdir -p shell' command
+
+ See
+ http://stackoverflow.com/questions/600268/mkdir-p-functionality-in-python
+ """
+ try:
+ os.makedirs(str(path))
+ except OSError as exc: # Python >2.5
+ if exc.errno == errno.EEXIST and path.is_dir():
+ pass
+ else:
+ raise
+
+
+@contextmanager
+def temp_dir():
+ tmp_dir = mkdtemp()
+ try:
+ yield Path(tmp_dir)
+ finally:
+ shutil.rmtree(tmp_dir)
+
+
+def unzip_archive(archive_file, destination_dir):
+ with ZipFile(archive_file, "r", ZIP_DEFLATED) as zipf:
+ zipf.extractall(str(destination_dir))
diff --git a/snips_inference_agl/common/log_utils.py b/snips_inference_agl/common/log_utils.py
new file mode 100644
index 0000000..47da34e
--- /dev/null
+++ b/snips_inference_agl/common/log_utils.py
@@ -0,0 +1,61 @@
+from __future__ import unicode_literals
+
+from builtins import str
+from datetime import datetime
+from functools import wraps
+
+from snips_inference_agl.common.utils import json_debug_string
+
+
+class DifferedLoggingMessage(object):
+
+ def __init__(self, fn, *args, **kwargs):
+ self.fn = fn
+ self.args = args
+ self.kwargs = kwargs
+
+ def __str__(self):
+ return str(self.fn(*self.args, **self.kwargs))
+
+
+def log_elapsed_time(logger, level, output_msg=None):
+ if output_msg is None:
+ output_msg = "Elapsed time ->:\n{elapsed_time}"
+
+ def get_wrapper(fn):
+ @wraps(fn)
+ def wrapped(*args, **kwargs):
+ start = datetime.now()
+ msg_fmt = dict()
+ res = fn(*args, **kwargs)
+ if "elapsed_time" in output_msg:
+ msg_fmt["elapsed_time"] = datetime.now() - start
+ logger.log(level, output_msg.format(**msg_fmt))
+ return res
+
+ return wrapped
+
+ return get_wrapper
+
+
+def log_result(logger, level, output_msg=None):
+ if output_msg is None:
+ output_msg = "Result ->:\n{result}"
+
+ def get_wrapper(fn):
+ @wraps(fn)
+ def wrapped(*args, **kwargs):
+ msg_fmt = dict()
+ res = fn(*args, **kwargs)
+ if "result" in output_msg:
+ try:
+ res_debug_string = json_debug_string(res)
+ except TypeError:
+ res_debug_string = str(res)
+ msg_fmt["result"] = res_debug_string
+ logger.log(level, output_msg.format(**msg_fmt))
+ return res
+
+ return wrapped
+
+ return get_wrapper
diff --git a/snips_inference_agl/common/registrable.py b/snips_inference_agl/common/registrable.py
new file mode 100644
index 0000000..bbc7bdc
--- /dev/null
+++ b/snips_inference_agl/common/registrable.py
@@ -0,0 +1,73 @@
+# This module is largely inspired by the AllenNLP library
+# See github.com/allenai/allennlp/blob/master/allennlp/common/registrable.py
+
+from collections import defaultdict
+from future.utils import iteritems
+
+from snips_inference_agl.exceptions import AlreadyRegisteredError, NotRegisteredError
+
+
+class Registrable(object):
+ """
+ Any class that inherits from ``Registrable`` gains access to a named
+ registry for its subclasses. To register them, just decorate them with the
+ classmethod ``@BaseClass.register(name)``.
+
+ After which you can call ``BaseClass.list_available()`` to get the keys
+ for the registered subclasses, and ``BaseClass.by_name(name)`` to get the
+ corresponding subclass.
+
+ Note that if you use this class to implement a new ``Registrable``
+ abstract class, you must ensure that all subclasses of the abstract class
+ are loaded when the module is loaded, because the subclasses register
+ themselves in their respective files. You can achieve this by having the
+ abstract class and all subclasses in the __init__.py of the module in
+ which they reside (as this causes any import of either the abstract class
+ or a subclass to load all other subclasses and the abstract class).
+ """
+ _registry = defaultdict(dict)
+
+ @classmethod
+ def register(cls, name, override=False):
+ """Decorator used to add the decorated subclass to the registry of the
+ base class
+
+ Args:
+ name (str): name use to identify the registered subclass
+ override (bool, optional): this parameter controls the behavior in
+ case where a subclass is registered with the same identifier.
+ If True, then the previous subclass will be unregistered in
+ profit of the new subclass.
+
+ Raises:
+ AlreadyRegisteredError: when ``override`` is False, while trying
+ to register a subclass with a name already used by another
+ registered subclass
+ """
+ registry = Registrable._registry[cls]
+
+ def add_subclass_to_registry(subclass):
+ # Add to registry, raise an error if key has already been used.
+ if not override and name in registry:
+ raise AlreadyRegisteredError(name, cls, registry[name])
+ registry[name] = subclass
+ return subclass
+
+ return add_subclass_to_registry
+
+ @classmethod
+ def registered_name(cls, registered_class):
+ for name, subclass in iteritems(Registrable._registry[cls]):
+ if subclass == registered_class:
+ return name
+ raise NotRegisteredError(cls, registered_cls=registered_class)
+
+ @classmethod
+ def by_name(cls, name):
+ if name not in Registrable._registry[cls]:
+ raise NotRegisteredError(cls, name=name)
+ return Registrable._registry[cls][name]
+
+ @classmethod
+ def list_available(cls):
+ return list(Registrable._registry[cls].keys())
diff --git a/snips_inference_agl/common/utils.py b/snips_inference_agl/common/utils.py
new file mode 100644
index 0000000..816dae7
--- /dev/null
+++ b/snips_inference_agl/common/utils.py
@@ -0,0 +1,239 @@
+from __future__ import unicode_literals
+
+import importlib
+import json
+import numbers
+import re
+from builtins import bytes as newbytes, str as newstr
+from datetime import datetime
+from functools import wraps
+from pathlib import Path
+
+from future.utils import text_type
+
+from snips_inference_agl.constants import (END, ENTITY_KIND, RES_MATCH_RANGE, RES_VALUE,
+ START)
+from snips_inference_agl.exceptions import NotTrained, PersistingError
+
+REGEX_PUNCT = {'\\', '.', '+', '*', '?', '(', ')', '|', '[', ']', '{', '}',
+ '^', '$', '#', '&', '-', '~'}
+
+
+# pylint:disable=line-too-long
+def regex_escape(s):
+ """Escapes all regular expression meta characters in *s*
+
+ The string returned may be safely used as a literal in a regular
+ expression.
+
+ This function is more precise than :func:`re.escape`, the latter escapes
+ all non-alphanumeric characters which can cause cross-platform
+ compatibility issues.
+
+ References:
+
+ - https://github.com/rust-lang/regex/blob/master/regex-syntax/src/lib.rs#L1685
+ - https://github.com/rust-lang/regex/blob/master/regex-syntax/src/parser.rs#L1378
+ """
+ escaped_string = ""
+ for c in s:
+ if c in REGEX_PUNCT:
+ escaped_string += "\\"
+ escaped_string += c
+ return escaped_string
+
+
+# pylint:enable=line-too-long
+
+
+def check_random_state(seed):
+ """Turn seed into a :class:`numpy.random.RandomState` instance
+
+ If seed is None, return the RandomState singleton used by np.random.
+ If seed is an int, return a new RandomState instance seeded with seed.
+ If seed is already a RandomState instance, return it.
+ Otherwise raise ValueError.
+ """
+ import numpy as np
+
+ # pylint: disable=W0212
+ # pylint: disable=c-extension-no-member
+ if seed is None or seed is np.random:
+ return np.random.mtrand._rand # pylint: disable=c-extension-no-member
+ if isinstance(seed, (numbers.Integral, np.integer)):
+ return np.random.RandomState(seed)
+ if isinstance(seed, np.random.RandomState):
+ return seed
+ raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
+ ' instance' % seed)
+
+
+def ranges_overlap(lhs_range, rhs_range):
+ if isinstance(lhs_range, dict) and isinstance(rhs_range, dict):
+ return lhs_range[END] > rhs_range[START] \
+ and lhs_range[START] < rhs_range[END]
+ elif isinstance(lhs_range, (tuple, list)) \
+ and isinstance(rhs_range, (tuple, list)):
+ return lhs_range[1] > rhs_range[0] and lhs_range[0] < rhs_range[1]
+ else:
+ raise TypeError("Cannot check overlap on objects of type: %s and %s"
+ % (type(lhs_range), type(rhs_range)))
+
+
+def elapsed_since(time):
+ return datetime.now() - time
+
+
+def json_debug_string(dict_data):
+ return json.dumps(dict_data, ensure_ascii=False, indent=2, sort_keys=True)
+
+def json_string(json_object, indent=2, sort_keys=True):
+ json_dump = json.dumps(json_object, indent=indent, sort_keys=sort_keys,
+ separators=(',', ': '))
+ return unicode_string(json_dump)
+
+def unicode_string(string):
+ if isinstance(string, text_type):
+ return string
+ if isinstance(string, bytes):
+ return string.decode("utf8")
+ if isinstance(string, newstr):
+ return text_type(string)
+ if isinstance(string, newbytes):
+ string = bytes(string).decode("utf8")
+
+ raise TypeError("Cannot convert %s into unicode string" % type(string))
+
+
+def check_persisted_path(func):
+ @wraps(func)
+ def func_wrapper(self, path, *args, **kwargs):
+ path = Path(path)
+ if path.exists():
+ raise PersistingError(path)
+ return func(self, path, *args, **kwargs)
+
+ return func_wrapper
+
+
+def fitted_required(func):
+ @wraps(func)
+ def func_wrapper(self, *args, **kwargs):
+ if not self.fitted:
+ raise NotTrained("%s must be fitted" % self.unit_name)
+ return func(self, *args, **kwargs)
+
+ return func_wrapper
+
+
+def is_package(name):
+ """Check if name maps to a package installed via pip.
+
+ Args:
+ name (str): Name of package
+
+ Returns:
+ bool: True if an installed packaged corresponds to this name, False
+ otherwise.
+ """
+ import pkg_resources
+
+ name = name.lower().replace("-", "_")
+ packages = pkg_resources.working_set.by_key.keys()
+ for package in packages:
+ if package.lower().replace("-", "_") == name:
+ return True
+ return False
+
+
+def get_package_path(name):
+ """Get the path to an installed package.
+
+ Args:
+ name (str): Package name
+
+ Returns:
+ class:`.Path`: Path to the installed package
+ """
+ name = name.lower().replace("-", "_")
+ pkg = importlib.import_module(name)
+ return Path(pkg.__file__).parent
+
+
+def deduplicate_overlapping_items(items, overlap_fn, sort_key_fn):
+ """Deduplicates the items by looping over the items, sorted using
+ sort_key_fn, and checking overlaps with previously seen items using
+ overlap_fn
+ """
+ sorted_items = sorted(items, key=sort_key_fn)
+ deduplicated_items = []
+ for item in sorted_items:
+ if not any(overlap_fn(item, dedup_item)
+ for dedup_item in deduplicated_items):
+ deduplicated_items.append(item)
+ return deduplicated_items
+
+
+def replace_entities_with_placeholders(text, entities, placeholder_fn):
+ """Processes the text in order to replace entity values with placeholders
+ as defined by the placeholder function
+ """
+ if not entities:
+ return dict(), text
+
+ entities = deduplicate_overlapping_entities(entities)
+ entities = sorted(
+ entities, key=lambda e: e[RES_MATCH_RANGE][START])
+
+ range_mapping = dict()
+ processed_text = ""
+ offset = 0
+ current_ix = 0
+ for ent in entities:
+ ent_start = ent[RES_MATCH_RANGE][START]
+ ent_end = ent[RES_MATCH_RANGE][END]
+ rng_start = ent_start + offset
+
+ processed_text += text[current_ix:ent_start]
+
+ entity_length = ent_end - ent_start
+ entity_place_holder = placeholder_fn(ent[ENTITY_KIND])
+
+ offset += len(entity_place_holder) - entity_length
+
+ processed_text += entity_place_holder
+ rng_end = ent_end + offset
+ new_range = (rng_start, rng_end)
+ range_mapping[new_range] = ent[RES_MATCH_RANGE]
+ current_ix = ent_end
+
+ processed_text += text[current_ix:]
+ return range_mapping, processed_text
+
+
+def deduplicate_overlapping_entities(entities):
+ """Deduplicates entities based on overlapping ranges"""
+
+ def overlap(lhs_entity, rhs_entity):
+ return ranges_overlap(lhs_entity[RES_MATCH_RANGE],
+ rhs_entity[RES_MATCH_RANGE])
+
+ def sort_key_fn(entity):
+ return -len(entity[RES_VALUE])
+
+ deduplicated_entities = deduplicate_overlapping_items(
+ entities, overlap, sort_key_fn)
+ return sorted(deduplicated_entities,
+ key=lambda entity: entity[RES_MATCH_RANGE][START])
+
+
+SEMVER_PATTERN = r"^(?P<major>0|[1-9]\d*)" \
+ r".(?P<minor>0|[1-9]\d*)" \
+ r".(?P<patch>0|[1-9]\d*)" \
+ r"(?:.(?P<subpatch>0|[1-9]\d*))?" \
+ r"(?:-(?P<prerelease>(?:0|[1-9]\d*|\d*[a-zA-Z-]" \
+ r"[0-9a-zA-Z-]*)" \
+ r"(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?" \
+ r"(?:\+(?P<buildmetadata>[0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*)" \
+ r")?$"
+SEMVER_REGEX = re.compile(SEMVER_PATTERN)