aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/slot_filler/keyword_slot_filler.py
blob: 087d9979ab126cdbc15886103f4ec6acf361b9ef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from __future__ import unicode_literals

import json

from snips_inference_agl.common.utils import json_string
from snips_inference_agl.preprocessing import tokenize
from snips_inference_agl.result import unresolved_slot
from snips_inference_agl.slot_filler import SlotFiller


@SlotFiller.register("keyword_slot_filler")
class KeywordSlotFiller(SlotFiller):
    def __init__(self, config=None, **shared):
        super(KeywordSlotFiller, self).__init__(config, **shared)
        self.slots_keywords = None
        self.language = None

    @property
    def fitted(self):
        return self.slots_keywords is not None

    def fit(self, dataset, intent):
        self.language = dataset["language"]
        self.slots_keywords = dict()
        utterances = dataset["intents"][intent]["utterances"]
        for utterance in utterances:
            for chunk in utterance["data"]:
                if "slot_name" in chunk:
                    text = chunk["text"]
                    if self.config.get("lowercase", False):
                        text = text.lower()
                    self.slots_keywords[text] = [
                        chunk["entity"],
                        chunk["slot_name"]
                    ]
        return self

    def get_slots(self, text):
        tokens = tokenize(text, self.language)
        slots = []
        for token in tokens:
            normalized_value = token.value
            if self.config.get("lowercase", False):
                normalized_value = normalized_value.lower()
            if normalized_value in self.slots_keywords:
                entity = self.slots_keywords[normalized_value][0]
                slot_name = self.slots_keywords[normalized_value][1]
                slot = unresolved_slot((token.start, token.end), token.value,
                                       entity, slot_name)
                slots.append(slot)
        return slots

    def persist(self, path):
        model = {
            "language": self.language,
            "slots_keywords": self.slots_keywords,
            "config": self.config.to_dict()
        }
        with path.open(mode="w", encoding="utf8") as f:
            f.write(json_string(model))

    @classmethod
    def from_path(cls, path, **shared):
        with path.open() as f:
            model = json.load(f)
        slot_filler = cls()
        slot_filler.language = model["language"]
        slot_filler.slots_keywords = model["slots_keywords"]
        slot_filler.config = cls.config_type.from_dict(model["config"])
        return slot_filler