vis4d.engine.trainer¶
Trainer for running train and test.
Classes
|
Trainer class. |
- class Trainer(device, output_dir, train_dataloader, test_dataloader, train_data_connector, test_data_connector, callbacks, num_epochs=1000, num_steps=-1, epoch=0, global_step=0, check_val_every_n_epoch=1, val_check_interval=None, log_every_n_steps=50)[source]¶
Trainer class.
Initialize the trainer.
- Parameters:
device (torch.device) – Device that should be used for training.
output_dir (str) – Output directory for saving tensorboard logs.
train_dataloader (DataLoader[DictData] | None, optional) – Dataloader for training.
test_dataloader (list[DataLoader[DictData]] | None, optional) – Dataloaders for testing.
train_data_connector (DataConnector | None) – Data connector used for generating training inputs from a batch of data.
test_data_connector (DataConnector | None) – Data connector used for generating testing inputs from a batch of data.
callbacks (list[Callback]) – Callbacks that should be used during training.
num_epochs (int, optional) – Number of training epochs. Defaults to 1000.
num_steps (int, optional) – Number of training steps. Defaults to -1.
epoch (int, optional) – Starting epoch. Defaults to 0.
global_step (int, optional) – Starting step. Defaults to 0.
check_val_every_n_epoch (int | None, optional) – Evaluate the model every n epochs during training. Defaults to 1.
val_check_interval (int | None, optional) – Interval for evaluating the model during training. Defaults to None.
log_every_n_steps (int, optional) – Log the training status every n steps. Defaults to 50.
- fit(model, optimizers, lr_schedulers, loss_module)[source]¶
Training loop.
- Parameters:
model (nn.Module) – Model that should be trained.
optimizers (list[Optimizer]) – Optimizers that should be used for training.
lr_schedulers (list[LRSchedulerWrapper]) – Learning rate schedulers that should be used for training.
loss_module (LossModule) – Loss module that should be used for training.
- Raises:
TypeError – If the loss value is not a torch.Tensor or a dict of torch.Tensor.
- Return type:
None