aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/common/utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'snips_inference_agl/common/utils.py')
-rw-r--r--snips_inference_agl/common/utils.py239
1 files changed, 239 insertions, 0 deletions
diff --git a/snips_inference_agl/common/utils.py b/snips_inference_agl/common/utils.py
new file mode 100644
index 0000000..816dae7
--- /dev/null
+++ b/snips_inference_agl/common/utils.py
@@ -0,0 +1,239 @@
+from __future__ import unicode_literals
+
+import importlib
+import json
+import numbers
+import re
+from builtins import bytes as newbytes, str as newstr
+from datetime import datetime
+from functools import wraps
+from pathlib import Path
+
+from future.utils import text_type
+
+from snips_inference_agl.constants import (END, ENTITY_KIND, RES_MATCH_RANGE, RES_VALUE,
+ START)
+from snips_inference_agl.exceptions import NotTrained, PersistingError
+
+REGEX_PUNCT = {'\\', '.', '+', '*', '?', '(', ')', '|', '[', ']', '{', '}',
+ '^', '$', '#', '&', '-', '~'}
+
+
+# pylint:disable=line-too-long
+def regex_escape(s):
+ """Escapes all regular expression meta characters in *s*
+
+ The string returned may be safely used as a literal in a regular
+ expression.
+
+ This function is more precise than :func:`re.escape`, the latter escapes
+ all non-alphanumeric characters which can cause cross-platform
+ compatibility issues.
+
+ References:
+
+ - https://github.com/rust-lang/regex/blob/master/regex-syntax/src/lib.rs#L1685
+ - https://github.com/rust-lang/regex/blob/master/regex-syntax/src/parser.rs#L1378
+ """
+ escaped_string = ""
+ for c in s:
+ if c in REGEX_PUNCT:
+ escaped_string += "\\"
+ escaped_string += c
+ return escaped_string
+
+
+# pylint:enable=line-too-long
+
+
+def check_random_state(seed):
+ """Turn seed into a :class:`numpy.random.RandomState` instance
+
+ If seed is None, return the RandomState singleton used by np.random.
+ If seed is an int, return a new RandomState instance seeded with seed.
+ If seed is already a RandomState instance, return it.
+ Otherwise raise ValueError.
+ """
+ import numpy as np
+
+ # pylint: disable=W0212
+ # pylint: disable=c-extension-no-member
+ if seed is None or seed is np.random:
+ return np.random.mtrand._rand # pylint: disable=c-extension-no-member
+ if isinstance(seed, (numbers.Integral, np.integer)):
+ return np.random.RandomState(seed)
+ if isinstance(seed, np.random.RandomState):
+ return seed
+ raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
+ ' instance' % seed)
+
+
+def ranges_overlap(lhs_range, rhs_range):
+ if isinstance(lhs_range, dict) and isinstance(rhs_range, dict):
+ return lhs_range[END] > rhs_range[START] \
+ and lhs_range[START] < rhs_range[END]
+ elif isinstance(lhs_range, (tuple, list)) \
+ and isinstance(rhs_range, (tuple, list)):
+ return lhs_range[1] > rhs_range[0] and lhs_range[0] < rhs_range[1]
+ else:
+ raise TypeError("Cannot check overlap on objects of type: %s and %s"
+ % (type(lhs_range), type(rhs_range)))
+
+
+def elapsed_since(time):
+ return datetime.now() - time
+
+
+def json_debug_string(dict_data):
+ return json.dumps(dict_data, ensure_ascii=False, indent=2, sort_keys=True)
+
+def json_string(json_object, indent=2, sort_keys=True):
+ json_dump = json.dumps(json_object, indent=indent, sort_keys=sort_keys,
+ separators=(',', ': '))
+ return unicode_string(json_dump)
+
+def unicode_string(string):
+ if isinstance(string, text_type):
+ return string
+ if isinstance(string, bytes):
+ return string.decode("utf8")
+ if isinstance(string, newstr):
+ return text_type(string)
+ if isinstance(string, newbytes):
+ string = bytes(string).decode("utf8")
+
+ raise TypeError("Cannot convert %s into unicode string" % type(string))
+
+
+def check_persisted_path(func):
+ @wraps(func)
+ def func_wrapper(self, path, *args, **kwargs):
+ path = Path(path)
+ if path.exists():
+ raise PersistingError(path)
+ return func(self, path, *args, **kwargs)
+
+ return func_wrapper
+
+
+def fitted_required(func):
+ @wraps(func)
+ def func_wrapper(self, *args, **kwargs):
+ if not self.fitted:
+ raise NotTrained("%s must be fitted" % self.unit_name)
+ return func(self, *args, **kwargs)
+
+ return func_wrapper
+
+
+def is_package(name):
+ """Check if name maps to a package installed via pip.
+
+ Args:
+ name (str): Name of package
+
+ Returns:
+ bool: True if an installed packaged corresponds to this name, False
+ otherwise.
+ """
+ import pkg_resources
+
+ name = name.lower().replace("-", "_")
+ packages = pkg_resources.working_set.by_key.keys()
+ for package in packages:
+ if package.lower().replace("-", "_") == name:
+ return True
+ return False
+
+
+def get_package_path(name):
+ """Get the path to an installed package.
+
+ Args:
+ name (str): Package name
+
+ Returns:
+ class:`.Path`: Path to the installed package
+ """
+ name = name.lower().replace("-", "_")
+ pkg = importlib.import_module(name)
+ return Path(pkg.__file__).parent
+
+
+def deduplicate_overlapping_items(items, overlap_fn, sort_key_fn):
+ """Deduplicates the items by looping over the items, sorted using
+ sort_key_fn, and checking overlaps with previously seen items using
+ overlap_fn
+ """
+ sorted_items = sorted(items, key=sort_key_fn)
+ deduplicated_items = []
+ for item in sorted_items:
+ if not any(overlap_fn(item, dedup_item)
+ for dedup_item in deduplicated_items):
+ deduplicated_items.append(item)
+ return deduplicated_items
+
+
+def replace_entities_with_placeholders(text, entities, placeholder_fn):
+ """Processes the text in order to replace entity values with placeholders
+ as defined by the placeholder function
+ """
+ if not entities:
+ return dict(), text
+
+ entities = deduplicate_overlapping_entities(entities)
+ entities = sorted(
+ entities, key=lambda e: e[RES_MATCH_RANGE][START])
+
+ range_mapping = dict()
+ processed_text = ""
+ offset = 0
+ current_ix = 0
+ for ent in entities:
+ ent_start = ent[RES_MATCH_RANGE][START]
+ ent_end = ent[RES_MATCH_RANGE][END]
+ rng_start = ent_start + offset
+
+ processed_text += text[current_ix:ent_start]
+
+ entity_length = ent_end - ent_start
+ entity_place_holder = placeholder_fn(ent[ENTITY_KIND])
+
+ offset += len(entity_place_holder) - entity_length
+
+ processed_text += entity_place_holder
+ rng_end = ent_end + offset
+ new_range = (rng_start, rng_end)
+ range_mapping[new_range] = ent[RES_MATCH_RANGE]
+ current_ix = ent_end
+
+ processed_text += text[current_ix:]
+ return range_mapping, processed_text
+
+
+def deduplicate_overlapping_entities(entities):
+ """Deduplicates entities based on overlapping ranges"""
+
+ def overlap(lhs_entity, rhs_entity):
+ return ranges_overlap(lhs_entity[RES_MATCH_RANGE],
+ rhs_entity[RES_MATCH_RANGE])
+
+ def sort_key_fn(entity):
+ return -len(entity[RES_VALUE])
+
+ deduplicated_entities = deduplicate_overlapping_items(
+ entities, overlap, sort_key_fn)
+ return sorted(deduplicated_entities,
+ key=lambda entity: entity[RES_MATCH_RANGE][START])
+
+
+SEMVER_PATTERN = r"^(?P<major>0|[1-9]\d*)" \
+ r".(?P<minor>0|[1-9]\d*)" \
+ r".(?P<patch>0|[1-9]\d*)" \
+ r"(?:.(?P<subpatch>0|[1-9]\d*))?" \
+ r"(?:-(?P<prerelease>(?:0|[1-9]\d*|\d*[a-zA-Z-]" \
+ r"[0-9a-zA-Z-]*)" \
+ r"(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?" \
+ r"(?:\+(?P<buildmetadata>[0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*)" \
+ r")?$"
+SEMVER_REGEX = re.compile(SEMVER_PATTERN)