1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
|
from __future__ import unicode_literals
import importlib
import json
import numbers
import re
from builtins import bytes as newbytes, str as newstr
from datetime import datetime
from functools import wraps
from pathlib import Path
from future.utils import text_type
from snips_inference_agl.constants import (END, ENTITY_KIND, RES_MATCH_RANGE, RES_VALUE,
START)
from snips_inference_agl.exceptions import NotTrained, PersistingError
REGEX_PUNCT = {'\\', '.', '+', '*', '?', '(', ')', '|', '[', ']', '{', '}',
'^', '$', '#', '&', '-', '~'}
# pylint:disable=line-too-long
def regex_escape(s):
"""Escapes all regular expression meta characters in *s*
The string returned may be safely used as a literal in a regular
expression.
This function is more precise than :func:`re.escape`, the latter escapes
all non-alphanumeric characters which can cause cross-platform
compatibility issues.
References:
- https://github.com/rust-lang/regex/blob/master/regex-syntax/src/lib.rs#L1685
- https://github.com/rust-lang/regex/blob/master/regex-syntax/src/parser.rs#L1378
"""
escaped_string = ""
for c in s:
if c in REGEX_PUNCT:
escaped_string += "\\"
escaped_string += c
return escaped_string
# pylint:enable=line-too-long
def check_random_state(seed):
"""Turn seed into a :class:`numpy.random.RandomState` instance
If seed is None, return the RandomState singleton used by np.random.
If seed is an int, return a new RandomState instance seeded with seed.
If seed is already a RandomState instance, return it.
Otherwise raise ValueError.
"""
import numpy as np
# pylint: disable=W0212
# pylint: disable=c-extension-no-member
if seed is None or seed is np.random:
return np.random.mtrand._rand # pylint: disable=c-extension-no-member
if isinstance(seed, (numbers.Integral, np.integer)):
return np.random.RandomState(seed)
if isinstance(seed, np.random.RandomState):
return seed
raise ValueError('%r cannot be used to seed a numpy.random.RandomState'
' instance' % seed)
def ranges_overlap(lhs_range, rhs_range):
if isinstance(lhs_range, dict) and isinstance(rhs_range, dict):
return lhs_range[END] > rhs_range[START] \
and lhs_range[START] < rhs_range[END]
elif isinstance(lhs_range, (tuple, list)) \
and isinstance(rhs_range, (tuple, list)):
return lhs_range[1] > rhs_range[0] and lhs_range[0] < rhs_range[1]
else:
raise TypeError("Cannot check overlap on objects of type: %s and %s"
% (type(lhs_range), type(rhs_range)))
def elapsed_since(time):
return datetime.now() - time
def json_debug_string(dict_data):
return json.dumps(dict_data, ensure_ascii=False, indent=2, sort_keys=True)
def json_string(json_object, indent=2, sort_keys=True):
json_dump = json.dumps(json_object, indent=indent, sort_keys=sort_keys,
separators=(',', ': '))
return unicode_string(json_dump)
def unicode_string(string):
if isinstance(string, text_type):
return string
if isinstance(string, bytes):
return string.decode("utf8")
if isinstance(string, newstr):
return text_type(string)
if isinstance(string, newbytes):
string = bytes(string).decode("utf8")
raise TypeError("Cannot convert %s into unicode string" % type(string))
def check_persisted_path(func):
@wraps(func)
def func_wrapper(self, path, *args, **kwargs):
path = Path(path)
if path.exists():
raise PersistingError(path)
return func(self, path, *args, **kwargs)
return func_wrapper
def fitted_required(func):
@wraps(func)
def func_wrapper(self, *args, **kwargs):
if not self.fitted:
raise NotTrained("%s must be fitted" % self.unit_name)
return func(self, *args, **kwargs)
return func_wrapper
def is_package(name):
"""Check if name maps to a package installed via pip.
Args:
name (str): Name of package
Returns:
bool: True if an installed packaged corresponds to this name, False
otherwise.
"""
import pkg_resources
name = name.lower().replace("-", "_")
packages = pkg_resources.working_set.by_key.keys()
for package in packages:
if package.lower().replace("-", "_") == name:
return True
return False
def get_package_path(name):
"""Get the path to an installed package.
Args:
name (str): Package name
Returns:
class:`.Path`: Path to the installed package
"""
name = name.lower().replace("-", "_")
pkg = importlib.import_module(name)
return Path(pkg.__file__).parent
def deduplicate_overlapping_items(items, overlap_fn, sort_key_fn):
"""Deduplicates the items by looping over the items, sorted using
sort_key_fn, and checking overlaps with previously seen items using
overlap_fn
"""
sorted_items = sorted(items, key=sort_key_fn)
deduplicated_items = []
for item in sorted_items:
if not any(overlap_fn(item, dedup_item)
for dedup_item in deduplicated_items):
deduplicated_items.append(item)
return deduplicated_items
def replace_entities_with_placeholders(text, entities, placeholder_fn):
"""Processes the text in order to replace entity values with placeholders
as defined by the placeholder function
"""
if not entities:
return dict(), text
entities = deduplicate_overlapping_entities(entities)
entities = sorted(
entities, key=lambda e: e[RES_MATCH_RANGE][START])
range_mapping = dict()
processed_text = ""
offset = 0
current_ix = 0
for ent in entities:
ent_start = ent[RES_MATCH_RANGE][START]
ent_end = ent[RES_MATCH_RANGE][END]
rng_start = ent_start + offset
processed_text += text[current_ix:ent_start]
entity_length = ent_end - ent_start
entity_place_holder = placeholder_fn(ent[ENTITY_KIND])
offset += len(entity_place_holder) - entity_length
processed_text += entity_place_holder
rng_end = ent_end + offset
new_range = (rng_start, rng_end)
range_mapping[new_range] = ent[RES_MATCH_RANGE]
current_ix = ent_end
processed_text += text[current_ix:]
return range_mapping, processed_text
def deduplicate_overlapping_entities(entities):
"""Deduplicates entities based on overlapping ranges"""
def overlap(lhs_entity, rhs_entity):
return ranges_overlap(lhs_entity[RES_MATCH_RANGE],
rhs_entity[RES_MATCH_RANGE])
def sort_key_fn(entity):
return -len(entity[RES_VALUE])
deduplicated_entities = deduplicate_overlapping_items(
entities, overlap, sort_key_fn)
return sorted(deduplicated_entities,
key=lambda entity: entity[RES_MATCH_RANGE][START])
SEMVER_PATTERN = r"^(?P<major>0|[1-9]\d*)" \
r".(?P<minor>0|[1-9]\d*)" \
r".(?P<patch>0|[1-9]\d*)" \
r"(?:.(?P<subpatch>0|[1-9]\d*))?" \
r"(?:-(?P<prerelease>(?:0|[1-9]\d*|\d*[a-zA-Z-]" \
r"[0-9a-zA-Z-]*)" \
r"(?:\.(?:0|[1-9]\d*|\d*[a-zA-Z-][0-9a-zA-Z-]*))*))?" \
r"(?:\+(?P<buildmetadata>[0-9a-zA-Z-]+(?:\.[0-9a-zA-Z-]+)*)" \
r")?$"
SEMVER_REGEX = re.compile(SEMVER_PATTERN)
|