aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/dataset/entity.py
diff options
context:
space:
mode:
Diffstat (limited to 'snips_inference_agl/dataset/entity.py')
-rw-r--r--snips_inference_agl/dataset/entity.py175
1 files changed, 175 insertions, 0 deletions
diff --git a/snips_inference_agl/dataset/entity.py b/snips_inference_agl/dataset/entity.py
new file mode 100644
index 0000000..65b9994
--- /dev/null
+++ b/snips_inference_agl/dataset/entity.py
@@ -0,0 +1,175 @@
+# coding=utf-8
+from __future__ import unicode_literals
+
+from builtins import str
+from io import IOBase
+
+from snips_inference_agl.constants import (
+ AUTOMATICALLY_EXTENSIBLE, DATA, MATCHING_STRICTNESS, SYNONYMS,
+ USE_SYNONYMS, VALUE)
+from snips_inference_agl.exceptions import EntityFormatError
+
+
+class Entity(object):
+ """Entity data of a :class:`.Dataset`
+
+ This class can represents both a custom or a builtin entity. When the
+ entity is a builtin one, only the `name` attribute is relevant.
+
+ Attributes:
+ name (str): name of the entity
+ utterances (list of :class:`.EntityUtterance`): entity utterances
+ (only for custom entities)
+ automatically_extensible (bool): whether or not the entity can be
+ extended to values not present in the data (only for custom
+ entities)
+ use_synonyms (bool): whether or not to map entity values using
+ synonyms (only for custom entities)
+ matching_strictness (float): controls the matching strictness of the
+ entity (only for custom entities). Must be between 0.0 and 1.0.
+ """
+
+ def __init__(self, name, utterances=None, automatically_extensible=True,
+ use_synonyms=True, matching_strictness=1.0):
+ if utterances is None:
+ utterances = []
+ self.name = name
+ self.utterances = utterances
+ self.automatically_extensible = automatically_extensible
+ self.use_synonyms = use_synonyms
+ self.matching_strictness = matching_strictness
+
+ @property
+ def is_builtin(self):
+ from snips_nlu_parsers import get_all_builtin_entities
+
+ return self.name in get_all_builtin_entities()
+
+ @classmethod
+ def from_yaml(cls, yaml_dict):
+ """Build an :class:`.Entity` from its YAML definition object
+
+ Args:
+ yaml_dict (dict or :class:`.IOBase`): object containing the YAML
+ definition of the entity. It can be either a stream, or the
+ corresponding python dict.
+
+ Examples:
+ An entity can be defined with a YAML document following the schema
+ illustrated in the example below:
+
+ >>> import io
+ >>> from snips_inference_agl.common.utils import json_string
+ >>> entity_yaml = io.StringIO('''
+ ... # City Entity
+ ... ---
+ ... type: entity
+ ... name: city
+ ... automatically_extensible: false # default value is true
+ ... use_synonyms: false # default value is true
+ ... matching_strictness: 0.8 # default value is 1.0
+ ... values:
+ ... - london
+ ... - [new york, big apple]
+ ... - [paris, city of lights]''')
+ >>> entity = Entity.from_yaml(entity_yaml)
+ >>> print(json_string(entity.json, indent=4, sort_keys=True))
+ {
+ "automatically_extensible": false,
+ "data": [
+ {
+ "synonyms": [],
+ "value": "london"
+ },
+ {
+ "synonyms": [
+ "big apple"
+ ],
+ "value": "new york"
+ },
+ {
+ "synonyms": [
+ "city of lights"
+ ],
+ "value": "paris"
+ }
+ ],
+ "matching_strictness": 0.8,
+ "use_synonyms": false
+ }
+
+ Raises:
+ EntityFormatError: When the YAML dict does not correspond to the
+ :ref:`expected entity format <yaml_entity_format>`
+ """
+ if isinstance(yaml_dict, IOBase):
+ from snips_inference_agl.dataset.yaml_wrapper import yaml
+
+ yaml_dict = yaml.safe_load(yaml_dict)
+
+ object_type = yaml_dict.get("type")
+ if object_type and object_type != "entity":
+ raise EntityFormatError("Wrong type: '%s'" % object_type)
+ entity_name = yaml_dict.get("name")
+ if not entity_name:
+ raise EntityFormatError("Missing 'name' attribute")
+ auto_extensible = yaml_dict.get(AUTOMATICALLY_EXTENSIBLE, True)
+ use_synonyms = yaml_dict.get(USE_SYNONYMS, True)
+ matching_strictness = yaml_dict.get("matching_strictness", 1.0)
+ utterances = []
+ for entity_value in yaml_dict.get("values", []):
+ if isinstance(entity_value, list):
+ utterance = EntityUtterance(entity_value[0], entity_value[1:])
+ elif isinstance(entity_value, str):
+ utterance = EntityUtterance(entity_value)
+ else:
+ raise EntityFormatError(
+ "YAML entity values must be either strings or lists, but "
+ "found: %s" % type(entity_value))
+ utterances.append(utterance)
+
+ return cls(name=entity_name,
+ utterances=utterances,
+ automatically_extensible=auto_extensible,
+ use_synonyms=use_synonyms,
+ matching_strictness=matching_strictness)
+
+ @property
+ def json(self):
+ """Returns the entity in json format"""
+ if self.is_builtin:
+ return dict()
+ return {
+ AUTOMATICALLY_EXTENSIBLE: self.automatically_extensible,
+ USE_SYNONYMS: self.use_synonyms,
+ DATA: [u.json for u in self.utterances],
+ MATCHING_STRICTNESS: self.matching_strictness
+ }
+
+
+class EntityUtterance(object):
+ """Represents a value of a :class:`.CustomEntity` with potential synonyms
+
+ Attributes:
+ value (str): entity value
+ synonyms (list of str): The values to remap to the utterance value
+ """
+
+ def __init__(self, value, synonyms=None):
+ self.value = value
+ if synonyms is None:
+ synonyms = []
+ self.synonyms = synonyms
+
+ @property
+ def variations(self):
+ return [self.value] + self.synonyms
+
+ @property
+ def json(self):
+ return {VALUE: self.value, SYNONYMS: self.synonyms}
+
+
+def utf_8_encoder(f):
+ for line in f:
+ yield line.encode("utf-8")