Source code for vis4d.config.config_dict

"""Config dict module."""

from __future__ import annotations

import importlib
from collections.abc import Callable, Iterable, Mapping
from typing import Any

import yaml
from ml_collections import ConfigDict, FieldReference, FrozenConfigDict

from vis4d.common.named_tuple import get_all_keys, is_namedtuple
from vis4d.common.typing import ArgsType


# NOTE: Most of these functions need to deal with unknown parameters and are
# therefore not strictly typed
[docs] class FieldConfigDict(ConfigDict): # type: ignore # pylint: disable=too-many-instance-attributes, line-too-long """A configuration dict which allows to access fields via dot notation. This class is a subclass of ConfigDict and overwrites the dot notation to return a FieldReference instead of a dict. For more information on the ConfigDict class, see: ml_collections.ConfigDict. Examples of using the ref and value mode: >>> config = FieldConfigDict({"a": 1, "b": 2}) >>> type(config.a) <class 'ml_collections.field_reference.FieldReference'> >>> config.value_mode() # Set the config to return values >>> type(config.a) <class 'int'> """ def __init__( # type: ignore self, initial_dictionary: Mapping[str, Any] | None = None, type_safe: bool = True, convert_dict: bool = True, ): """Creates an instance of FieldConfigDict. Args: initial_dictionary: May be one of the following: 1) dict. In this case, all values of initial_dictionary that are dictionaries are also be converted to ConfigDict. However, dictionaries within values of non-dict type are untouched. 2) ConfigDict. In this case, all attributes are uncopied, and only the top-level object (self) is re-addressed. This is the same behavior as Python dict, list, and tuple. 3) FrozenConfigDict. In this case, initial_dictionary is converted to a ConfigDict version of the initial dictionary for the FrozenConfigDict (reversing any mutability changes FrozenConfigDict made). type_safe: If set to True, once an attribute value is assigned, its type cannot be overridden without .ignore_type() context manager. convert_dict: If set to True, all dict used as value in the ConfigDict will automatically be converted to ConfigDict. """ super().__init__(initial_dictionary, type_safe, convert_dict) object.__setattr__(self, "_return_refs", True)
[docs] @classmethod def from_yaml(cls, path: str) -> FieldConfigDict: """Creates a config from a .yaml file. Args: path: The path to the .yaml file that should be loaded. """ return cls( yaml.load( open(path, "r", encoding="utf-8"), Loader=yaml.UnsafeLoader ) )
[docs] def to_yaml(self, **kwargs: ArgsType) -> str: """Returns a YAML representation of the object. ConfigDict serializes types of fields as well as the values of fields themselves. Deserializing the YAML representation hence requires using YAML's UnsafeLoader: ``` yaml.load(cfg.to_yaml(), Loader=yaml.UnsafeLoader) ``` or equivalently: ``` yaml.unsafe_load(cfg.to_yaml()) ``` Please see the PyYAML documentation and https://msg.pyyaml.org/load for more details on the consequences of this. Args: **kwargs: Keyword arguments for yaml.dump. Returns: YAML representation of the object. """ return copy_and_resolve_references(self.value_mode()).to_yaml(**kwargs)
[docs] def dump(self, output_path: str) -> None: """Writes the config to a .yaml file. Args: output_path: The path to the output file. """ with open(output_path, "w", encoding="utf-8") as file: file.write(self.to_yaml())
[docs] def set_ref_mode(self, ref_mode: bool) -> None: """Sets the config to return references instead of values.""" def _rec_resolve_iterable( # type: ignore iterable: Iterable[Any], cfgs: list[FieldConfigDict] ) -> None: """Recursively adds all FieldConfigDicts to a list.""" for item in iterable: if isinstance(item, FieldConfigDict): cfgs.append(item) elif isinstance(item, (list, tuple)): _rec_resolve_iterable(item, cfgs) elif isinstance(item, (dict, ConfigDict)): _rec_resolve_iterable(item.values(), cfgs) # Update value of this dict object.__setattr__(self, "_return_refs", ref_mode) # propagate to sub configs for value in self.values(): if isinstance(value, FieldConfigDict): value = value.value_mode() elif isinstance(value, (list, tuple, ConfigDict, dict)): cfgs: list[FieldConfigDict] = [] _rec_resolve_iterable(value, cfgs) for cfg in cfgs: cfg.set_ref_mode(ref_mode)
[docs] def ref_mode(self) -> FieldConfigDict: """Sets the config to return references instead of values.""" self.set_ref_mode(True) return self
[docs] def value_mode(self) -> FieldConfigDict: """Sets the config to return values instead of references.""" self.set_ref_mode(False) return self
[docs] def __getitem__(self, key: str) -> FieldReference: """Returns the reference for the given key.""" # private properties are always returned as values if self._return_refs: try: return super().get_ref(key) except ValueError: pass return super().__getitem__(key)
[docs] def resolve_class_name(clazz: type | Callable[Any, Any] | str) -> str: # type: ignore # pylint: disable=line-too-long """Resolves the full class name of the given class object, callable or str. This function takes a class object and returns the class name as a string. Args: clazz (type | Callable[[Any], Any] | str): The object to resolve the full path of. Returns: str: The full path of the given object. Raises: ValueError: If the given object is a lambda function. Examples: >>> class MyClass: pass >>> resolve_class_name(MyClass) '__main__.MyClass' >>> resolve_class_name("path.to.MyClass") 'path.to.MyClass' >>> def my_function(): pass >>> resolve_class_name(my_function) '__main__.my_function' """ if isinstance(clazz, str): return clazz if clazz.__name__ == "lambda": raise ValueError( "Resolving the full class path of lambda functions" "is not supported. Please define a inline function instead." ) module = clazz.__module__ if module is None or module == str.__class__.__module__: return clazz.__name__ return module + "." + clazz.__name__
[docs] def class_config( clazz: type | Callable[Any, Any] | str, # type: ignore **kwargs: ArgsType, ) -> ConfigDict: """Creates a configuration which can be instantiated as a class. This function creates a configuration dict which can be passed to 'instantiate_classes' to create a instance of the given class or functor. Example: >>> class_cfg_obj = class_config("your.module.Module", arg1="arg1", arg2=2) >>> print(class_cfg_obj) >>> # Prints : >>> class_path: your.module.Module >>> init_args: >>> arg1: arg1 >>> arg2: 2 >>> # instantiate object >>> inst_obj = instantiate_classes(class_cfg_obj) >>> print(type(inst_obj)) # -> Will print <class 'your.module.Module'> >>> # Example by directly passing objects: >>> class MyClass: >>> def __init__(self, name: str, age: int): >>> self.name = name >>> self.age = age >>> class_cfg_obj = class_config(MyClass, name="John", age= 25) >>> print(class_cfg_obj) >>> # Prints : >>> class_path: __main__.MyClass >>> init_args: >>> name: John >>> age: 25 >>> # instantiate object >>> inst_obj = instantiate_classes(class_cfg_obj) >>> print(type(inst_obj)) # -> Will print <class '__main__.MyClass'> >>> print(inst_obj.name) # -> Will print John Args: clazz (type | Callable[[Any], Any] | str): class type or functor or class string path. **kwargs (ArgsType): Kwargs to pass to the class constructor. Returns: ConfigDict: _description_ """ class_path = resolve_class_name(clazz) if class_path is None or len(kwargs) == 0: return ConfigDict({"class_path": class_path}) return ConfigDict( {"class_path": class_path, "init_args": ConfigDict(kwargs)} )
[docs] def delay_instantiation(instantiable: ConfigDict) -> ConfigDict: """Delays the instantiation of the given configuration object. This is a somewhat hacky way to delay the initialization of the optimizer configuration object. It works by replacing the class_path with _class_path which basically tells the instantiate_classes function to not instantiate the class. Instead, it returns a function that can be called to instantiate the class Args: instantiable (ConfigDict): The configuration object to delay the instantiation of. """ instantiable["_class_path"] = instantiable["class_path"] del instantiable["class_path"] return class_config(DelayedInstantiator, instantiable=instantiable)
[docs] class DelayedInstantiator: """Class that delays the instantiation of the given configuration object. This is a somewhat hacky way to delay the initialization of the optimizer configuration object. It works by replacing the class_path with _class_path which basically tells the instantiate_classes function to not instantiate the class. Instead, it returns a function that can be called to instantiate the class. Args: instantiable (ConfigDict): The configuration object to delay the instantiation of. """ def __init__(self, instantiable: ConfigDict) -> None: """Instantiates the DelayedInstantiator.""" self.instantiable = instantiable
[docs] def __call__(self, **kwargs: ArgsType) -> Any: # type: ignore """Instantiates the configuration object.""" instantiable = class_config( self.instantiable["_class_path"], **self.instantiable.get("init_args", {}), ) return instantiate_classes(instantiable, **kwargs)
[docs] def instantiate_classes(data: ConfigDict | FieldReference, **kwargs: ArgsType) -> ConfigDict | Any: # type: ignore # pylint: disable=line-too-long """Instantiates all classes in a given ConfigDict. This function iterates over the configuration data and instantiates all classes. Class defintions are provided by a config dict that has the following structure: { 'data_path': 'path.to.my.class.Class', 'init_args': ConfigDict( { 'arg1': 'value1', 'arg2': 'value2', } ) } Args: data (ConfigDict | FieldReference): The general configuration object. **kwargs (ArgsType): Additional arguments to pass to the class constructor. Returns: ConfigDict | Any: The instantiated objects. """ if isinstance(data, FieldReference): # De-Reference the field reference data = data.get() assert isinstance(data, ConfigDict), "Data must be a ConfigDict." if isinstance(data, FieldConfigDict): data.value_mode() # make sure data is in value mode if len(kwargs) > 0: if "init_args" not in data: data["init_args"] = ConfigDict(kwargs) else: for k, v in kwargs.items(): data["init_args"][k] = v resolved_data = copy_and_resolve_references(data) instantiated_objects = _instantiate_classes(resolved_data) return instantiated_objects
[docs] def copy_and_resolve_references( # type: ignore data: Any, visit_map: dict[int, Any] | None = None ) -> Any: """Returns a ConfigDict copy with FieldReferences replaced by values. If the object is a FrozenConfigDict, the copy returned is also a FrozenConfigDict. However, note that FrozenConfigDict should already have FieldReferences resolved to values, so this method effectively produces a deep copy. Note: This method is overwritten from the ConfigDict class and allows to also resolve FieldReferences in list, tuple and dict. Args: data (Any): object to copy. visit_map (dict[int, Any]): A mapping from ConfigDict object ids to their copy. Method is recursive in nature, and it will call "copy_and_resolve_references(visit_map)" on each encountered object, unless it is already in visit_map. Returns: Any: ConfigDict copy with previous FieldReferences replaced by values. """ if isinstance(data, FieldReference): data = data.get() if is_namedtuple(data): return type(data)( **{ key: copy_and_resolve_references(getattr(data, key)) for key in get_all_keys(data) } ) if isinstance(data, (list, tuple)): return type(data)( copy_and_resolve_references(value, visit_map) for value in data ) if isinstance(data, dict): return { k: copy_and_resolve_references(v, visit_map) for k, v in data.items() } if not isinstance(data, ConfigDict): return data visit_map = visit_map or {} config_dict = ConfigDict() # copy attributes super(ConfigDict, config_dict).__setattr__( "_convert_dict", config_dict.convert_dict ) visit_map[id(config_dict)] = config_dict for key, value in data._fields.items(): if isinstance(value, FieldReference): value = value.get() if id(value) in visit_map: value = visit_map[id(value)] elif isinstance(value, ConfigDict): value = copy_and_resolve_references(value, visit_map) elif is_namedtuple(value): value = type(value)( **{ key: copy_and_resolve_references(getattr(value, key)) for key in get_all_keys(value) } ) elif isinstance(value, (list, tuple)): value = type(value)( copy_and_resolve_references(v, visit_map) for v in value ) elif isinstance(value, dict): value = { k: copy_and_resolve_references(v, visit_map) for k, v in value.items() } if isinstance(data, FrozenConfigDict): config_dict._frozen_setattr( # pylint:disable=protected-access key, value ) else: config_dict[key] = value # copy attributes super(ConfigDict, config_dict).__setattr__("_locked", data.is_locked) super(ConfigDict, config_dict).__setattr__("_type_safe", data.is_type_safe) return config_dict
def _get_index(data: Any) -> Any: # type: ignore """Internal function to generate a Sequence of indexes for a given object. Example: >>> [data[idx] for idx in _get_index(data)] Args: data (Any): The data entry to get an index for. Returns: Any: Iterable that can be used to index the data entry using e.g. [data[idx] for idx in _get_index(data)] """ if isinstance(data, (list, tuple)): return range(len(data)) return data def _instantiate_classes(data: Any) -> Any: # type: ignore """Instantiates all classes in a given data. Data could be ConfigDict, FieldReference, tuple, list or dict. This is the recursive implementation of the 'instantiate_classes'. This function iterates over the configuration data and instantiates all classes. Class defintions are provided by a config dict that has the following structure: { 'data_path': 'path.to.my.class.Class', 'init_args': ConfigDict( { 'arg1': 'value1', 'arg2': 'value2', } ) } Args: data (Any): The general configuration object. Returns: Any: The ConfigDict with all classes intialized. Or, if the top level element is a class config, the returned element will be the instantiated class. """ if isinstance(data, FieldReference): data = data.get() if not isinstance(data, (ConfigDict, dict, list, tuple)): return data for key in _get_index(data): value = data[key] if isinstance(value, FieldReference): value = value.get() if isinstance(value, (ConfigDict, dict)): if isinstance(data, ConfigDict): with data.ignore_type(): data[key] = _instantiate_classes(value) else: data[key] = _instantiate_classes(value) elif is_namedtuple(value): if isinstance(data, ConfigDict): with data.ignore_type(): data[key] = type(value)( **{ key: _instantiate_classes(getattr(value, key)) for key in get_all_keys(value) } ) else: data[key] = type(value)( **{ key: _instantiate_classes(getattr(value, key)) for key in get_all_keys(value) } ) elif isinstance(value, (list, tuple)): if isinstance(data, ConfigDict): with data.ignore_type(): data[key] = type(value)( _instantiate_classes(value[idx]) for idx in range(len(value)) ) else: data[key] = type(value)( _instantiate_classes(value[idx]) for idx in range(len(value)) ) # Instantiate classs if "class_path" in data and not isinstance(data["class_path"], ConfigDict): module_name, class_name = data["class_path"].rsplit(".", 1) init_args = data.get("init_args", {}) # Convert ConfigDict to normal dictionary if isinstance(init_args, ConfigDict): init_args = init_args.to_dict() module = importlib.import_module(module_name) # Instantiate class clazz = getattr(module, class_name)(**init_args) return clazz return data