aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/pipeline/processing_unit.py
blob: 19284701ce5a8ab555d907c743f1eb131f1fc516 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
from __future__ import unicode_literals

import io
import json
import shutil
from abc import ABCMeta, abstractmethod, abstractproperty
from builtins import str, bytes
from pathlib import Path

from future.utils import with_metaclass

from snips_inference_agl.common.abc_utils import abstractclassmethod, classproperty
from snips_inference_agl.common.io_utils import temp_dir, unzip_archive
from snips_inference_agl.common.registrable import Registrable
from snips_inference_agl.common.utils import (
    json_string, check_random_state)
from snips_inference_agl.constants import (
    BUILTIN_ENTITY_PARSER, CUSTOM_ENTITY_PARSER, CUSTOM_ENTITY_PARSER_USAGE,
    RESOURCES, LANGUAGE, RANDOM_STATE)
from snips_inference_agl.entity_parser import (
    BuiltinEntityParser, CustomEntityParser, CustomEntityParserUsage)
from snips_inference_agl.exceptions import LoadingError
from snips_inference_agl.pipeline.configs import ProcessingUnitConfig
from snips_inference_agl.pipeline.configs.config import DefaultProcessingUnitConfig
from snips_inference_agl.resources import load_resources


class ProcessingUnit(with_metaclass(ABCMeta, Registrable)):
    """Abstraction of a NLU pipeline unit

    Pipeline processing units such as intent parsers, intent classifiers and
    slot fillers must implement this class.

     A :class:`ProcessingUnit` is associated with a *config_type*, which
    represents the :class:`.ProcessingUnitConfig` used to initialize it.
    """

    def __init__(self, config, **shared):
        if config is None:
            self.config = self.default_config()
        elif isinstance(config, ProcessingUnitConfig):
            self.config = config
        elif isinstance(config, dict):
            self.config = self.config_type.from_dict(config)
        else:
            raise ValueError("Unexpected config type: %s" % type(config))
        if self.config is not None:
            self.config.set_unit_name(self.unit_name)
        self.builtin_entity_parser = shared.get(BUILTIN_ENTITY_PARSER)
        self.custom_entity_parser = shared.get(CUSTOM_ENTITY_PARSER)
        self.resources = shared.get(RESOURCES)
        self.random_state = check_random_state(shared.get(RANDOM_STATE))

    @classproperty
    def config_type(cls):  # pylint:disable=no-self-argument
        return DefaultProcessingUnitConfig

    @classmethod
    def default_config(cls):
        config = cls.config_type()  # pylint:disable=no-value-for-parameter
        config.set_unit_name(cls.unit_name)
        return config

    @classproperty
    def unit_name(cls):  # pylint:disable=no-self-argument
        return ProcessingUnit.registered_name(cls)

    @classmethod
    def from_config(cls, unit_config, **shared):
        """Build a :class:`ProcessingUnit` from the provided config"""
        unit = cls.by_name(unit_config.unit_name)
        return unit(unit_config, **shared)

    @classmethod
    def load_from_path(cls, unit_path, unit_name=None, **shared):
        """Load a :class:`ProcessingUnit` from a persisted processing unit
        directory

        Args:
            unit_path (str or :class:`pathlib.Path`): path to the persisted
                processing unit
            unit_name (str, optional): Name of the processing unit to load.
                By default, the unit name is assumed to be stored in a
                "metadata.json" file located in the directory at unit_path.

        Raises:
            LoadingError: when unit_name is None and no metadata file is found
                in the processing unit directory
        """
        unit_path = Path(unit_path)
        if unit_name is None:
            metadata_path = unit_path / "metadata.json"
            if not metadata_path.exists():
                raise LoadingError(
                    "Missing metadata for processing unit at path %s"
                    % str(unit_path))
            with metadata_path.open(encoding="utf8") as f:
                metadata = json.load(f)
            unit_name = metadata["unit_name"]
        unit = cls.by_name(unit_name)
        return unit.from_path(unit_path, **shared)

    @classmethod
    def get_config(cls, unit_config):
        """Returns the :class:`.ProcessingUnitConfig` corresponding to
        *unit_config*"""
        if isinstance(unit_config, ProcessingUnitConfig):
            return unit_config
        elif isinstance(unit_config, dict):
            unit_name = unit_config["unit_name"]
            processing_unit_type = cls.by_name(unit_name)
            return processing_unit_type.config_type.from_dict(unit_config)
        elif isinstance(unit_config, (str, bytes)):
            unit_name = unit_config
            unit_config = {"unit_name": unit_name}
            processing_unit_type = cls.by_name(unit_name)
            return processing_unit_type.config_type.from_dict(unit_config)
        else:
            raise ValueError(
                "Expected `unit_config` to be an instance of "
                "ProcessingUnitConfig or dict or str but found: %s"
                % type(unit_config))

    @abstractproperty
    def fitted(self):
        """Whether or not the processing unit has already been trained"""
        pass

    def load_resources_if_needed(self, language):
        if self.resources is None or self.fitted:
            required_resources = None
            if self.config is not None:
                required_resources = self.config.get_required_resources()
            self.resources = load_resources(language, required_resources)

    def fit_builtin_entity_parser_if_needed(self, dataset):
        # We only fit a builtin entity parser when the unit has already been
        # fitted or if the parser is none.
        # In the other cases the parser is provided fitted by another unit.
        if self.builtin_entity_parser is None or self.fitted:
            self.builtin_entity_parser = BuiltinEntityParser.build(
                dataset=dataset)
        return self

    def fit_custom_entity_parser_if_needed(self, dataset):
        # We only fit a custom entity parser when the unit has already been
        # fitted or if the parser is none.
        # In the other cases the parser is provided fitted by another unit.
        required_resources = self.config.get_required_resources()
        if not required_resources or not required_resources.get(
                CUSTOM_ENTITY_PARSER_USAGE):
            # In these cases we need a custom entity parser only to do the
            # final slot resolution step, which must be done without stemming.
            parser_usage = CustomEntityParserUsage.WITHOUT_STEMS
        else:
            parser_usage = required_resources[CUSTOM_ENTITY_PARSER_USAGE]

        if self.custom_entity_parser is None or self.fitted:
            self.load_resources_if_needed(dataset[LANGUAGE])
            self.custom_entity_parser = CustomEntityParser.build(
                dataset, parser_usage, self.resources)
        return self

    def persist_metadata(self, path, **kwargs):
        metadata = {"unit_name": self.unit_name}
        metadata.update(kwargs)
        metadata_json = json_string(metadata)
        with (path / "metadata.json").open(mode="w", encoding="utf8") as f:
            f.write(metadata_json)

    # @abstractmethod
    def persist(self, path):
        pass

    @abstractclassmethod
    def from_path(cls, path, **shared):
        pass