vis4d.common.distributed

This module contains utilities for multiprocess parallelism.

Functions

all_gather_object_cpu(data[, tmpdir, ...])

Share arbitrary picklable data via file system caching.

all_gather_object_gpu(data[, ...])

Run pl_module.all_gather on arbitrary picklable data.

all_reduce_dict(py_dict[, reduce_op, to_float])

Apply all reduce function for python dict object.

broadcast(obj[, src])

Broadcast an object from a source to all processes.

create_tmpdir(rank[, tmpdir])

Create and distribute a temporary directory across all processes.

distributed_available()

Check if torch.distributed is available.

get_local_rank()

Get the local rank of the current process in torch.distributed.

get_rank()

Get the global rank of the current process in torch.distributed.

get_world_size()

Get the world size (number of processes) of torch.distributed.

is_module_wrapper(module)

Checks recursively if a module is wrapped.

obj2tensor(pyobj[, device])

Serialize picklable python object to tensor.

pad_to_largest_tensor(tensor)

Pad tensor to largest size among the tensors in each process.

rank_zero_only(func)

Allows the decorated function to be called only on global rank 0.

reduce_mean(tensor)

Obtain the mean of tensor on different GPUs.

serialize_to_tensor(data)

Serialize arbitrary picklable data to a torch.Tensor.

synchronize()

Sync (barrier) among all processes when using distributed training.

tensor2obj(tensor)

Deserialize tensor to picklable python object.

Classes

PicklableWrapper(obj)

Wrap an object to make it more picklable.

class PicklableWrapper(obj)[source]

Wrap an object to make it more picklable.

Note that it uses heavy weight serialization libraries that are slower than pickle. It’s best to use it only on closures (which are usually not picklable). This is a simplified version of https://github.com/joblib/joblib/blob/master/joblib/externals/loky/cloudpickle_wrapper.py

Creates an instance of the class.

__call__(*args, **kwargs)[source]

Call.

Return type:

Any

__getattr__(attr)[source]

Get attribute.

Ensure that the wrapped object can be used seamlessly as the previous object.

Return type:

Any

__reduce__()[source]

Reduce.

Return type:

tuple[Any, tuple[bytes]]

all_gather_object_cpu(data, tmpdir=None, rank_zero_return_only=True)[source]

Share arbitrary picklable data via file system caching.

Parameters:
  • data (Any) – any picklable object.

  • tmpdir (Optional[str]) – Save path for temporary files. If None, safely create tmpdir.

  • rank_zero_return_only (bool) – if results should only be returned on rank 0

Returns:

list of data gathered from each process.

Return type:

list[Any]

all_gather_object_gpu(data, rank_zero_return_only=True)[source]

Run pl_module.all_gather on arbitrary picklable data.

Parameters:
  • data (Any) – any picklable object

  • rank_zero_return_only (bool) – if results should only be returned on rank 0

Returns:

list of data gathered from each process

Return type:

list[Any]

all_reduce_dict(py_dict, reduce_op='sum', to_float=True)[source]

Apply all reduce function for python dict object.

The code is modified from https://github.com/Megvii-BaseDetection/YOLOX/blob/main/yolox/utils/allreduce_norm.py.

NOTE: make sure that py_dict in different ranks has the same keys and the values should be in the same shape. Currently only supports NCCL backend.

Parameters:
  • py_dict (DictStrAny) – Dict to be applied all reduce op.

  • reduce_op (str) – Operator, could be ‘sum’ or ‘mean’. Default: ‘sum’.

  • to_float (bool) – Whether to convert all values of dict to float. Default: True.

Returns:

reduced python dict object.

Return type:

DictStrAny

broadcast(obj, src=0)[source]

Broadcast an object from a source to all processes.

Return type:

Any

create_tmpdir(rank, tmpdir=None)[source]

Create and distribute a temporary directory across all processes.

Return type:

str

distributed_available()[source]

Check if torch.distributed is available.

Returns:

Whether torch.distributed is available.

Return type:

bool

get_local_rank()[source]

Get the local rank of the current process in torch.distributed.

Returns:

The local rank.

Return type:

int

get_rank()[source]

Get the global rank of the current process in torch.distributed.

Returns:

The global rank.

Return type:

int

get_world_size()[source]

Get the world size (number of processes) of torch.distributed.

Returns:

The world size.

Return type:

int

is_module_wrapper(module)[source]

Checks recursively if a module is wrapped.

Two modules are regarded as wrapper: DataParallel, DistributedDataParallel.

Parameters:

module (nn.Module) – The module to be checked.

Returns:

True if the input module is a module wrapper.

Return type:

bool

obj2tensor(pyobj, device=device(type='cuda'))[source]

Serialize picklable python object to tensor.

Parameters:
  • pyobj (Any) – Any picklable python object.

  • device (torch.device) – Device to put on. Defaults to “cuda”.

Return type:

Tensor

pad_to_largest_tensor(tensor)[source]

Pad tensor to largest size among the tensors in each process.

Parameters:

tensor (Tensor) – tensor to be padded.

Returns:

size of the tensor, on each rank Tensor: padded tensor that has the max size

Return type:

list[int]

rank_zero_only(func)[source]

Allows the decorated function to be called only on global rank 0.

Parameters:

func (GenericFunc) – The function to decorate.

Returns:

The decorated function.

Return type:

GenericFunc

reduce_mean(tensor)[source]

Obtain the mean of tensor on different GPUs.

Return type:

Tensor

serialize_to_tensor(data)[source]

Serialize arbitrary picklable data to a torch.Tensor.

Parameters:

data (Any) – The data to serialize.

Returns:

The serialized data as a torch.Tensor.

Return type:

torch.Tensor

Raises:

AssertionError – If the backend of torch.distributed is not gloo or nccl.

synchronize()[source]

Sync (barrier) among all processes when using distributed training.

Return type:

None

tensor2obj(tensor)[source]

Deserialize tensor to picklable python object.

Parameters:

tensor (torch.Tensor) – Tensor to be deserialized.

Return type:

Any