
Multi-Scale Deformable Attention Module.

Modified from Deformable DETR ( # pylint: disable=line-too-long



Check if a number is a power of 2.

ms_deformable_attention_cpu(value, ...)

CPU version of multi-scale deformable attention.


MSDeformAttention([d_model, n_levels, ...])

Multi-Scale Deformable Attention Module.

MSDeformAttentionFunction(*args, **kwargs)

Multi-Scale Deformable Attention Function module.

class MSDeformAttention(d_model=256, n_levels=4, n_heads=8, n_points=4, im2col_step=64)[source]

Multi-Scale Deformable Attention Module.

Creates an instance of the class.

  • d_model (int) – Hidden dimensions.

  • n_levels (int) – Number of feature levels.

  • n_heads (int) – Number of attention heads.

  • n_points (int) – Number of sampling points per attention head per feature level.

  • im2col_step (int) – The step used in image_to_column. Default: 64.

forward(query, reference_points, input_flatten, input_spatial_shapes, input_level_start_index, input_padding_mask=None)[source]

Forward function.

  • query (Tensor) – (n, length_{query}, C).

  • reference_points (Tensor) – (n, length_{query}, n_levels, 2), range in [0, 1], top-left (0,0), bottom-right (1, 1), including padding area or (n, length_{query}, n_levels, 4), add additional (w, h) to form reference boxes.

  • input_flatten (Tensor) – (n, sum_{l=0}^{L-1} H_l cdot W_l, C).

  • input_spatial_shapes (Tensor) – (n_levels, 2), [(H_0, W_0), (H_1, W_1), …, (H_{L-1}, W_{L-1})]

  • input_level_start_index (Tensor) – (n_levels, ), [0, H_0*W_0, H_0*W_0+H_1*W_1, H_0*W_0+H_1*W_1+H_2*W_2, …, H_0*W_0+H_1*W_1+…+H_{L-1}*W_{L-1}]

  • input_padding_mask (Tensor) – (n, sum_{l=0}^{L-1} H_l cdot W_l), True for padding elements, False for non-padding elements.

Return type:



output (Tensor): (n, length_{query}, C).

class MSDeformAttentionFunction(*args, **kwargs)[source]

Multi-Scale Deformable Attention Function module.

static backward(ctx, grad_output)[source]

Backward pass.

Return type:

tuple[Tensor, None, None, Tensor, Tensor, None]

static forward(ctx, value, value_spatial_shapes, value_level_start_index, sampling_locations, attention_weights, im2col_step)[source]

Forward pass.

Return type:



Check if a number is a power of 2.

Return type:


ms_deformable_attention_cpu(value, value_spatial_shapes, sampling_locations, attention_weights)[source]

CPU version of multi-scale deformable attention.

  • value (Tensor) – The value has shape (bs, num_keys, mum_heads, embed_dims // num_heads)

  • value_spatial_shapes (Tensor) – Spatial shape of each feature map, has shape (num_levels, 2), last dimension 2 represent (h, w).

  • sampling_locations (Tensor) – The location of sampling points, has shape (bs ,num_queries, num_heads, num_levels, num_points, 2), the last dimension 2 represent (x, y).

  • attention_weights (Tensor) – The weight of sampling points used when calculate the attention, has shape (bs ,num_queries, num_heads, num_levels, num_points),


has shape (bs, num_queries, embed_dims).

Return type:
