Source code for vis4d.model.detect3d.bevformer
"""BEVFromer model implementation.
This file composes the operations associated with BEVFormer
`https://arxiv.org/abs/2203.17270` into the full model implementation.
"""
from __future__ import annotations
import copy
from typing import TypedDict
import torch
from torch import Tensor, nn
from vis4d.common.ckpt import load_model_checkpoint
from vis4d.op.base import BaseModel
from vis4d.op.detect3d.bevformer import BEVFormerHead, GridMask
from vis4d.op.detect3d.common import Detect3DOut
from vis4d.op.fpp.fpn import FPN, ExtraFPNBlock
REV_KEYS = [
(r"^img_backbone\.", "basemodel."),
(r"^img_neck.lateral_convs\.", "fpn.inner_blocks."),
(r"^img_neck.fpn_convs\.", "fpn.layer_blocks."),
(r"^fpn.layer_blocks.3\.", "fpn.extra_blocks.convs.0."),
(r"\.conv.weight", ".weight"),
(r"\.conv.bias", ".bias"),
]
[docs]
class PrevFrameInfo(TypedDict):
"""Previous frame information."""
scene_name: str
prev_bev: Tensor | None
prev_pos: Tensor
prev_angle: Tensor
[docs]
class BEVFormer(nn.Module):
"""BEVFormer 3D Detector."""
def __init__(
self,
basemodel: BaseModel,
fpn: FPN | None = None,
pts_bbox_head: BEVFormerHead | None = None,
weights: str | None = None,
) -> None:
"""Creates an instance of the class.
Args:
basemodel (BaseModel): Base model network.
fpn (FPN, optional): Feature Pyramid Network. Defaults to None. If
None, a default FPN will be used.
pts_bbox_head (BEVFormerHead, optional): BEVFormer head. Defaults
to None. If None, a default BEVFormer head will be used.
weights (str, optional): Path to the checkpoint to load. Defaults
to None.
"""
super().__init__()
self.basemodel = basemodel
self.fpn = fpn or FPN(
self.basemodel.out_channels[3:],
256,
extra_blocks=ExtraFPNBlock(
extra_levels=1, in_channels=256, out_channels=256
),
start_index=3,
)
self.grid_mask = GridMask(
True, True, rotate=1, offset=False, ratio=0.5, mode=1, prob=0.7
)
self.pts_bbox_head = pts_bbox_head or BEVFormerHead()
# Temporal information
self.prev_frame_info = PrevFrameInfo(
scene_name="",
prev_bev=None,
prev_pos=torch.zeros(3),
prev_angle=torch.zeros(1),
)
if weights is not None:
load_model_checkpoint(self, weights, rev_keys=REV_KEYS)
[docs]
def extract_feat(self, images_list: list[Tensor]) -> list[Tensor]:
"""Extract features of images."""
n = len(images_list) # N
b = images_list[0].shape[0] # B
images = torch.stack(images_list, dim=1) # [B, N, C, H, W]
images = images.view(-1, *images.shape[2:]) # [B*N, C, H, W]
# grid mask
if self.training:
images = self.grid_mask(images)
features = self.basemodel(images)
features = self.fpn(features)[self.fpn.start_index :]
img_feats = []
for img_feat in features:
_, c, h, w = img_feat.size()
img_feats.append(img_feat.view(b, n, c, h, w))
return img_feats
[docs]
def forward(
self,
images: list[Tensor],
can_bus: list[list[float]],
scene_names: list[str],
cam_intrinsics: list[Tensor],
cam_extrinsics: list[Tensor],
lidar_extrinsics: list[Tensor],
) -> Detect3DOut:
"""Forward."""
# Parse lidar extrinsics from LIDAR sensor data.
lidar_extrinsics_tensor = lidar_extrinsics[0]
can_bus_tensor = torch.tensor(
can_bus, dtype=torch.float32, device=images[0].device
)
if scene_names[0] != self.prev_frame_info["scene_name"]:
# the first sample of each scene is truncated
self.prev_frame_info["prev_bev"] = None
# update idx
self.prev_frame_info["scene_name"] = scene_names[0]
# Get the delta of ego position and angle between two timestamps.
tmp_pos = copy.deepcopy(can_bus_tensor[0][:3])
tmp_angle = copy.deepcopy(can_bus_tensor[0][-1])
if self.prev_frame_info["prev_bev"] is not None:
can_bus_tensor[0][:3] -= self.prev_frame_info["prev_pos"]
can_bus_tensor[0][-1] -= self.prev_frame_info["prev_angle"]
else:
can_bus_tensor[0][:3] = 0
can_bus_tensor[0][-1] = 0
images_hw = (int(images[0].shape[-2]), int(images[0].shape[-1]))
img_feats = self.extract_feat(images)
out, bev_embed = self.pts_bbox_head(
img_feats,
can_bus_tensor,
images_hw,
cam_intrinsics,
cam_extrinsics,
lidar_extrinsics_tensor,
prev_bev=self.prev_frame_info["prev_bev"],
)
# During inference, we save the BEV features and ego motion of each
# timestamp.
self.prev_frame_info["prev_pos"] = tmp_pos
self.prev_frame_info["prev_angle"] = tmp_angle
self.prev_frame_info["prev_bev"] = bev_embed
return out