vis4d.common.ckpt

This module contains convenience functions for checkpoint loading.

The code is based on https://github.com/open-mmlab/mmcv/

Functions

get_torchvision_models()

Get full URLs of all torchvision paths.

load_from_http(filename[, map_location])

Load checkpoint through HTTP or HTTPS scheme path.

load_from_local(filename[, map_location])

Load checkpoint by local file path.

load_from_torchvision(filename[, map_location])

Load checkpoint through the file path prefixed with torchvision.

load_model_checkpoint(model, weights[, ...])

Load checkpoint from a file or URI.

load_state_dict(module, state_dict[, strict])

Load state_dict to a module.

Classes

CheckpointLoader()

A general checkpoint loader to manage all schemes.

class CheckpointLoader[source]

A general checkpoint loader to manage all schemes.

classmethod load_checkpoint(filename, map_location=None)[source]

Load checkpoint through URL scheme path.

Parameters:
  • filename (str) – checkpoint file name with given prefix

  • map_location (str, optional) – Same as torch.load(). Default: None

Returns:

The loaded checkpoint.

Return type:

dict or OrderedDict

classmethod register_scheme(prefixes, force=False)[source]

Register a loader to CheckpointLoader.

This method should be used as a decorator.

Parameters:
  • prefixes (str or Sequence[str]) – The register prefix of the loader.

  • force (bool, optional) – Whether to override the loader if the prefix has already been registered. Defaults to False.

Return type:

Callable[[Callable[[str, Union[str, device, None]], Dict[str, Any]]], Callable[[str, Union[str, device, None]], Dict[str, Any]]]

get_torchvision_models()[source]

Get full URLs of all torchvision paths.

Requires torchvision >= 0.14.0a0.

Return type:

dict[str, str]

load_from_http(filename, map_location=None)[source]

Load checkpoint through HTTP or HTTPS scheme path.

In distributed setting, this function only download checkpoint at local rank 0.

Parameters:
  • filename (str) – checkpoint file path with modelzoo or torchvision prefix

  • map_location (str, optional) – Same as torch.load().

Returns:

The loaded checkpoint.

Return type:

TorchCheckpoint

load_from_local(filename, map_location=None)[source]

Load checkpoint by local file path.

Parameters:
  • filename (str) – local checkpoint file path

  • map_location (str, optional) – Same as torch.load().

Raises:

FileNotFoundError – If file not found.

Returns:

The loaded checkpoint.

Return type:

TorchCheckpoint

load_from_torchvision(filename, map_location=None)[source]

Load checkpoint through the file path prefixed with torchvision.

Parameters:
  • filename (str) – checkpoint file path with modelzoo or torchvision prefix

  • map_location (str, optional) – Same as torch.load().’

Returns:

The loaded checkpoint.

Return type:

dict or OrderedDict

load_model_checkpoint(model, weights, strict=False, rev_keys=None, map_location='cpu')[source]

Load checkpoint from a file or URI.

Parameters:
  • model (Module) – Module to load checkpoint.

  • weights (str) – Accept local filepath, URL, or e.g.``torchvision://xxx``

  • strict (bool) – Whether to allow different params for the model and checkpoint.

  • rev_keys (tuple[tuple[str, str]]) – A tuple of customized keywords to modify the state_dict in checkpoint. Each item is a (pattern, replacement) pair of the regular expression operations. Default: strip the prefix ‘module.’ by [(r’^module.’, ‘’)].

  • map_location (str | torch.device | None) – Same as torch.load(). Default: ‘cpu’.

Return type:

None

load_state_dict(module, state_dict, strict=False)[source]

Load state_dict to a module.

This method is modified from torch.nn.Module.load_state_dict(). Default value for strict is set to False and the message for param mismatch will be shown even if strict is False.

Raises:

RuntimeError – If strict, it will raise a runtime error if module and state_dict do not match completely.

Parameters:
  • module (Module) – Module that receives the state_dict.

  • state_dict (dict or OrderedDict) – Weights.

  • strict (bool) – whether to strictly enforce that the keys in state_dict match the keys returned by this module’s state_dict() function. Default: False.

Return type:

None