"""This module contains convenience functions for checkpoint loading.

The code is based on

from __future__ import annotations

import os.path as osp
import re
from collections import OrderedDict
from typing import Callable, Union

import torch
import torchvision
from torch import nn
from torch.hub import load_state_dict_from_url as load_url

from vis4d.common import TorchCheckpoint
from vis4d.common.distributed import (
from vis4d.common.logging import rank_zero_info, rank_zero_warn

CheckpointLoadFunc = Callable[
    [str, Union[str, torch.device, None]], TorchCheckpoint

# Define mapping for specific model checkpoints
    "mmdet://": "",
    "mmseg://": "",
    "mmdet://": "syscv/mmdetection/master/configs/",
    "mmseg://": "open-mmlab/mmsegmentation/master/configs/",
    "mmdet://": "mmdetection-master/configs/",
    "mmseg://": "mmsegmentation-master/configs/",

[docs] def load_model_checkpoint( model: nn.Module, weights: str, strict: bool = False, rev_keys: None | list[tuple[str, str]] = None, map_location: str | torch.device | None = "cpu", ) -> None: """Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. weights (str): Accept local filepath, URL, or e.g.``torchvision://xxx`` strict (bool): Whether to allow different params for the model and checkpoint. rev_keys (tuple[tuple[str, str]]): A tuple of customized keywords to modify the state_dict in checkpoint. Each item is a (pattern, replacement) pair of the regular expression operations. Default: strip the prefix 'module.' by [(r'^module.', '')]. map_location (str | torch.device | None): Same as :func:`torch.load`. Default: 'cpu'. """ if rev_keys is None: # pragma: no cover rev_keys = [(r"^module\.", "")] if re.compile(r"^mm(det|seg)://").search(weights): pre = weights[:8] weights = MM_MODEL_MAP[pre] + weights.split(pre)[-1] _load_checkpoint( model, weights, map_location, strict=strict, revise_keys=rev_keys ) elif weights.startswith("bdd100k://"): weights = BDD100K_MODEL_PREFIX + weights.split("bdd100k://")[-1] _load_checkpoint( model, weights, map_location, strict=strict, revise_keys=rev_keys ) else: # pragma: no cover _load_checkpoint( model, weights, map_location, strict=strict, revise_keys=rev_keys )
[docs] class CheckpointLoader: """A general checkpoint loader to manage all schemes.""" _schemes: dict[str, CheckpointLoadFunc] = {} @classmethod def _register_scheme( cls, prefixes: str | tuple[str, ...], loader: CheckpointLoadFunc, force: bool = False, ) -> None: """Register a scheme.""" if isinstance(prefixes, str): prefixes = (prefixes,) assert isinstance(prefixes, (list, tuple)) for prefix in prefixes: if (prefix not in cls._schemes) or force: cls._schemes[prefix] = loader else: raise KeyError( f"{prefix} is already registered as a loader backend, " 'add "force=True" if you want to override it' ) # sort, longer prefixes take priority cls._schemes = OrderedDict( sorted(cls._schemes.items(), key=lambda t: t[0], reverse=True) )
[docs] @classmethod def register_scheme( cls, prefixes: str | tuple[str, ...], force: bool = False, ) -> Callable[[CheckpointLoadFunc], CheckpointLoadFunc]: """Register a loader to CheckpointLoader. This method should be used as a decorator. Args: prefixes (str or Sequence[str]): The register prefix of the loader. force (bool, optional): Whether to override the loader if the prefix has already been registered. Defaults to False. """ def _register( loader_cls: CheckpointLoadFunc, ) -> CheckpointLoadFunc: cls._register_scheme(prefixes, loader_cls, force=force) return loader_cls return _register
@classmethod def _get_checkpoint_loader(cls, path: str) -> CheckpointLoadFunc: """Finds a loader that supports the given path. Falls back to the local loader if no other loader is found, since it is registered with an empty prefix. Args: path (str): checkpoint path. Raises: ValueError: If the path cannot be matched to any prefix, raise an error. This should usually not happen, since the local loader is registered with an empty prefix. Returns: CheckpointLoadFunc: checkpoint loader. """ for prefix, func in cls._schemes.items(): if re.match(prefix, path) is not None: return func raise ValueError("Invalid path! No prefix matched.")
[docs] @classmethod def load_checkpoint( cls, filename: str, map_location: str | torch.device | None = None, ) -> TorchCheckpoint: """Load checkpoint through URL scheme path. Args: filename (str): checkpoint file name with given prefix map_location (str, optional): Same as :func:`torch.load`. Default: None Returns: dict or OrderedDict: The loaded checkpoint. """ checkpoint_loader = cls._get_checkpoint_loader(filename) class_name = checkpoint_loader.__name__ rank_zero_info( f"Load checkpoint from {class_name[10:]} path: {filename}" ) return checkpoint_loader(filename, map_location)
[docs] @CheckpointLoader.register_scheme(prefixes="") def load_from_local( filename: str, map_location: str | torch.device | None = None, ) -> TorchCheckpoint: """Load checkpoint by local file path. Args: filename (str): local checkpoint file path map_location (str, optional): Same as :func:`torch.load`. Raises: FileNotFoundError: If file not found. Returns: TorchCheckpoint: The loaded checkpoint. """ filename = osp.expanduser(filename) if not osp.isfile(filename): raise FileNotFoundError(f"{filename} can not be found.") checkpoint = torch.load(filename, map_location=map_location) return checkpoint
[docs] @CheckpointLoader.register_scheme(prefixes=("http://", "https://")) def load_from_http( filename: str, map_location: str | torch.device | None = None ) -> TorchCheckpoint: """Load checkpoint through HTTP or HTTPS scheme path. In distributed setting, this function only download checkpoint at local rank 0. Args: filename (str): checkpoint file path with modelzoo or torchvision prefix map_location (str, optional): Same as :func:`torch.load`. Returns: TorchCheckpoint: The loaded checkpoint. """ rank, world_size = get_rank(), get_world_size() if rank == 0: checkpoint = load_url(filename, map_location=map_location) if world_size > 1: synchronize() if rank > 0: checkpoint = load_url(filename, map_location=map_location) return checkpoint # pylint: disable=used-before-assignment
[docs] def get_torchvision_models() -> dict[str, str]: """Get full URLs of all torchvision paths. Requires torchvision >= 0.14.0a0. """ model_urls: dict[str, str] = {} weights_list = [ torchvision.models.get_model_weights(model) for model in torchvision.models.list_models(torchvision.models) ] for model_cls in weights_list: # The name of torchvision model weights classes ends with # `_Weights` such as `ResNet18_Weights`. However, some model weight # classes, such as `MNASNet0_75_Weights` does not have any urls in # torchvision 0.13.0 and cannot be iterated. Here we simply check # `DEFAULT` attribute to ensure the class is not empty. if not hasattr(model_cls, "DEFAULT"): continue # Since `cls.DEFAULT` can not be accessed by iterating cls, we set # default urls explicitly. cls_name = model_cls.__name__ cls_key = cls_name.replace("_Weights", "").lower() model_urls[f"{cls_key}.default"] = model_cls.DEFAULT.url for weight_enum in model_cls: cls_key = cls_name.replace("_Weights", "").lower() cls_key = f"{cls_key}.{}" model_urls[cls_key] = weight_enum.url return model_urls
[docs] @CheckpointLoader.register_scheme(prefixes="torchvision://") def load_from_torchvision( filename: str, map_location: str | torch.device | None = None ) -> TorchCheckpoint: """Load checkpoint through the file path prefixed with torchvision. Args: filename (str): checkpoint file path with modelzoo or torchvision prefix map_location (str, optional): Same as :func:`torch.load`.' Returns: dict or OrderedDict: The loaded checkpoint. """ model_urls = get_torchvision_models() model_name = filename[14:] # Support getting model urls in the same way as torchvision # `ResNet50_Weights.IMAGENET1K_V1` will be mapped to # resnet50.imagenet1k_v1. model_name = model_name.lower().replace("_weights", "") return load_from_http(model_urls[model_name], map_location)
[docs] def load_state_dict( module: nn.Module, state_dict: TorchCheckpoint, strict: bool = False ) -> None: """Load state_dict to a module. This method is modified from :meth:`torch.nn.Module.load_state_dict`. Default value for ``strict`` is set to ``False`` and the message for param mismatch will be shown even if strict is False. Raises: RuntimeError: If strict, it will raise a runtime error if module and state_dict do not match completely. Args: module (Module): Module that receives the state_dict. state_dict (dict or OrderedDict): Weights. strict (bool): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`~torch.nn.Module.state_dict` function. Default: ``False``. """ unexpected_keys: list[str] = [] all_missing_keys: list[str] = [] err_msg: list[str] = [] metadata = getattr(state_dict, "_metadata", None) state_dict = state_dict.copy() if metadata is not None: # pylint: disable=protected-access state_dict._metadata = metadata # type: ignore # use _load_from_state_dict to enable checkpoint version control def load(module: nn.Module, prefix: str = "") -> None: # recursively check parallel module in case that the model has a # complicated structure, e.g., nn.Module(nn.Module(DDP)) if is_module_wrapper(module): module = module.module local_metadata = ( {} if metadata is None else metadata.get(prefix[:-1], {}) ) module._load_from_state_dict( # pylint: disable=protected-access state_dict, prefix, local_metadata, True, all_missing_keys, unexpected_keys, err_msg, ) # pylint: disable=protected-access for name, child in module._modules.items(): if child is not None: # pylint: disable=not-callable load(child, prefix + name + ".") load(module) # break load->load reference cycle load = None # type: ignore # ignore "num_batches_tracked" of BN layers missing_keys = [ key for key in all_missing_keys if "num_batches_tracked" not in key ] if unexpected_keys: err_msg.append( "unexpected key in source " f'state_dict: {", ".join(unexpected_keys)}\n' ) if missing_keys: err_msg.append( f'missing keys in source state_dict: {", ".join(missing_keys)}\n' ) rank = get_rank() if len(err_msg) > 0 and rank == 0: err_msg.insert( 0, "The model and loaded state dict do not match exactly\n" ) err_msg = "\n".join(err_msg) # type: ignore if strict: raise RuntimeError(err_msg) rank_zero_warn(err_msg)
def _load_checkpoint( model: torch.nn.Module, filename: str, map_location: str | torch.device | None = None, strict: bool = False, revise_keys: tuple[tuple[str, str]] | list[tuple[str, str]] = ( (r"^module\.", ""), ), ) -> TorchCheckpoint: """Load checkpoint from a file or URI. Args: model (Module): Module to load checkpoint. filename (str): Accept local filepath, URL, ``torchvision://xxx``, ``open-mmlab://xxx``. Please refer to ``docs/`` for details. map_location (str): Same as :func:`torch.load`. strict (bool): Whether to allow different params for the model and checkpoint. revise_keys (tuple[tuple[str, str]]): A tuple of customized keywords to modify the state_dict in checkpoint. Each item is a (pattern, replacement) pair of the regular expression operations. Default: strip the prefix 'module.' by [(r'^module.', '')]. Raises: RuntimeError: If no state_dict is found in the checkpoint file. Returns: TorchCheckpoint: The loaded checkpoint. """ checkpoint = CheckpointLoader.load_checkpoint(filename, map_location) # OrderedDict is a subclass of dict if not isinstance(checkpoint, dict): raise RuntimeError( f"No state_dict found in checkpoint file {filename}" ) # get state_dict from checkpoint if "state_dict" in checkpoint: state_dict = checkpoint["state_dict"] elif "model" in checkpoint: state_dict = checkpoint["model"] else: state_dict = checkpoint # strip prefix of state_dict metadata = getattr(state_dict, "_metadata", OrderedDict()) for p, r in revise_keys: state_dict = OrderedDict( {re.sub(p, r, k): v for k, v in state_dict.items()} ) # Keep metadata in state_dict state_dict._metadata = metadata # pylint: disable=protected-access # load state_dict load_state_dict(model, state_dict, strict) return checkpoint