"""Data module composing the data loading pipeline."""from__future__importannotationsimportlightning.pytorchasplfromtorch.utils.dataimportDataLoaderfromvis4d.configimportinstantiate_classesfromvis4d.config.typingimportDataConfigfromvis4d.data.typingimportDictData
[docs]classDataModule(pl.LightningDataModule):"""DataModule that wraps around the vis4d implementations. This is a wrapper around the vis4d implementations that allows to use pytorch-lightning for training and testing. """def__init__(self,data_cfg:DataConfig)->None:"""Creates an instance of the class."""super().__init__()self.data_cfg=data_cfg
[docs]deftrain_dataloader(self)->DataLoader[DictData]:"""Return dataloader for training."""ifself.trainerisnotNoneandhasattr(self.trainer,"seed"):seed=self.trainer.seedelse:seed=Nonereturninstantiate_classes(self.data_cfg.train_dataloader,seed=seed)
[docs]deftest_dataloader(self)->list[DataLoader[DictData]]:"""Return dataloaders for testing."""returninstantiate_classes(self.data_cfg.test_dataloader)
[docs]defval_dataloader(self)->list[DataLoader[DictData]]:"""Return dataloaders for validation."""returnself.test_dataloader()