"""ToTensor transformation."""importnumpyasnpimporttorchfromvis4d.data.constimportCommonKeysasKfromvis4d.data.typingimportDictDatafrom.baseimportTransformdef_replace_arrays(data:DictData)->None:"""Replace numpy arrays with tensors."""forkeyindata.keys():ifkeyin[K.images,K.original_images]:ifnotdata[key].flags.c_contiguous:data[key]=np.ascontiguousarray(data[key].transpose(0,3,1,2))data[key]=torch.from_numpy(data[key])else:data[key]=(torch.from_numpy(data[key]).permute(0,3,1,2).contiguous())elifisinstance(data[key],np.ndarray):data[key]=torch.from_numpy(data[key])elifisinstance(data[key],dict):_replace_arrays(data[key])elifisinstance(data[key],list):fori,entryinenumerate(data[key]):ifisinstance(entry,np.ndarray):data[key][i]=torch.from_numpy(entry)
[docs]@Transform("data","data")classToTensor:"""Transform all entries in a list of DataDict from numpy to torch. Note that we reshape K.images from NHWC to NCHW. """
[docs]def__call__(self,batch:list[DictData])->list[DictData]:"""Transform all entries to tensor."""fordatainbatch:_replace_arrays(data)returnbatch