aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/slot_filler/feature.py
blob: a6da5526b017cdf385b5be5cc6639b2444f9bbfe (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
from __future__ import unicode_literals

from builtins import object

TOKEN_NAME = "token"


class Feature(object):
    """CRF Feature which is used by :class:`.CRFSlotFiller`

    Attributes:
        base_name (str): Feature name (e.g. 'is_digit', 'is_first' etc)
        func (function): The actual feature function for example:

            def is_first(tokens, token_index):
                return "1" if token_index == 0 else None

        offset (int, optional): Token offset to consider when computing
            the feature (e.g -1 for computing the feature on the previous word)
        drop_out (float, optional): Drop out to use when computing the
            feature during training

    Note:
        The easiest way to add additional features to the existing ones is
        to create a :class:`.CRFFeatureFactory`
    """

    def __init__(self, base_name, func, offset=0, drop_out=0):
        if base_name == TOKEN_NAME:
            raise ValueError("'%s' name is reserved" % TOKEN_NAME)
        self.offset = offset
        self._name = None
        self._base_name = None
        self.base_name = base_name
        self.function = func
        self.drop_out = drop_out

    @property
    def name(self):
        return self._name

    @property
    def base_name(self):
        return self._base_name

    @base_name.setter
    def base_name(self, value):
        self._name = _offset_name(value, self.offset)
        self._base_name = _offset_name(value, 0)

    def compute(self, token_index, cache):
        if not 0 <= (token_index + self.offset) < len(cache):
            return None

        if self.base_name in cache[token_index + self.offset]:
            return cache[token_index + self.offset][self.base_name]

        tokens = [c["token"] for c in cache]
        value = self.function(tokens, token_index + self.offset)
        cache[token_index + self.offset][self.base_name] = value
        return value


def _offset_name(name, offset):
    if offset > 0:
        return "%s[+%s]" % (name, offset)
    if offset < 0:
        return "%s[%s]" % (name, offset)
    return name