aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/entity_parser/builtin_entity_parser.py
blob: 02fa6107c9c441e711513c74c00f3565fb12e6e4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
from __future__ import unicode_literals

import json
import shutil

from future.builtins import str

from snips_inference_agl.common.io_utils import temp_dir
from snips_inference_agl.common.utils import json_string
from snips_inference_agl.constants import DATA_PATH, ENTITIES, LANGUAGE
from snips_inference_agl.entity_parser.entity_parser import EntityParser
from snips_inference_agl.result import parsed_entity

_BUILTIN_ENTITY_PARSERS = dict()

try:
    FileNotFoundError
except NameError:
    FileNotFoundError = IOError


class BuiltinEntityParser(EntityParser):
    def __init__(self, parser):
        super(BuiltinEntityParser, self).__init__()
        self._parser = parser

    def _parse(self, text, scope=None):
        entities = self._parser.parse(text.lower(), scope=scope)
        result = []
        for entity in entities:
            ent = parsed_entity(
                entity_kind=entity["entity_kind"],
                entity_value=entity["value"],
                entity_resolved_value=entity["entity"],
                entity_range=entity["range"]
            )
            result.append(ent)
        return result

    def persist(self, path):
        self._parser.persist(path)

    @classmethod
    def from_path(cls, path):
        from snips_nlu_parsers import (
            BuiltinEntityParser as _BuiltinEntityParser)

        parser = _BuiltinEntityParser.from_path(path)
        return cls(parser)

    @classmethod
    def build(cls, dataset=None, language=None, gazetteer_entity_scope=None):
        from snips_nlu_parsers import get_supported_gazetteer_entities

        global _BUILTIN_ENTITY_PARSERS

        if dataset is not None:
            language = dataset[LANGUAGE]
            gazetteer_entity_scope = [entity for entity in dataset[ENTITIES]
                                      if is_gazetteer_entity(entity)]

        if language is None:
            raise ValueError("Either a dataset or a language must be provided "
                             "in order to build a BuiltinEntityParser")

        if gazetteer_entity_scope is None:
            gazetteer_entity_scope = []
        caching_key = _get_caching_key(language, gazetteer_entity_scope)
        if caching_key not in _BUILTIN_ENTITY_PARSERS:
            for entity in gazetteer_entity_scope:
                if entity not in get_supported_gazetteer_entities(language):
                    raise ValueError(
                        "Gazetteer entity '%s' is not supported in "
                        "language '%s'" % (entity, language))
            _BUILTIN_ENTITY_PARSERS[caching_key] = _build_builtin_parser(
                language, gazetteer_entity_scope)
        return _BUILTIN_ENTITY_PARSERS[caching_key]


def _build_builtin_parser(language, gazetteer_entities):
    from snips_nlu_parsers import BuiltinEntityParser as _BuiltinEntityParser

    with temp_dir() as serialization_dir:
        gazetteer_entity_parser = None
        if gazetteer_entities:
            gazetteer_entity_parser = _build_gazetteer_parser(
                serialization_dir, gazetteer_entities, language)

        metadata = {
            "language": language.upper(),
            "gazetteer_parser": gazetteer_entity_parser
        }
        metadata_path = serialization_dir / "metadata.json"
        with metadata_path.open("w", encoding="utf-8") as f:
            f.write(json_string(metadata))
        parser = _BuiltinEntityParser.from_path(serialization_dir)
        return BuiltinEntityParser(parser)


def _build_gazetteer_parser(target_dir, gazetteer_entities, language):
    from snips_nlu_parsers import get_builtin_entity_shortname

    gazetteer_parser_name = "gazetteer_entity_parser"
    gazetteer_parser_path = target_dir / gazetteer_parser_name
    gazetteer_parser_metadata = []
    for ent in sorted(gazetteer_entities):
        # Fetch the compiled parser in the resources
        source_parser_path = find_gazetteer_entity_data_path(language, ent)
        short_name = get_builtin_entity_shortname(ent).lower()
        target_parser_path = gazetteer_parser_path / short_name
        parser_metadata = {
            "entity_identifier": ent,
            "entity_parser": short_name
        }
        gazetteer_parser_metadata.append(parser_metadata)
        # Copy the single entity parser
        shutil.copytree(str(source_parser_path), str(target_parser_path))
    # Dump the parser metadata
    gazetteer_entity_parser_metadata = {
        "parsers_metadata": gazetteer_parser_metadata
    }
    gazetteer_parser_metadata_path = gazetteer_parser_path / "metadata.json"
    with gazetteer_parser_metadata_path.open("w", encoding="utf-8") as f:
        f.write(json_string(gazetteer_entity_parser_metadata))
    return gazetteer_parser_name


def is_builtin_entity(entity_label):
    from snips_nlu_parsers import get_all_builtin_entities

    return entity_label in get_all_builtin_entities()


def is_gazetteer_entity(entity_label):
    from snips_nlu_parsers import get_all_gazetteer_entities

    return entity_label in get_all_gazetteer_entities()


def find_gazetteer_entity_data_path(language, entity_name):
    for directory in DATA_PATH.iterdir():
        if not directory.is_dir():
            continue
        metadata_path = directory / "metadata.json"
        if not metadata_path.exists():
            continue
        with metadata_path.open(encoding="utf8") as f:
            metadata = json.load(f)
        if metadata.get("entity_name") == entity_name \
                and metadata.get("language") == language:
            return directory / metadata["data_directory"]
    raise FileNotFoundError(
        "No data found for the '{e}' builtin entity in language '{lang}'. "
        "You must download the corresponding resources by running "
        "'python -m snips_nlu download-entity {e} {lang}' before you can use "
        "this builtin entity.".format(e=entity_name, lang=language))


def _get_caching_key(language, entity_scope):
    tuple_key = (language,)
    tuple_key += tuple(entity for entity in sorted(entity_scope))
    return tuple_key