aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/slot_filler/keyword_slot_filler.py
diff options
context:
space:
mode:
Diffstat (limited to 'snips_inference_agl/slot_filler/keyword_slot_filler.py')
-rw-r--r--snips_inference_agl/slot_filler/keyword_slot_filler.py70
1 files changed, 70 insertions, 0 deletions
diff --git a/snips_inference_agl/slot_filler/keyword_slot_filler.py b/snips_inference_agl/slot_filler/keyword_slot_filler.py
new file mode 100644
index 0000000..087d997
--- /dev/null
+++ b/snips_inference_agl/slot_filler/keyword_slot_filler.py
@@ -0,0 +1,70 @@
+from __future__ import unicode_literals
+
+import json
+
+from snips_inference_agl.common.utils import json_string
+from snips_inference_agl.preprocessing import tokenize
+from snips_inference_agl.result import unresolved_slot
+from snips_inference_agl.slot_filler import SlotFiller
+
+
+@SlotFiller.register("keyword_slot_filler")
+class KeywordSlotFiller(SlotFiller):
+ def __init__(self, config=None, **shared):
+ super(KeywordSlotFiller, self).__init__(config, **shared)
+ self.slots_keywords = None
+ self.language = None
+
+ @property
+ def fitted(self):
+ return self.slots_keywords is not None
+
+ def fit(self, dataset, intent):
+ self.language = dataset["language"]
+ self.slots_keywords = dict()
+ utterances = dataset["intents"][intent]["utterances"]
+ for utterance in utterances:
+ for chunk in utterance["data"]:
+ if "slot_name" in chunk:
+ text = chunk["text"]
+ if self.config.get("lowercase", False):
+ text = text.lower()
+ self.slots_keywords[text] = [
+ chunk["entity"],
+ chunk["slot_name"]
+ ]
+ return self
+
+ def get_slots(self, text):
+ tokens = tokenize(text, self.language)
+ slots = []
+ for token in tokens:
+ normalized_value = token.value
+ if self.config.get("lowercase", False):
+ normalized_value = normalized_value.lower()
+ if normalized_value in self.slots_keywords:
+ entity = self.slots_keywords[normalized_value][0]
+ slot_name = self.slots_keywords[normalized_value][1]
+ slot = unresolved_slot((token.start, token.end), token.value,
+ entity, slot_name)
+ slots.append(slot)
+ return slots
+
+ def persist(self, path):
+ model = {
+ "language": self.language,
+ "slots_keywords": self.slots_keywords,
+ "config": self.config.to_dict()
+ }
+ with path.open(mode="w", encoding="utf8") as f:
+ f.write(json_string(model))
+
+ @classmethod
+ def from_path(cls, path, **shared):
+ with path.open() as f:
+ model = json.load(f)
+ slot_filler = cls()
+ slot_filler.language = model["language"]
+ slot_filler.slots_keywords = model["slots_keywords"]
+ slot_filler.config = cls.config_type.from_dict(model["config"])
+ return slot_filler