aboutsummaryrefslogtreecommitdiffstats
path: root/snips_inference_agl/common/registrable.py
blob: bbc7bdc62c54696d8e32eed6c55d4dc3ecbd6f7b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
# This module is largely inspired by the AllenNLP library
# See github.com/allenai/allennlp/blob/master/allennlp/common/registrable.py

from collections import defaultdict
from future.utils import iteritems

from snips_inference_agl.exceptions import AlreadyRegisteredError, NotRegisteredError


class Registrable(object):
    """
    Any class that inherits from ``Registrable`` gains access to a named
    registry for its subclasses. To register them, just decorate them with the
    classmethod ``@BaseClass.register(name)``.

    After which you can call ``BaseClass.list_available()`` to get the keys
    for the registered subclasses, and ``BaseClass.by_name(name)`` to get the
    corresponding subclass.

    Note that if you use this class to implement a new ``Registrable``
    abstract class, you must ensure that all subclasses of the abstract class
    are loaded when the module is loaded, because the subclasses register
    themselves in their respective files. You can achieve this by having the
    abstract class and all subclasses in the __init__.py of the module in
    which they reside (as this causes any import of either the abstract class
    or a subclass to load all other subclasses and the abstract class).
    """
    _registry = defaultdict(dict)

    @classmethod
    def register(cls, name, override=False):
        """Decorator used to add the decorated subclass to the registry of the
        base class

        Args:
            name (str): name use to identify the registered subclass
            override (bool, optional): this parameter controls the behavior in
                case where a subclass is registered with the same identifier.
                If True, then the previous subclass will be unregistered in
                profit of the new subclass.

        Raises:
            AlreadyRegisteredError: when ``override`` is False, while trying
                to register a subclass with a name already used by another
                registered subclass
        """
        registry = Registrable._registry[cls]

        def add_subclass_to_registry(subclass):
            # Add to registry, raise an error if key has already been used.
            if not override and name in registry:
                raise AlreadyRegisteredError(name, cls, registry[name])
            registry[name] = subclass
            return subclass

        return add_subclass_to_registry

    @classmethod
    def registered_name(cls, registered_class):
        for name, subclass in iteritems(Registrable._registry[cls]):
            if subclass == registered_class:
                return name
        raise NotRegisteredError(cls, registered_cls=registered_class)

    @classmethod
    def by_name(cls, name):
        if name not in Registrable._registry[cls]:
            raise NotRegisteredError(cls, name=name)
        return Registrable._registry[cls][name]

    @classmethod
    def list_available(cls):
        return list(Registrable._registry[cls].keys())