Source code for vis4d.op.detect3d.bevformer.grid_mask
"""Grid mask for BEVFormer."""
import numpy as np
import torch
from PIL import Image
from torch import Tensor, nn
[docs]
class GridMask(nn.Module):
"""Grid Mask Layer."""
def __init__(
self,
use_h: bool,
use_w: bool,
rotate: int = 1,
offset: bool = False,
ratio: float = 0.5,
mode: int = 0,
prob: float = 1.0,
) -> None:
"""Init."""
super().__init__()
self.use_h = use_h
self.use_w = use_w
self.rotate = rotate
self.offset = offset
self.ratio = ratio
self.mode = mode
self.st_prob = prob
self.prob = prob
[docs]
def forward(self, x: Tensor) -> Tensor:
"""Forward."""
if np.random.rand() > self.prob:
return x
device = x.device
n, c, h, w = x.size()
x = x.view(-1, h, w)
hh = int(1.5 * h)
ww = int(1.5 * w)
d = np.random.randint(2, h)
l = min(max(int(d * self.ratio + 0.5), 1), d - 1)
mask = np.ones((hh, ww), np.float32)
st_h = np.random.randint(d)
st_w = np.random.randint(d)
if self.use_h:
for i in range(hh // d):
s = d * i + st_h
t = min(s + l, hh)
mask[s:t, :] *= 0
if self.use_w:
for i in range(ww // d):
s = d * i + st_w
t = min(s + l, ww)
mask[:, s:t] *= 0
r = np.random.randint(self.rotate)
mask_img = Image.fromarray(np.uint8(mask))
mask_img = mask_img.rotate(r)
mask = np.asarray(mask_img)
mask = mask[
(hh - h) // 2 : (hh - h) // 2 + h,
(ww - w) // 2 : (ww - w) // 2 + w,
]
mask_tensor = torch.from_numpy(mask).to(x.dtype).to(device)
if self.mode == 1:
mask_tensor = 1 - mask_tensor
mask_tensor = mask_tensor.expand_as(x)
if self.offset:
offset = (
torch.from_numpy(2 * (np.random.rand(h, w) - 0.5))
.to(x.dtype)
.to(device)
)
x = x * mask_tensor + offset * (1 - mask_tensor)
else:
x = x * mask_tensor
return x.view(n, c, h, w)