Source code for vis4d.model.motion.velo_lstm

"""VeloLSTM 3D motion model."""

from __future__ import annotations

from typing import NamedTuple

import torch
from torch import Tensor, nn

from vis4d.common.ckpt import load_model_checkpoint
from vis4d.op.geometry.rotation import acute_angle, normalize_angle
from vis4d.op.layer.weight_init import xavier_init


[docs] class VeloLSTMOut(NamedTuple): """VeloLSTM output.""" loc_preds: Tensor loc_refines: Tensor
[docs] class VeloLSTM(nn.Module): """Estimating object location in world coordinates. Prediction LSTM: Input: 5 frames velocity Output: Next frame location Updating LSTM: Input: predicted location and observed location Output: Refined location """ def __init__( self, num_frames: int = 5, feature_dim: int = 64, hidden_size: int = 128, num_layers: int = 2, loc_dim: int = 7, dropout: float = 0.1, weights: str | None = None, ) -> None: """Init.""" super().__init__() self.num_frames = num_frames self.feature_dim = feature_dim self.hidden_size = hidden_size self.num_layers = num_layers self.loc_dim = loc_dim self.vel2feat = nn.Linear( loc_dim, feature_dim, ) self.pred_lstm = nn.LSTM( input_size=feature_dim, hidden_size=hidden_size, dropout=dropout, num_layers=num_layers, ) self.pred2atten = nn.Linear( hidden_size, loc_dim, bias=False, ) self.conf2feat = nn.Linear( 1, feature_dim, bias=False, ) self.refine_lstm = nn.LSTM( input_size=3 * feature_dim, hidden_size=hidden_size, dropout=dropout, num_layers=num_layers, ) self.conf2atten = nn.Linear( hidden_size, loc_dim, bias=False, ) self._init_weights() if weights is not None: load_model_checkpoint( self, weights, map_location="cpu", rev_keys=[(r"^model\.", ""), (r"^module\.", "")], ) def _init_weights(self) -> None: """Initialize model weights.""" xavier_init(self.vel2feat) xavier_init(self.pred2atten) xavier_init(self.conf2feat) xavier_init(self.conf2atten) init_lstm_module(self.pred_lstm) init_lstm_module(self.refine_lstm)
[docs] def init_hidden( self, device: torch.device, batch_size: int = 1 ) -> tuple[Tensor, Tensor]: """Initializae hidden state. The axes semantics are (num_layers, minibatch_size, hidden_dim) """ return ( torch.zeros(self.num_layers, batch_size, self.hidden_size).to( device ), torch.zeros(self.num_layers, batch_size, self.hidden_size).to( device ), )
[docs] def refine( self, location: Tensor, observation: Tensor, prev_location: Tensor, confidence: Tensor, hc_0: tuple[Tensor, Tensor], ) -> tuple[Tensor, tuple[Tensor, Tensor]]: """Refine predicted location using single frame estimation at t+1. Input: location: (num_batch x loc_dim), location from prediction observation: (num_batch x loc_dim), location from single frame estimation prev_location: (num_batch x loc_dim), refined location confidence: (num_batch X 1), depth estimation confidence hc_0: (num_layers, num_batch, hidden_size), tuple of hidden and cell Middle: loc_embed: (1, num_batch x feature_dim), predicted location feature obs_embed: (1, num_batch x feature_dim), single frame location feature conf_embed: (1, num_batch x feature_dim), depth estimation confidence feature embed: (1, num_batch x 2*feature_dim), location feature out: (1 x num_batch x hidden_size), lstm output Output: hc_n: (num_layers, num_batch, hidden_size), tuple of updated hidden, cell output_pred: (num_batch x loc_dim), predicted location """ num_batch = location.shape[0] pred_vel = location - prev_location obsv_vel = observation - prev_location # Embed feature to hidden_size loc_embed = self.vel2feat(pred_vel).view(num_batch, self.feature_dim) obs_embed = self.vel2feat(obsv_vel).view(num_batch, self.feature_dim) conf_embed = self.conf2feat(confidence).view( num_batch, self.feature_dim ) embed = torch.cat( [ loc_embed, obs_embed, conf_embed, ], dim=1, ).view(1, num_batch, 3 * self.feature_dim) out, (h_n, c_n) = self.refine_lstm(embed, hc_0) delta_vel_atten = torch.sigmoid(self.conf2atten(out)).view( num_batch, self.loc_dim ) output_pred = ( delta_vel_atten * obsv_vel + (1.0 - delta_vel_atten) * pred_vel + prev_location ) return output_pred, (h_n, c_n)
[docs] def predict( self, vel_history: Tensor, location: Tensor, hc_0: tuple[Tensor, Tensor], ) -> tuple[Tensor, tuple[Tensor, Tensor]]: """Predict location at t+1 using updated location at t. Input: vel_history: (num_seq, num_batch, loc_dim), velocity from previous num_seq updates location: (num_batch, loc_dim), location from previous update hc_0: (num_layers, num_batch, hidden_size), tuple of hidden and cell Middle: embed: (num_seq, num_batch x feature_dim), location feature out: (num_seq x num_batch x hidden_size), lstm output attention_logit: (num_seq x num_batch x loc_dim), the predicted residual Output: hc_n: (num_layers, num_batch, hidden_size), tuple of updated hidden, cell output_pred: (num_batch x loc_dim), predicted location """ num_seq, num_batch, _ = vel_history.shape # Embed feature to hidden_size embed = self.vel2feat(vel_history).view( num_seq, num_batch, self.feature_dim ) out, (h_n, c_n) = self.pred_lstm(embed, hc_0) attention_logit = self.pred2atten(out).view( num_seq, num_batch, self.loc_dim ) attention = torch.softmax(attention_logit, dim=0) output_pred = torch.sum(attention * vel_history, dim=0) + location return output_pred, (h_n, c_n)
[docs] def forward(self, pred_traj: Tensor) -> VeloLSTMOut: """Forward of QD3DTrackGraph in training stage.""" loc_preds_list = [] loc_refines_list = [] hidden_predict = self.init_hidden( pred_traj.device, batch_size=pred_traj.shape[0] ) hidden_refine = self.init_hidden( pred_traj.device, batch_size=pred_traj.shape[0] ) vel_history = pred_traj.new_zeros( self.num_frames, pred_traj.shape[0], self.loc_dim ) # Starting condition pred_traj[:, :, 6] = normalize_angle(pred_traj[:, :, 6]) prev_refine = pred_traj[:, 0, : self.loc_dim] loc_pred = pred_traj[:, 1, : self.loc_dim] # LSTM for i in range(1, pred_traj.shape[1]): # Update loc_pred[:, 6] = normalize_angle(loc_pred[:, 6]) for batch_id in range(pred_traj.shape[0]): # acute angle loc_pred[batch_id, 6] = acute_angle( loc_pred[batch_id, 6], pred_traj[batch_id, i, 6] ) loc_refine, hidden_refine = self.refine( loc_pred.detach().clone(), pred_traj[:, i, : self.loc_dim], prev_refine.detach().clone(), pred_traj[:, i, -1].unsqueeze(-1), hidden_refine, ) loc_refine[:, 6] = normalize_angle(loc_refine[:, 6]) if i == 1: vel_history = torch.cat( [(loc_refine - prev_refine).unsqueeze(0)] * self.num_frames ) else: vel_history = torch.cat( [vel_history[1:], (loc_refine - prev_refine).unsqueeze(0)], dim=0, ) prev_refine = loc_refine # Predict loc_pred, hidden_predict = self.predict( vel_history, loc_refine.detach().clone(), hidden_predict ) loc_pred[:, 6] = normalize_angle(loc_pred[:, 6]) loc_refines_list.append(loc_refine) loc_preds_list.append(loc_pred) loc_refines = torch.cat(loc_refines_list, dim=1).view( pred_traj.shape[0], -1, self.loc_dim ) loc_preds = torch.cat(loc_preds_list, dim=1).view( pred_traj.shape[0], -1, self.loc_dim ) return VeloLSTMOut(loc_preds=loc_preds, loc_refines=loc_refines)
[docs] def init_lstm_module(layer: nn.Module) -> None: """Initialize LSTM weights and biases.""" for name, param in layer.named_parameters(): if "weight_ih" in name: torch.nn.init.xavier_uniform_(param.data) elif "weight_hh" in name: torch.nn.init.orthogonal_(param.data) elif "bias" in name: param.data.fill_(0)