aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/slot_filler/crf_utils.py
diff options
context:
space:
mode:
Diffstat (limited to 'snips_inference_agl/slot_filler/crf_utils.py')
-rw-r--r--snips_inference_agl/slot_filler/crf_utils.py219
1 files changed, 219 insertions, 0 deletions
diff --git a/snips_inference_agl/slot_filler/crf_utils.py b/snips_inference_agl/slot_filler/crf_utils.py
new file mode 100644
index 0000000..817a59b
--- /dev/null
+++ b/snips_inference_agl/slot_filler/crf_utils.py
@@ -0,0 +1,219 @@
+from __future__ import unicode_literals
+
+from builtins import range
+from enum import Enum, unique
+
+from snips_inference_agl.constants import END, SLOT_NAME, START, TEXT
+from snips_inference_agl.preprocessing import Token, tokenize
+from snips_inference_agl.result import unresolved_slot
+
+BEGINNING_PREFIX = "B-"
+INSIDE_PREFIX = "I-"
+LAST_PREFIX = "L-"
+UNIT_PREFIX = "U-"
+OUTSIDE = "O"
+
+RANGE = "range"
+TAGS = "tags"
+TOKENS = "tokens"
+
+
+@unique
+class TaggingScheme(Enum):
+ """CRF Coding Scheme"""
+
+ IO = 0
+ """Inside-Outside scheme"""
+ BIO = 1
+ """Beginning-Inside-Outside scheme"""
+ BILOU = 2
+ """Beginning-Inside-Last-Outside-Unit scheme, sometimes referred as
+ BWEMO"""
+
+
+def tag_name_to_slot_name(tag):
+ return tag[2:]
+
+
+def start_of_io_slot(tags, i):
+ if i == 0:
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ return tags[i - 1] == OUTSIDE
+
+
+def end_of_io_slot(tags, i):
+ if i + 1 == len(tags):
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ return tags[i + 1] == OUTSIDE
+
+
+def start_of_bio_slot(tags, i):
+ if i == 0:
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ if tags[i].startswith(BEGINNING_PREFIX):
+ return True
+ if tags[i - 1] != OUTSIDE:
+ return False
+ return True
+
+
+def end_of_bio_slot(tags, i):
+ if i + 1 == len(tags):
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ if tags[i + 1].startswith(INSIDE_PREFIX):
+ return False
+ return True
+
+
+def start_of_bilou_slot(tags, i):
+ if i == 0:
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ if tags[i].startswith(BEGINNING_PREFIX):
+ return True
+ if tags[i].startswith(UNIT_PREFIX):
+ return True
+ if tags[i - 1].startswith(UNIT_PREFIX):
+ return True
+ if tags[i - 1].startswith(LAST_PREFIX):
+ return True
+ if tags[i - 1] != OUTSIDE:
+ return False
+ return True
+
+
+def end_of_bilou_slot(tags, i):
+ if i + 1 == len(tags):
+ return tags[i] != OUTSIDE
+ if tags[i] == OUTSIDE:
+ return False
+ if tags[i + 1] == OUTSIDE:
+ return True
+ if tags[i].startswith(LAST_PREFIX):
+ return True
+ if tags[i].startswith(UNIT_PREFIX):
+ return True
+ if tags[i + 1].startswith(BEGINNING_PREFIX):
+ return True
+ if tags[i + 1].startswith(UNIT_PREFIX):
+ return True
+ return False
+
+
+def _tags_to_preslots(tags, tokens, is_start_of_slot, is_end_of_slot):
+ slots = []
+ current_slot_start = 0
+ for i, tag in enumerate(tags):
+ if is_start_of_slot(tags, i):
+ current_slot_start = i
+ if is_end_of_slot(tags, i):
+ slots.append({
+ RANGE: {
+ START: tokens[current_slot_start].start,
+ END: tokens[i].end
+ },
+ SLOT_NAME: tag_name_to_slot_name(tag)
+ })
+ current_slot_start = i
+ return slots
+
+
+def tags_to_preslots(tokens, tags, tagging_scheme):
+ if tagging_scheme == TaggingScheme.IO:
+ slots = _tags_to_preslots(tags, tokens, start_of_io_slot,
+ end_of_io_slot)
+ elif tagging_scheme == TaggingScheme.BIO:
+ slots = _tags_to_preslots(tags, tokens, start_of_bio_slot,
+ end_of_bio_slot)
+ elif tagging_scheme == TaggingScheme.BILOU:
+ slots = _tags_to_preslots(tags, tokens, start_of_bilou_slot,
+ end_of_bilou_slot)
+ else:
+ raise ValueError("Unknown tagging scheme %s" % tagging_scheme)
+ return slots
+
+
+def tags_to_slots(text, tokens, tags, tagging_scheme, intent_slots_mapping):
+ slots = tags_to_preslots(tokens, tags, tagging_scheme)
+ return [
+ unresolved_slot(match_range=slot[RANGE],
+ value=text[slot[RANGE][START]:slot[RANGE][END]],
+ entity=intent_slots_mapping[slot[SLOT_NAME]],
+ slot_name=slot[SLOT_NAME])
+ for slot in slots
+ ]
+
+
+def positive_tagging(tagging_scheme, slot_name, slot_size):
+ if slot_name == OUTSIDE:
+ return [OUTSIDE for _ in range(slot_size)]
+
+ if tagging_scheme == TaggingScheme.IO:
+ tags = [INSIDE_PREFIX + slot_name for _ in range(slot_size)]
+ elif tagging_scheme == TaggingScheme.BIO:
+ if slot_size > 0:
+ tags = [BEGINNING_PREFIX + slot_name]
+ tags += [INSIDE_PREFIX + slot_name for _ in range(1, slot_size)]
+ else:
+ tags = []
+ elif tagging_scheme == TaggingScheme.BILOU:
+ if slot_size == 0:
+ tags = []
+ elif slot_size == 1:
+ tags = [UNIT_PREFIX + slot_name]
+ else:
+ tags = [BEGINNING_PREFIX + slot_name]
+ tags += [INSIDE_PREFIX + slot_name
+ for _ in range(1, slot_size - 1)]
+ tags.append(LAST_PREFIX + slot_name)
+ else:
+ raise ValueError("Invalid tagging scheme %s" % tagging_scheme)
+ return tags
+
+
+def negative_tagging(size):
+ return [OUTSIDE for _ in range(size)]
+
+
+def utterance_to_sample(query_data, tagging_scheme, language):
+ tokens, tags = [], []
+ current_length = 0
+ for chunk in query_data:
+ chunk_tokens = tokenize(chunk[TEXT], language)
+ tokens += [Token(t.value, current_length + t.start,
+ current_length + t.end) for t in chunk_tokens]
+ current_length += len(chunk[TEXT])
+ if SLOT_NAME not in chunk:
+ tags += negative_tagging(len(chunk_tokens))
+ else:
+ tags += positive_tagging(tagging_scheme, chunk[SLOT_NAME],
+ len(chunk_tokens))
+ return {TOKENS: tokens, TAGS: tags}
+
+
+def get_scheme_prefix(index, indexes, tagging_scheme):
+ if tagging_scheme == TaggingScheme.IO:
+ return INSIDE_PREFIX
+ elif tagging_scheme == TaggingScheme.BIO:
+ if index == indexes[0]:
+ return BEGINNING_PREFIX
+ return INSIDE_PREFIX
+ elif tagging_scheme == TaggingScheme.BILOU:
+ if len(indexes) == 1:
+ return UNIT_PREFIX
+ if index == indexes[0]:
+ return BEGINNING_PREFIX
+ if index == indexes[-1]:
+ return LAST_PREFIX
+ return INSIDE_PREFIX
+ else:
+ raise ValueError("Invalid tagging scheme %s" % tagging_scheme)