diff options
Diffstat (limited to 'snips_inference_agl/common')
-rw-r--r-- | snips_inference_agl/common/__init__.py | 0 | ||||
-rw-r--r-- | snips_inference_agl/common/abc_utils.py | 36 | ||||
-rw-r--r-- | snips_inference_agl/common/dataset_utils.py | 48 | ||||
-rw-r--r-- | snips_inference_agl/common/dict_utils.py | 36 | ||||
-rw-r--r-- | snips_inference_agl/common/from_dict.py | 30 | ||||
-rw-r--r-- | snips_inference_agl/common/io_utils.py | 36 | ||||
-rw-r--r-- | snips_inference_agl/common/log_utils.py | 61 | ||||
-rw-r--r-- | snips_inference_agl/common/registrable.py | 73 | ||||
-rw-r--r-- | snips_inference_agl/common/utils.py | 239 |
9 files changed, 559 insertions, 0 deletions
diff --git a/snips_inference_agl/common/__init__.py b/snips_inference_agl/common/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/snips_inference_agl/common/__init__.py diff --git a/snips_inference_agl/common/abc_utils.py b/snips_inference_agl/common/abc_utils.py new file mode 100644 index 0000000..db7b933 --- /dev/null +++ b/snips_inference_agl/common/abc_utils.py @@ -0,0 +1,36 @@ +class abstractclassmethod(classmethod): # pylint: disable=invalid-name + __isabstractmethod__ = True + + def __init__(self, callable): + callable.__isabstractmethod__ = True + super(abstractclassmethod, self).__init__(callable) + + +class ClassPropertyDescriptor(object): + def __init__(self, fget, fset=None): + self.fget = fget + self.fset = fset + + def __get__(self, obj, klass=None): + if klass is None: + klass = type(obj) + return self.fget.__get__(obj, klass)() + + def __set__(self, obj, value): + if not self.fset: + raise AttributeError("can't set attribute") + type_ = type(obj) + return self.fset.__get__(obj, type_)(value) + + def setter(self, func): + if not isinstance(func, (classmethod, staticmethod)): + func = classmethod(func) + self.fset = func + return self + + +def classproperty(func): + if not isinstance(func, (classmethod, staticmethod)): + func = classmethod(func) + + return ClassPropertyDescriptor(func) diff --git a/snips_inference_agl/common/dataset_utils.py b/snips_inference_agl/common/dataset_utils.py new file mode 100644 index 0000000..34648e6 --- /dev/null +++ b/snips_inference_agl/common/dataset_utils.py @@ -0,0 +1,48 @@ +from snips_inference_agl.constants import INTENTS, UTTERANCES, DATA, SLOT_NAME, ENTITY +from snips_inference_agl.exceptions import DatasetFormatError + + +def type_error(expected_type, found_type, object_label=None): + if object_label is None: + raise DatasetFormatError("Invalid type: expected %s but found %s" + % (expected_type, found_type)) + raise DatasetFormatError("Invalid type for '%s': expected %s but found %s" + % (object_label, expected_type, found_type)) + + +def validate_type(obj, expected_type, object_label=None): + if not isinstance(obj, expected_type): + type_error(expected_type, type(obj), object_label) + + +def missing_key_error(key, object_label=None): + if object_label is None: + raise DatasetFormatError("Missing key: '%s'" % key) + raise DatasetFormatError("Expected %s to have key: '%s'" + % (object_label, key)) + + +def validate_key(obj, key, object_label=None): + if key not in obj: + missing_key_error(key, object_label) + + +def validate_keys(obj, keys, object_label=None): + for key in keys: + validate_key(obj, key, object_label) + + +def get_slot_name_mapping(dataset, intent): + """Returns a dict which maps slot names to entities for the provided intent + """ + slot_name_mapping = dict() + for utterance in dataset[INTENTS][intent][UTTERANCES]: + for chunk in utterance[DATA]: + if SLOT_NAME in chunk: + slot_name_mapping[chunk[SLOT_NAME]] = chunk[ENTITY] + return slot_name_mapping + +def get_slot_name_mappings(dataset): + """Returns a dict which maps intents to their slot name mapping""" + return {intent: get_slot_name_mapping(dataset, intent) + for intent in dataset[INTENTS]}
\ No newline at end of file diff --git a/snips_inference_agl/common/dict_utils.py b/snips_inference_agl/common/dict_utils.py new file mode 100644 index 0000000..a70b217 --- /dev/null +++ b/snips_inference_agl/common/dict_utils.py @@ -0,0 +1,36 @@ +from collections import OrderedDict + + +class LimitedSizeDict(OrderedDict): + def __init__(self, *args, **kwds): + if "size_limit" not in kwds: + raise ValueError("'size_limit' must be passed as a keyword " + "argument") + self.size_limit = kwds.pop("size_limit") + if len(args) > 1: + raise TypeError('expected at most 1 arguments, got %d' % len(args)) + if len(args) == 1 and len(args[0]) + len(kwds) > self.size_limit: + raise ValueError("Tried to initialize LimitedSizedDict with more " + "value than permitted with 'limit_size'") + super(LimitedSizeDict, self).__init__(*args, **kwds) + + def __setitem__(self, key, value, dict_setitem=OrderedDict.__setitem__): + dict_setitem(self, key, value) + self._check_size_limit() + + def _check_size_limit(self): + if self.size_limit is not None: + while len(self) > self.size_limit: + self.popitem(last=False) + + def __eq__(self, other): + if self.size_limit != other.size_limit: + return False + return super(LimitedSizeDict, self).__eq__(other) + + +class UnupdatableDict(dict): + def __setitem__(self, key, value): + if key in self: + raise KeyError("Can't update key '%s'" % key) + super(UnupdatableDict, self).__setitem__(key, value) diff --git a/snips_inference_agl/common/from_dict.py b/snips_inference_agl/common/from_dict.py new file mode 100644 index 0000000..2b776a6 --- /dev/null +++ b/snips_inference_agl/common/from_dict.py @@ -0,0 +1,30 @@ +try: + import funcsigs as inspect +except ImportError: + import inspect + +from future.utils import iteritems + +KEYWORD_KINDS = {inspect.Parameter.POSITIONAL_OR_KEYWORD, + inspect.Parameter.KEYWORD_ONLY} + + +class FromDict(object): + @classmethod + def from_dict(cls, dict): + if dict is None: + return cls() + params = inspect.signature(cls.__init__).parameters + + if any(p.kind == inspect.Parameter.VAR_KEYWORD for p in + params.values()): + return cls(**dict) + + param_names = set() + for i, (name, param) in enumerate(iteritems(params)): + if i == 0 and name == "self": + continue + if param.kind in KEYWORD_KINDS: + param_names.add(name) + filtered_dict = {k: v for k, v in iteritems(dict) if k in param_names} + return cls(**filtered_dict) diff --git a/snips_inference_agl/common/io_utils.py b/snips_inference_agl/common/io_utils.py new file mode 100644 index 0000000..f4407a3 --- /dev/null +++ b/snips_inference_agl/common/io_utils.py @@ -0,0 +1,36 @@ +import errno +import os +import shutil +from contextlib import contextmanager +from pathlib import Path +from tempfile import mkdtemp +from zipfile import ZipFile, ZIP_DEFLATED + + +def mkdir_p(path): + """Reproduces the 'mkdir -p shell' command + + See + http://stackoverflow.com/questions/600268/mkdir-p-functionality-in-python + """ + try: + os.makedirs(str(path)) + except OSError as exc: # Python >2.5 + if exc.errno == errno.EEXIST and path.is_dir(): + pass + else: + raise + + +@contextmanager +def temp_dir(): + tmp_dir = mkdtemp() + try: + yield Path(tmp_dir) + finally: + shutil.rmtree(tmp_dir) + + +def unzip_archive(archive_file, destination_dir): + with ZipFile(archive_file, "r", ZIP_DEFLATED) as zipf: + zipf.extractall(str(destination_dir)) diff --git a/snips_inference_agl/common/log_utils.py b/snips_inference_agl/common/log_utils.py new file mode 100644 index 0000000..47da34e --- /dev/null +++ b/snips_inference_agl/common/log_utils.py @@ -0,0 +1,61 @@ +from __future__ import unicode_literals + +from builtins import str +from datetime import datetime +from functools import wraps + +from snips_inference_agl.common.utils import json_debug_string + + +class DifferedLoggingMessage(object): + + def __init__(self, fn, *args, **kwargs): + self.fn = fn + self.args = args + self.kwargs = kwargs + + def __str__(self): + return str(self.fn(*self.args, **self.kwargs)) + + +def log_elapsed_time(logger, level, output_msg=None): + if output_msg is None: + output_msg = "Elapsed time ->:\n{elapsed_time}" + + def get_wrapper(fn): + @wraps(fn) + def wrapped(*args, **kwargs): + start = datetime.now() + msg_fmt = dict() + res = fn(*args, **kwargs) + if "elapsed_time" in output_msg: + msg_fmt["elapsed_time"] = datetime.now() - start + logger.log(level, output_msg.format(**msg_fmt)) + return res + + return wrapped + + return get_wrapper + + +def log_result(logger, level, output_msg=None): + if output_msg is None: + output_msg = "Result ->:\n{result}" + + def get_wrapper(fn): + @wraps(fn) + def wrapped(*args, **kwargs): + msg_fmt = dict() + res = fn(*args, **kwargs) + if "result" in output_msg: + try: + res_debug_string = json_debug_string(res) + except TypeError: + res_debug_string = str(res) + msg_fmt["result"] = res_debug_string + logger.log(level, output_msg.format(**msg_fmt)) + return res + + return wrapped + + return get_wrapper diff --git a/snips_inference_agl/common/registrable.py b/snips_inference_agl/common/registrable.py new file mode 100644 index 0000000..bbc7bdc --- /dev/null +++ b/snips_inference_agl/common/registrable.py @@ -0,0 +1,73 @@ +# This module is largely inspired by the AllenNLP library +# See github.com/allenai/allennlp/blob/master/allennlp/common/registrable.py + +from collections import defaultdict +from future.utils import iteritems + +from snips_inference_agl.exceptions import AlreadyRegisteredError, NotRegisteredError + + +class Registrable(object): + """ + Any class that inherits from ``Registrable`` gains access to a named + registry for its subclasses. To register them, just decorate them with the + classmethod ``@BaseClass.register(name)``. + + After which you can call ``BaseClass.list_available()`` to get the keys + for the registered subclasses, and ``BaseClass.by_name(name)`` to get the + corresponding subclass. + + Note that if you use this class to implement a new ``Registrable`` + abstract class, you must ensure that all subclasses of the abstract class + are loaded when the module is loaded, because the subclasses register + themselves in their respective files. You can achieve this by having the + abstract class and all subclasses in the __init__.py of the module in + which they reside (as this causes any import of either the abstract class + or a subclass to load all other subclasses and the abstract class). + """ + _registry = defaultdict(dict) + + @classmethod + def register(cls, name, override=False): + """Decorator used to add the decorated subclass to the registry of the + base class + + Args: + name (str): name use to identify the registered subclass + override (bool, optional): this parameter controls the behavior in + case where a subclass is registered with the same identifier. + If True, then the previous subclass will be unregistered in + profit of the new subclass. + + Raises: + AlreadyRegisteredError: when ``override`` is False, while trying + to register a subclass with a name already used by another + registered subclass + """ + registry = Registrable._registry[cls] + + def add_subclass_to_registry(subclass): + # Add to registry, raise an error if key has already been used. + if not override and name in registry: + raise AlreadyRegisteredError(name, cls, registry[name]) + registry[name] = subclass + return subclass + + return add_subclass_to_registry + + @classmethod + def registered_name(cls, registered_class): + for name, subclass in iteritems(Registrable._registry[cls]): + if subclass == registered_class: + return name + raise NotRegisteredError(cls, registered_cls=registered_class) + + @classmethod + def by_name(cls, name): + if name not in Registrable._registry[cls]: + raise NotRegisteredError(cls, name=name) + return Registrable._registry[cls][name] + + @classmethod + def list_available(cls): + return list(Registrable._registry[cls].keys()) diff --git a/snips_inference_agl/common/utils.py b/snips_inference_agl/common/utils.py new file mode 100644 index 0000000..816dae7 --- /dev/null +++ b/snips_inference_agl/common/utils.py @@ -0,0 +1,239 @@ +from __future__ import unicode_literals + +import importlib +import json +import numbers +import re +from builtins import bytes as newbytes, str as newstr +from datetime import datetime +from functools import wraps +from pathlib import Path + +from future.utils import text_type + +from snips_inference_agl.constants import (END, ENTITY_KIND, RES_MATCH_RANGE, RES_VALUE, + START) +from snips_inference_agl.exceptions import NotTrained, PersistingError + +REGEX_PUNCT = {'\\', '.', '+', '*', '?', '(', ')', '|', '[', ']', '{', '}', + '^', '$', '#', '&', '-', '~'} + + +# pylint:disable=line-too-long +def regex_escape(s): + """Escapes all regular expression meta characters in *s* + + The string returned may be safely used as a literal in a regular + expression. + + This function is more precise than :func:`re.escape`, the latter escapes + all non-alphanumeric characters which can cause cross-platform + compatibility issues. + + References: + + - https://github.com/rust-lang/regex/blob/master/regex-syntax/src/lib.rs#L1685 + - https://github.com/rust-lang/regex/blob/master/regex-syntax/src/parser.rs#L1378 + """ + escaped_string = "" + for c in s: + if c in REGEX_PUNCT: + escaped_string += "\\" + escaped_string += c + return escaped_string + + +# pylint:enable=line-too-long + + +def check_random_state(seed): + """Turn seed into a :class:`numpy.random.RandomState` instance + + If seed is None, return the RandomState singleton used by np.random. + If seed is an int, return a new RandomState instance seeded with seed. + If seed is already a RandomState instance, return it. + Otherwise raise ValueError. + """ + import numpy as np + + # pylint: disable=W0212 + # pylint: disable=c-extension-no-member + if seed is None or seed is np.random: + return np.random.mtrand._rand # pylint: disable=c-extension-no-member + if isinstance(seed, (numbers.Integral, np.integer)): + return np.random.RandomState(seed) + if isinstance(seed, np.random.RandomState): + return seed + raise ValueError('%r cannot be used to seed a numpy.random.RandomState' + ' instance' % seed) + + +def ranges_overlap(lhs_range, rhs_range): + if isinstance(lhs_range, dict) and isinstance(rhs_range, dict): + return lhs_range[END] > rhs_range[START] \ + and lhs_range[START] < rhs_range[END] + elif isinstance(lhs_range, (tuple, list)) \ + and isinstance(rhs_range, (tuple, list)): + return lhs_range[1] > rhs_range[0] and lhs_range[0] < rhs_range[1] + else: + raise TypeError("Cannot check overlap on objects of type: %s and %s" + % (type(lhs_range), type(rhs_range))) + + +def elapsed_since(time): + return datetime.now() - time + + +def json_debug_string(dict_data): + return json.dumps(dict_data, ensure_ascii=False, indent=2, sort_keys=True) + +def json_string(json_object, indent=2, sort_keys=True): + json_dump = json.dumps(json_object, indent=indent, sort_keys=sort_keys, + separators=(',', ': ')) + return unicode_string(json_dump) + +def unicode_string(string): + if isinstance(string, text_type): + return string + if isinstance(string, bytes): + return string.decode("utf8") + if isinstance(string, newstr): + return text_type(string) + if isinstance(string, newbytes): + string = bytes(string).decode("utf8") + + raise TypeError("Cannot convert %s into unicode string" % type(string)) + + +def check_persisted_path(func): + @wraps(func) + def func_wrapper(self, path, *args, **kwargs): + path = Path(path) + if path.exists(): + raise PersistingError(path) + return func(self, path, *args, **kwargs) + + return func_wrapper + + +def fitted_required(func): + @wraps(func) + def func_wrapper(self, *args, **kwargs): + if not self.fitted: + raise NotTrained("%s must be fitted" % self.unit_name) + return func(self, *args, **kwargs) + + return func_wrapper + + +def is_package(name): + """Check if name maps to a package installed via pip. + + Args: + name (str): Name of package + + Returns: + bool: True if an installed packaged corresponds to this name, False + otherwise. + """ + import pkg_resources + + name = name.lower().replace("-", "_") + packages = pkg_resources.working_set.by_key.keys() + for package in packages: + if package.lower().replace("-", "_") == name: + return True + return False + + +def get_package_path(name): + """Get the path to an installed package. + + Args: + name (str): Package name + + Returns: + class:`.Path`: Path to the installed package + """ + name = name.lower().replace("-", "_") + pkg = importlib.import_module(name) + return Path(pkg.__file__).parent + + +def deduplicate_overlapping_items(items, overlap_fn, sort_key_fn): + """Deduplicates the items by looping over the items, sorted using + sort_key_fn, and checking overlaps with previously seen items using + overlap_fn + """ + sorted_items = sorted(items, key=sort_key_fn) + deduplicated_items = [] + for item in sorted_items: + if not any(overlap_fn(item, dedup_item) + for dedup_item in deduplicated_items): + deduplicated_items.append(item) + return deduplicated_items + + +def replace_entities_with_placeholders(text, entities, placeholder_fn): + """Processes the text in order to replace entity values with placeholders + as defined by the placeholder function + """ + if not entities: + return dict(), text + + entities = deduplicate_overlapping_entities(entities) + entities = sorted( + entities, key=lambda e: e[RES_MATCH_RANGE][START]) + + range_mapping = dict() + processed_text = "" + offset = 0 + current_ix = 0 + for ent in entities: + ent_start = ent[RES_MATCH_RANGE][START] + ent_end = ent[RES_MATCH_RANGE][END] + rng_start = ent_start + offset + + processed_text += text[current_ix:ent_start] + + entity_length = ent_end - ent_start + entity_place_holder = placeholder_fn(ent[ENTITY_KIND]) + + offset += len(entity_place_holder) - entity_length + + processed_text += entity_place_holder + rng_end = ent_end + offset + new_range = (rng_start, rng_end) + range_mapping[new_range] = ent[RES_MATCH_RANGE] + current_ix = ent_end + + processed_text += text[current_ix:] + return range_mapping, processed_text + + +def deduplicate_overlapping_entities(entities): + """Deduplicates entities based on overlapping ranges""" + + def overlap(lhs_entity, rhs_entity): + return ranges_overlap(lhs_entity[RES_MATCH_RANGE], + rhs_entity[RES_MATCH_RANGE]) + + def sort_key_fn(entity): + return -len(entity[RES_VALUE]) + + deduplicated_entities = deduplicate_overlapping_items( + entities, overlap, sort_key_fn) + return sorted(deduplicated_entities, + key=lambda entity: entity[RES_MATCH_RANGE][START]) + + +SEMVER_PATTERN = r"^(?P<major>0|[1-9]\d*)" \ + r".(?P<minor>0|[1-9]\d*)" \ + r".(?P<patch>0|[1-9]\d*)" \ + r"(?:.(?P<subpatch>0|[1-9]\d*))?" \ + r"(?:-(?P<prerelease>(?:0|[1-9]\d*|\d*[a-zA-Z-]" \ + r"[0-9a-zA-Z-]*)" \ + r"(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?" \ + r"(?:\+(?P<buildmetadata>[0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*)" \ + r")?$" +SEMVER_REGEX = re.compile(SEMVER_PATTERN) |