aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/common/utils.py
blob: 816dae7b5cb7b446b24882f3e795e6daa9870a5e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
from __future__ import unicode_literals

import importlib
import json
import numbers
import re
from builtins import bytes as newbytes, str as newstr
from datetime import datetime
from functools import wraps
from pathlib import Path

from future.utils import text_type

from snips_inference_agl.constants import (END, ENTITY_KIND, RES_MATCH_RANGE, RES_VALUE,
                                 START)
from snips_inference_agl.exceptions import NotTrained, PersistingError

REGEX_PUNCT = {'\\', '.', '+', '*', '?', '(', ')', '|', '[', ']', '{', '}',
               '^', '$', '#', '&', '-', '~'}


# pylint:disable=line-too-long
def regex_escape(s):
    """Escapes all regular expression meta characters in *s*

    The string returned may be safely used as a literal in a regular
    expression.

    This function is more precise than :func:`re.escape`, the latter escapes
    all non-alphanumeric characters which can cause cross-platform
    compatibility issues.

    References:

    - https://github.com/rust-lang/regex/blob/master/regex-syntax/src/lib.rs#L1685
    - https://github.com/rust-lang/regex/blob/master/regex-syntax/src/parser.rs#L1378
    """
    escaped_string = ""
    for c in s:
        if c in REGEX_PUNCT:
            escaped_string += "\\"
        escaped_string += c
    return escaped_string


# pylint:enable=line-too-long


def check_random_state(seed):
    """Turn seed into a :class:`numpy.random.RandomState` instance

    If seed is None, return the RandomState singleton used by np.random.
    If seed is an int, return a new RandomState instance seeded with seed.
    If seed is already a RandomState instance, return it.
    Otherwise raise ValueError.
    """
    import numpy as np

    # pylint: disable=W0212
    # pylint: disable=c-extension-no-member
    if seed is None or seed is np.random:
        return np.random.mtrand._rand  # pylint: disable=c-extension-no-member
    if isinstance(seed, (numbers.Integral, np.integer)):
        return np.random.RandomState(seed)
    if isinstance(seed, np.random.RandomState):
        return seed
    raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
                     ' instance' % seed)


def ranges_overlap(lhs_range, rhs_range):
    if isinstance(lhs_range, dict) and isinstance(rhs_range, dict):
        return lhs_range[END] > rhs_range[START] \
               and lhs_range[START] < rhs_range[END]
    elif isinstance(lhs_range, (tuple, list)) \
            and isinstance(rhs_range, (tuple, list)):
        return lhs_range[1] > rhs_range[0] and lhs_range[0] < rhs_range[1]
    else:
        raise TypeError("Cannot check overlap on objects of type: %s and %s"
                        % (type(lhs_range), type(rhs_range)))


def elapsed_since(time):
    return datetime.now() - time


def json_debug_string(dict_data):
    return json.dumps(dict_data, ensure_ascii=False, indent=2, sort_keys=True)

def json_string(json_object, indent=2, sort_keys=True):
    json_dump = json.dumps(json_object, indent=indent, sort_keys=sort_keys,
                           separators=(',', ': '))
    return unicode_string(json_dump)

def unicode_string(string):
    if isinstance(string, text_type):
        return string
    if isinstance(string, bytes):
        return string.decode("utf8")
    if isinstance(string, newstr):
        return text_type(string)
    if isinstance(string, newbytes):
        string = bytes(string).decode("utf8")

    raise TypeError("Cannot convert %s into unicode string" % type(string))


def check_persisted_path(func):
    @wraps(func)
    def func_wrapper(self, path, *args, **kwargs):
        path = Path(path)
        if path.exists():
            raise PersistingError(path)
        return func(self, path, *args, **kwargs)

    return func_wrapper


def fitted_required(func):
    @wraps(func)
    def func_wrapper(self, *args, **kwargs):
        if not self.fitted:
            raise NotTrained("%s must be fitted" % self.unit_name)
        return func(self, *args, **kwargs)

    return func_wrapper


def is_package(name):
    """Check if name maps to a package installed via pip.

    Args:
        name (str): Name of package

    Returns:
        bool: True if an installed packaged corresponds to this name, False
            otherwise.
    """
    import pkg_resources

    name = name.lower().replace("-", "_")
    packages = pkg_resources.working_set.by_key.keys()
    for package in packages:
        if package.lower().replace("-", "_") == name:
            return True
    return False


def get_package_path(name):
    """Get the path to an installed package.

    Args:
        name (str): Package name

    Returns:
        class:`.Path`: Path to the installed package
    """
    name = name.lower().replace("-", "_")
    pkg = importlib.import_module(name)
    return Path(pkg.__file__).parent


def deduplicate_overlapping_items(items, overlap_fn, sort_key_fn):
    """Deduplicates the items by looping over the items, sorted using
    sort_key_fn, and checking overlaps with previously seen items using
    overlap_fn
    """
    sorted_items = sorted(items, key=sort_key_fn)
    deduplicated_items = []
    for item in sorted_items:
        if not any(overlap_fn(item, dedup_item)
                   for dedup_item in deduplicated_items):
            deduplicated_items.append(item)
    return deduplicated_items


def replace_entities_with_placeholders(text, entities, placeholder_fn):
    """Processes the text in order to replace entity values with placeholders
    as defined by the placeholder function
    """
    if not entities:
        return dict(), text

    entities = deduplicate_overlapping_entities(entities)
    entities = sorted(
        entities, key=lambda e: e[RES_MATCH_RANGE][START])

    range_mapping = dict()
    processed_text = ""
    offset = 0
    current_ix = 0
    for ent in entities:
        ent_start = ent[RES_MATCH_RANGE][START]
        ent_end = ent[RES_MATCH_RANGE][END]
        rng_start = ent_start + offset

        processed_text += text[current_ix:ent_start]

        entity_length = ent_end - ent_start
        entity_place_holder = placeholder_fn(ent[ENTITY_KIND])

        offset += len(entity_place_holder) - entity_length

        processed_text += entity_place_holder
        rng_end = ent_end + offset
        new_range = (rng_start, rng_end)
        range_mapping[new_range] = ent[RES_MATCH_RANGE]
        current_ix = ent_end

    processed_text += text[current_ix:]
    return range_mapping, processed_text


def deduplicate_overlapping_entities(entities):
    """Deduplicates entities based on overlapping ranges"""

    def overlap(lhs_entity, rhs_entity):
        return ranges_overlap(lhs_entity[RES_MATCH_RANGE],
                              rhs_entity[RES_MATCH_RANGE])

    def sort_key_fn(entity):
        return -len(entity[RES_VALUE])

    deduplicated_entities = deduplicate_overlapping_items(
        entities, overlap, sort_key_fn)
    return sorted(deduplicated_entities,
                  key=lambda entity: entity[RES_MATCH_RANGE][START])


SEMVER_PATTERN = r"^(?P<major>0|[1-9]\d*)" \
                 r".(?P<minor>0|[1-9]\d*)" \
                 r".(?P<patch>0|[1-9]\d*)" \
                 r"(?:.(?P<subpatch>0|[1-9]\d*))?" \
                 r"(?:-(?P<prerelease>(?:0|[1-9]\d*|\d*[a-zA-Z-]" \
                 r"[0-9a-zA-Z-]*)" \
                 r"(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?" \
                 r"(?:\+(?P<buildmetadata>[0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*)" \
                 r")?$"
SEMVER_REGEX = re.compile(SEMVER_PATTERN)