vis4d.common.ckpt¶
This module contains convenience functions for checkpoint loading.
The code is based on https://github.com/open-mmlab/mmcv/
Functions
Get full URLs of all torchvision paths. |
|
|
Load checkpoint through HTTP or HTTPS scheme path. |
|
Load checkpoint by local file path. |
|
Load checkpoint through the file path prefixed with torchvision. |
|
Load checkpoint from a file or URI. |
|
Load state_dict to a module. |
Classes
A general checkpoint loader to manage all schemes. |
- class CheckpointLoader[source]¶
A general checkpoint loader to manage all schemes.
- classmethod load_checkpoint(filename, map_location=None)[source]¶
Load checkpoint through URL scheme path.
- Parameters:
filename (str) – checkpoint file name with given prefix
map_location (str, optional) – Same as
torch.load()
. Default: None
- Returns:
The loaded checkpoint.
- Return type:
dict or OrderedDict
- classmethod register_scheme(prefixes, force=False)[source]¶
Register a loader to CheckpointLoader.
This method should be used as a decorator.
- Parameters:
prefixes (str or Sequence[str]) – The register prefix of the loader.
force (bool, optional) – Whether to override the loader if the prefix has already been registered. Defaults to False.
- Return type:
Callable
[[Callable
[[str
,Union
[str
,device
,None
]],Dict
[str
,Any
]]],Callable
[[str
,Union
[str
,device
,None
]],Dict
[str
,Any
]]]
- get_torchvision_models()[source]¶
Get full URLs of all torchvision paths.
Requires torchvision >= 0.14.0a0.
- Return type:
dict
[str
,str
]
- load_from_http(filename, map_location=None)[source]¶
Load checkpoint through HTTP or HTTPS scheme path.
In distributed setting, this function only download checkpoint at local rank 0.
- Parameters:
filename (str) – checkpoint file path with modelzoo or torchvision prefix
map_location (str, optional) – Same as
torch.load()
.
- Returns:
The loaded checkpoint.
- Return type:
TorchCheckpoint
- load_from_local(filename, map_location=None)[source]¶
Load checkpoint by local file path.
- Parameters:
filename (str) – local checkpoint file path
map_location (str, optional) – Same as
torch.load()
.
- Raises:
FileNotFoundError – If file not found.
- Returns:
The loaded checkpoint.
- Return type:
TorchCheckpoint
- load_from_torchvision(filename, map_location=None)[source]¶
Load checkpoint through the file path prefixed with torchvision.
- Parameters:
filename (str) – checkpoint file path with modelzoo or torchvision prefix
map_location (str, optional) – Same as
torch.load()
.’
- Returns:
The loaded checkpoint.
- Return type:
dict or OrderedDict
- load_model_checkpoint(model, weights, strict=False, rev_keys=None, map_location='cpu')[source]¶
Load checkpoint from a file or URI.
- Parameters:
model (Module) – Module to load checkpoint.
weights (str) – Accept local filepath, URL, or e.g.``torchvision://xxx``
strict (bool) – Whether to allow different params for the model and checkpoint.
rev_keys (tuple[tuple[str, str]]) – A tuple of customized keywords to modify the state_dict in checkpoint. Each item is a (pattern, replacement) pair of the regular expression operations. Default: strip the prefix ‘module.’ by [(r’^module.’, ‘’)].
map_location (str | torch.device | None) – Same as
torch.load()
. Default: ‘cpu’.
- Return type:
None
- load_state_dict(module, state_dict, strict=False)[source]¶
Load state_dict to a module.
This method is modified from
torch.nn.Module.load_state_dict()
. Default value forstrict
is set toFalse
and the message for param mismatch will be shown even if strict is False.- Raises:
RuntimeError – If strict, it will raise a runtime error if module and state_dict do not match completely.
- Parameters:
module (Module) – Module that receives the state_dict.
state_dict (dict or OrderedDict) – Weights.
strict (bool) – whether to strictly enforce that the keys in
state_dict
match the keys returned by this module’sstate_dict()
function. Default:False
.
- Return type:
None