Source code for vis4d.config.replicator

"""Replication methods to perform different parameters sweeps."""

from __future__ import annotations

import re
from collections.abc import Callable, Generator, Iterable
from queue import Queue
from typing import Any

from ml_collections import ConfigDict

from vis4d.common.typing import ArgsType


[docs] def iterable_sampler( # type: ignore samples: Iterable[Any], ) -> Callable[[], Generator[Any, None, None]]: """Creates a sampler from an iterable. This fuction returns a method that returns a generator that iterates over all values provided in the 'samples' iterable. Args: samples (Iterable[Any]): Iterable over which to sample. Returns: Callable[[], Generator[Any, None, None]]: Function that returns a generator which iterates over all elements in the given iterable. """ def _sampler() -> Generator[float, None, None]: yield from samples return _sampler
[docs] def linspace_sampler( min_value: float, max_value: float, n_steps: int = 1 ) -> Callable[[], Generator[float, None, None]]: """Creates a linear space sampler. This fuction returns a method that returns a generator that iterates from min_value to max_value in n_steps. Args: min_value (float): Lower value bound max_value (float): Upper value bound n_steps (int, optional): Number of steps. Defaults to 1. Returns: Callable[[], Generator[float, None, None]]: Function that returns a generator which iterates from min to max in n_steps. """ def _sampler() -> Generator[float, None, None]: for i in range(n_steps): yield min_value + i / max(n_steps - 1, 1) * (max_value - min_value) return _sampler
[docs] def logspace_sampler( min_exponent: float, max_exponent: float, n_steps: int = 1, base: float = 10, ) -> Callable[[], Generator[float, None, None]]: """Creates a logarithmic space sampler. This fuction returns a method that returns a generator that iterates from base^min_exponent to base^max_exponent in n_steps. Args: min_exponent (float): Lower value bound max_exponent (float): Upper value bound n_steps (int, optional): Number of steps. Defaults to 1. base (float): Base value for exponential calculation. Defaults to 10. Returns: Callable[[], Generator[float, None, None]]: Function that returns a generator which iterates from 10^min to 10^max in n_steps. """ def _sampler() -> Generator[float, None, None]: for exp in linspace_sampler(min_exponent, max_exponent, n_steps)(): yield base**exp return _sampler
[docs] def replicate_config( # type: ignore configuration: ConfigDict, sampling_args: list[ tuple[str, Callable[[], Generator[Any, None, None]] | Iterable[Any]] ], method: str = "grid", fstring="", ) -> Generator[ConfigDict, None, None]: """Function used to replicate a config. This function takes a ConfigDict and a dict with (key: generator) entries. It will yield, multiple modified config dicts assigned with different values defined in the sampling_args dictionary. Example: >>> config = ConfigDict({"trainer": {"lr": 0.2, "bs": 2}}) >>> replicated_config = replicate_config(config, >>> sampling_args = [("trainer.lr", linspace_sampler(0.01, 0.1, 3))], >>> method = "grid" >>> ) >>> for c in replicated_config: >>> print(c) Will print: trainer: bs: 2 lr: 0.01 trainer: bs: 2 lr: 0.055 trainer: bs: 2 lr: 0.1 NOTE, the config dict instance that will be returned will be mutable and continuously updated to preserve references. In the code above, executing >>> print(list(replicated_config)) Prints: trainer: bs: 2 lr: 0.1 trainer: bs: 2 lr: 0.1 trainer: bs: 2 lr: 0.1 Please resolve the reference and copy the dict if you need a list: >>> print([c.copy_and_resolve_references() for c in replicated_config]) Args: configuration (ConfigDict): Configuration to replicate sampling_args (dict[str, Callable[[], Any]]): The queue, that contains (key, iterator) pairs where the iterator yields the values which should be assigned to the key. method (str): What replication method to use. Currently only 'grid' and 'linear' is supported. Grid combines the sampling arguments in a grid wise fashion ([1,2],[3,4] -> [1,3],[1,4],[2,3],[2,4]) whereas 'linear' will only select elements at the same index ([1,2],[3,4]->[1,3],[2,4]). fstring (str): Format string to use for the experiment name. Defaults to an empty string. The format string will be resolved with the values of the config dict. For example, if the config dict contains a key 'trainer.lr' with value 0.1, the format string '{trainer.lr}' will be resolved to '0.1'. Raises: ValueError: if the replication method is unknown. """ sampling_queue: Queue[ # type: ignore tuple[str, Callable[[], Generator[Any, None, None]]] ] = Queue() for key, value in sampling_args: # Convert Iterable to a callable generator if isinstance(value, Iterable): generator: Callable[[], Generator[ArgsType, None, None]] = ( lambda value=value: (i for i in value) # type: ignore ) sampling_queue.put((key, generator)) else: sampling_queue.put((key, value)) if method == "grid": replicated = _replicate_config_grid(configuration, sampling_queue) elif method == "linear": replicated = _replicate_config_linear(configuration, sampling_queue) else: raise ValueError(f"Unknown replication method {method}") original_name = configuration.experiment_name for config in replicated: # Update config name config.experiment_name = ( f"{original_name}_{_resolve_fstring(fstring, config)}" ) yield config
def _resolve_fstring(fstring: str, config: ConfigDict) -> str: """Resolves a format string with the values from the config. This function takes a format string and replaces all the keys with the values from the config. The keys are expected to be in the format {key} or {key:format}. This function may fail if the format string contains a key that is not present in the config. It will also fail if the format string contains a key that is not a valid python identifier. Args: fstring (str): The format string. E.g. "lr_{params.lr}". config (ConfigDict): The config dict. E.g. {"params": {"lr": 0.1}}. Returns: str: The resolved format string. E.g. "lr_0.1 """ # match everything between { and ':' or '}' pattern = re.compile(r"{([^:}]+)") required_params = {p.strip() for p in pattern.findall(fstring)} format_dict: dict[str, str] = {} for param in required_params: # Maks out '.' which is invalid for .format() call new_param_name = param.replace(".", "_") format_dict[new_param_name] = getattr(config, param) fstring = fstring.replace(param, new_param_name) # apply formatting return fstring.format(**format_dict) def _replicate_config_grid( # type: ignore configuration: ConfigDict, sampling_queue: Queue[ tuple[str, Callable[[], Generator[Any, None, None]]] ], ) -> Generator[ConfigDict, None, None]: """Internal function used to replicate a config. This function takes a ConfigDict and a queue with (key, generator) entries. It will then recursively call itself and yield the ConfigDict with updated values for every key in the sampling_queue. Each key combination will be yielded exactly once, resulting in prod(len(generator)) entires. For example, a parameter sweep using 'lr: [0,1], bs: [8, 16]' will yield [0, 8], [0, 16], [0, 8], [1, 16] as combinations. Args: configuration (ConfigDict): Configuration to replicate sampling_queue (Queue[tuple[str, Callable[[], Any]]]): The queue, that contains (key, iterator) pairs where the iterator yields the values which should be assigned to the key. Yields: ConfigDict: Replicated configuration with updated key values. """ # Termination criterion reached, We processed all samplers if sampling_queue.empty(): yield configuration return # Get next key we want to replicate (key_name, sampler) = sampling_queue.get() # Iterate over all possible assignement values for this key for value in sampler(): # Update value ignoring type errors # (e.g. user set default lr to 1 instead 1.0 would # otherwise give a type error (float != int)) with configuration.ignore_type(): configuration.update_from_flattened_dict({key_name: value}) # Let the other samplers process the remaining keys yield from _replicate_config_grid(configuration, sampling_queue) # Add back this sampler for next round sampling_queue.put((key_name, sampler)) def _replicate_config_linear( # type: ignore configuration: ConfigDict, sampling_queue: Queue[ tuple[str, Callable[[], Generator[Any, None, None]]] ], current_position: int | None = None, ) -> Generator[ConfigDict, None, None]: """Internal function used to replicate a config in a linear way. This function takes a ConfigDict and a queue with (key, generator) entries. It will then recursively call itself and yield the ConfigDict with updated values for every key in the sampling_queue. For example, a parameter sweep using 'lr: [0,1], bs: [8, 16]' will yield [0, 8], [1, 16] as combinations. Args: configuration (ConfigDict): Configuration to replicate sampling_queue (Queue[tuple[str, Callable[[], Any]]]): The queue, that contains (key, iterator) pairs where the iterator yields the values which should be assigned to the key. current_position (int, optional): Current position of the top level sampling module. Used and updated internally. Yields: ConfigDict: Replicated configuration with updated key values. """ # Termination criterion reached, We processed all samplers if sampling_queue.empty(): yield configuration return # Get next key we want to replicate (key_name, sampler) = sampling_queue.get() is_top_level = False if current_position is None: is_top_level = True # This is the top level call. current_position = 0 # Iterate over all possible assignement values for this key for idx, value in enumerate(sampler()): if not is_top_level and idx != current_position: continue # only yield entry that matches requested position # Update value ignoring type errors # (e.g. user set default lr to 1 instead 1.0 would # otherwise give a type error (float != int)) with configuration.ignore_type(): configuration.update_from_flattened_dict({key_name: value}) # Let the other samplers process the remaining keys yield from _replicate_config_linear( configuration, sampling_queue, current_position ) if is_top_level: current_position += 1 # Add back this sampler for next round sampling_queue.put((key_name, sampler))