Source code for vis4d.common.distributed

# mypy: disable-error-code=misc
"""This module contains utilities for multiprocess parallelism."""
from __future__ import annotations

import logging
import os
import pickle
import shutil
import tempfile
from collections import OrderedDict
from functools import wraps
from typing import Any

import cloudpickle
import torch
import torch.distributed as dist
from torch import nn
from torch.distributed import broadcast_object_list
from torch.nn.parallel import DataParallel, DistributedDataParallel

from vis4d.common import ArgsType, DictStrAny, GenericFunc


[docs] class PicklableWrapper: # mypy: disable=line-too-long """Wrap an object to make it more picklable. Note that it uses heavy weight serialization libraries that are slower than pickle. It's best to use it only on closures (which are usually not picklable). This is a simplified version of https://github.com/joblib/joblib/blob/master/joblib/externals/loky/cloudpickle_wrapper.py """ def __init__(self, obj: Any | PicklableWrapper) -> None: # type: ignore """Creates an instance of the class.""" while isinstance(obj, PicklableWrapper): # Wrapping an object twice is no-op obj = obj._obj self._obj: Any = obj
[docs] def __reduce__(self) -> tuple[Any, tuple[bytes]]: """Reduce.""" s = cloudpickle.dumps(self._obj) return cloudpickle.loads, (s,)
[docs] def __call__(self, *args: ArgsType, **kwargs: ArgsType) -> Any: """Call.""" return self._obj(*args, **kwargs)
[docs] def __getattr__(self, attr: str) -> Any: """Get attribute. Ensure that the wrapped object can be used seamlessly as the previous object. """ if attr not in ["_obj"]: return getattr(self._obj, attr) return getattr(self, attr)
# no coverage for these functions, since we don't unittest distributed setting
[docs] def get_world_size() -> int: # pragma: no cover """Get the world size (number of processes) of torch.distributed. Returns: int: The world size. """ if os.environ.get("WORLD_SIZE", None): return int(os.environ["WORLD_SIZE"]) # In interactive job not using slurm ntasks if os.environ.get("SLURM_JOB_NAME", None) != "bash": if os.environ.get("SLURM_NTASKS", None): return int(os.environ["SLURM_NTASKS"]) return 1
[docs] def get_rank() -> int: # pragma: no cover """Get the global rank of the current process in torch.distributed. Returns: int: The global rank. """ # For torchrun if os.environ.get("RANK", None): return int(os.environ["RANK"]) # Because pl don't set global rank, use local rank for interactive job and # slurm process id for submitted job if os.environ.get("SLURM_JOB_NAME", None) == "bash": return get_local_rank() if os.environ.get("SLURM_PROCID", None): return int(os.environ["SLURM_PROCID"]) # Return local rank return get_local_rank()
[docs] def get_local_rank() -> int: # pragma: no cover """Get the local rank of the current process in torch.distributed. Returns: int: The local rank. """ if os.environ.get("LOCAL_RANK", None): return int(os.environ["LOCAL_RANK"]) if os.environ.get("SLURM_LOCALID", None): return int(os.environ["SLURM_LOCALID"]) return 0
[docs] def distributed_available() -> bool: # pragma: no cover """Check if torch.distributed is available. Returns: bool: Whether torch.distributed is available. """ return dist.is_available() and dist.is_initialized()
[docs] def synchronize() -> None: # pragma: no cover """Sync (barrier) among all processes when using distributed training.""" if not distributed_available(): return if get_world_size() == 1: return dist.barrier(group=dist.group.WORLD, device_ids=[get_local_rank()])
[docs] def broadcast(obj: Any, src: int = 0) -> Any: # pragma: no cover """Broadcast an object from a source to all processes.""" if not distributed_available(): return obj obj = [obj] rank = get_rank() if rank != src: obj = [None] broadcast_object_list(obj, src, group=dist.group.WORLD) return obj[0]
[docs] def serialize_to_tensor(data: Any) -> torch.Tensor: # pragma: no cover """Serialize arbitrary picklable data to a torch.Tensor. Args: data (Any): The data to serialize. Returns: torch.Tensor: The serialized data as a torch.Tensor. Raises: AssertionError: If the backend of torch.distributed is not gloo or nccl. """ backend = dist.get_backend() assert backend in { "gloo", "nccl", }, "_serialize_to_tensor only supports gloo and nccl backends." device = torch.device("cpu" if backend == "gloo" else "cuda") buffer = pickle.dumps(data) if len(buffer) > 1024**3: logger = logging.getLogger(__name__) logger.warning( "Rank %s tries all-gather %.2f GB of data on device %s", get_rank(), len(buffer) / (1024**3), device, ) storage = torch.ByteStorage.from_buffer(buffer) tensor = torch.ByteTensor(storage).to(device=device) return tensor
[docs] def rank_zero_only(func: GenericFunc) -> GenericFunc: """Allows the decorated function to be called only on global rank 0. Args: func(GenericFunc): The function to decorate. Returns: GenericFunc: The decorated function. """ @wraps(func) def wrapped_fn(*args: ArgsType, **kwargs: ArgsType) -> Any: rank = get_rank() if rank == 0: return func(*args, **kwargs) return None return wrapped_fn
[docs] def pad_to_largest_tensor( tensor: torch.Tensor, ) -> tuple[list[int], torch.Tensor]: # pragma: no cover """Pad tensor to largest size among the tensors in each process. Args: tensor: tensor to be padded. Returns: list[int]: size of the tensor, on each rank Tensor: padded tensor that has the max size """ world_size = get_world_size() assert ( world_size >= 1 ), "_pad_to_largest_tensor requires distributed setting!" local_size = torch.tensor( [tensor.numel()], dtype=torch.int64, device=tensor.device ) local_size_list = [local_size.clone() for _ in range(world_size)] dist.all_gather_object(local_size_list, local_size) size_list = [int(size.item()) for size in local_size_list] max_size = max(size_list) # we pad the tensor because torch all_gather does not support # gathering tensors of different shapes if local_size != max_size: padding = torch.zeros( (max_size - local_size,), dtype=torch.uint8, device=tensor.device ) tensor = torch.cat((tensor, padding), dim=0) return size_list, tensor
[docs] def all_gather_object_gpu( # type: ignore data: Any, rank_zero_return_only: bool = True ) -> list[Any] | None: # pragma: no cover """Run pl_module.all_gather on arbitrary picklable data. Args: data: any picklable object rank_zero_return_only: if results should only be returned on rank 0 Returns: list[Any]: list of data gathered from each process """ rank, world_size = get_rank(), get_world_size() if world_size == 1: return [data] # encode tensor = serialize_to_tensor(data) size_list, tensor = pad_to_largest_tensor(tensor) tensor_list = [tensor.clone() for _ in range(world_size)] dist.all_gather_object(tensor_list, tensor) # (world_size, N) if rank_zero_return_only and not rank == 0: return None # decode data_list = [] for size, tensor in zip(size_list, tensor_list): buffer = tensor.cpu().numpy().tobytes()[:size] data_list.append(pickle.loads(buffer)) return data_list
[docs] def create_tmpdir( rank: int, tmpdir: None | str = None ) -> str: # pragma: no cover """Create and distribute a temporary directory across all processes.""" if tmpdir is not None: os.makedirs(tmpdir, exist_ok=True) return tmpdir if rank == 0: # create a temporary directory default_tmpdir = tempfile.gettempdir() if default_tmpdir is not None: dist_tmpdir = os.path.join(default_tmpdir, ".dist_tmp") else: dist_tmpdir = ".dist_tmp" os.makedirs(dist_tmpdir, exist_ok=True) tmpdir = tempfile.mkdtemp(dir=dist_tmpdir) else: tmpdir = None return broadcast(tmpdir)
[docs] def all_gather_object_cpu( # type: ignore data: Any, tmpdir: None | str = None, rank_zero_return_only: bool = True, ) -> list[Any] | None: # pragma: no cover """Share arbitrary picklable data via file system caching. Args: data: any picklable object. tmpdir: Save path for temporary files. If None, safely create tmpdir. rank_zero_return_only: if results should only be returned on rank 0 Returns: list[Any]: list of data gathered from each process. """ rank, world_size = get_rank(), get_world_size() if world_size == 1: return [data] # make tmp dir tmpdir = create_tmpdir(rank, tmpdir) # encode & save with open(os.path.join(tmpdir, f"part_{rank}.pkl"), "wb") as f: pickle.dump(data, f) synchronize() if rank_zero_return_only and not rank == 0: return None # load & decode data_list = [] for i in range(world_size): with open(os.path.join(tmpdir, f"part_{i}.pkl"), "rb") as f: data_list.append(pickle.load(f)) # remove dir if not rank_zero_return_only: # wait for all processes to finish loading before removing tmpdir synchronize() if rank == 0: shutil.rmtree(tmpdir) return data_list
[docs] def reduce_mean(tensor: torch.Tensor) -> torch.Tensor: """Obtain the mean of tensor on different GPUs.""" if not (dist.is_available() and dist.is_initialized()): return tensor tensor = tensor.clone() dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM) return tensor
[docs] def obj2tensor( pyobj: Any, device: torch.device = torch.device("cuda") ) -> torch.Tensor: """Serialize picklable python object to tensor. Args: pyobj (Any): Any picklable python object. device (torch.device): Device to put on. Defaults to "cuda". """ storage = torch.ByteStorage.from_buffer(pickle.dumps(pyobj)) return torch.ByteTensor(storage).to(device=device)
[docs] def tensor2obj(tensor: torch.Tensor) -> Any: """Deserialize tensor to picklable python object. Args: tensor (torch.Tensor): Tensor to be deserialized. """ return pickle.loads(tensor.cpu().numpy().tobytes())
[docs] def all_reduce_dict( py_dict: DictStrAny, reduce_op: str = "sum", to_float: bool = True ) -> DictStrAny: # pragma: no cover """Apply all reduce function for python dict object. The code is modified from https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/allreduce_norm.py. NOTE: make sure that py_dict in different ranks has the same keys and the values should be in the same shape. Currently only supports NCCL backend. Args: py_dict (DictStrAny): Dict to be applied all reduce op. reduce_op (str): Operator, could be 'sum' or 'mean'. Default: 'sum'. to_float (bool): Whether to convert all values of dict to float. Default: True. Returns: DictStrAny: reduced python dict object. """ world_size = get_world_size() if world_size == 1: return py_dict # all reduce logic across different devices. py_key = list(py_dict.keys()) if not isinstance(py_dict, OrderedDict): py_key_tensor = obj2tensor(py_key) dist.broadcast(py_key_tensor, src=0) py_key = tensor2obj(py_key_tensor) tensor_shapes = [py_dict[k].shape for k in py_key] tensor_numels = [py_dict[k].numel() for k in py_key] if to_float: flatten_tensor = torch.cat( [py_dict[k].flatten().float() for k in py_key] ) else: flatten_tensor = torch.cat([py_dict[k].flatten() for k in py_key]) dist.all_reduce(flatten_tensor, op=dist.ReduceOp.SUM) if reduce_op == "mean": flatten_tensor /= world_size split_tensors = [ x.reshape(shape) for x, shape in zip( torch.split(flatten_tensor, tensor_numels), tensor_shapes ) ] out_dict: DictStrAny = dict(zip(py_key, split_tensors)) if isinstance(py_dict, OrderedDict): out_dict = OrderedDict(out_dict) return out_dict
[docs] def is_module_wrapper(module: nn.Module) -> bool: """Checks recursively if a module is wrapped. Two modules are regarded as wrapper: DataParallel, DistributedDataParallel. Args: module (nn.Module): The module to be checked. Returns: bool: True if the input module is a module wrapper. """ if isinstance(module, (DataParallel, DistributedDataParallel)): return True if any(is_module_wrapper(child) for child in module.children()): return True return False