aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/slot_filler/features_utils.py
blob: 483e9c0721a8ab59e0dcb1c0d78aba98dda908be (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
from __future__ import unicode_literals

from copy import deepcopy

from snips_inference_agl.common.dict_utils import LimitedSizeDict
from snips_inference_agl.constants import END, RES_MATCH_RANGE, START

_NGRAMS_CACHE = LimitedSizeDict(size_limit=1000)


def get_all_ngrams(tokens):
    from snips_nlu_utils import compute_all_ngrams

    if not tokens:
        return []
    key = "<||>".join(tokens)
    if key not in _NGRAMS_CACHE:
        ngrams = compute_all_ngrams(tokens, len(tokens))
        _NGRAMS_CACHE[key] = ngrams
    return deepcopy(_NGRAMS_CACHE[key])


def get_word_chunk(word, chunk_size, chunk_start, reverse=False):
    if chunk_size < 1:
        raise ValueError("chunk size should be >= 1")
    if chunk_size > len(word):
        return None
    start = chunk_start - chunk_size if reverse else chunk_start
    end = chunk_start if reverse else chunk_start + chunk_size
    return word[start:end]


def initial_string_from_tokens(tokens):
    current_index = 0
    s = ""
    for t in tokens:
        if t.start > current_index:
            s += " " * (t.start - current_index)
        s += t.value
        current_index = t.end
    return s


def entity_filter(entity, start, end):
    entity_start = entity[RES_MATCH_RANGE][START]
    entity_end = entity[RES_MATCH_RANGE][END]
    return entity_start <= start < end <= entity_end