"""Motion model base class."""fromtorchimportTensor
[docs]classBaseMotionModel:"""Base class for motion model."""def__init__(self,num_frames:int,motion_dims:int,hits:int=1,hit_streak:int=0,time_since_update:int=0,age:int=0,fps:int=1,)->None:"""Creates an instance of the class."""self.num_frames=num_framesself.motion_dims=motion_dimsself.hits=hitsself.hit_streak=hit_streakself.time_since_update=time_since_updateself.age=ageself.fps=fps
[docs]defupdate(self,obs_3d:Tensor,info:Tensor)->None:"""Update the state."""raiseNotImplementedError()
[docs]defpredict(self,update_state:bool=True)->Tensor:"""Predict the state."""raiseNotImplementedError()
[docs]defget_state(self)->Tensor:"""Get the state."""raiseNotImplementedError()
[docs]defupdate_array(origin_array:Tensor,input_array:Tensor)->Tensor:"""Update array according the input."""new_array=origin_array.clone()new_array[:-1]=origin_array[1:]new_array[-1:]=input_arrayreturnnew_array