vis4d.engine.trainer

Trainer for running train and test.

Classes

Trainer(device, output_dir, ...[, ...])

Trainer class.

class Trainer(device, output_dir, train_dataloader, test_dataloader, train_data_connector, test_data_connector, callbacks, num_epochs=1000, num_steps=-1, epoch=0, global_step=0, check_val_every_n_epoch=1, val_check_interval=None, log_every_n_steps=50)[source]

Trainer class.

Initialize the trainer.

Parameters:
  • device (torch.device) – Device that should be used for training.

  • output_dir (str) – Output directory for saving tensorboard logs.

  • train_dataloader (DataLoader[DictData] | None, optional) – Dataloader for training.

  • test_dataloader (list[DataLoader[DictData]] | None, optional) – Dataloaders for testing.

  • train_data_connector (DataConnector | None) – Data connector used for generating training inputs from a batch of data.

  • test_data_connector (DataConnector | None) – Data connector used for generating testing inputs from a batch of data.

  • callbacks (list[Callback]) – Callbacks that should be used during training.

  • num_epochs (int, optional) – Number of training epochs. Defaults to 1000.

  • num_steps (int, optional) – Number of training steps. Defaults to -1.

  • epoch (int, optional) – Starting epoch. Defaults to 0.

  • global_step (int, optional) – Starting step. Defaults to 0.

  • check_val_every_n_epoch (int | None, optional) – Evaluate the model every n epochs during training. Defaults to 1.

  • val_check_interval (int | None, optional) – Interval for evaluating the model during training. Defaults to None.

  • log_every_n_steps (int, optional) – Log the training status every n steps. Defaults to 50.

done()[source]

Return whether training is done.

Return type:

bool

fit(model, optimizers, lr_schedulers, loss_module)[source]

Training loop.

Parameters:
  • model (nn.Module) – Model that should be trained.

  • optimizers (list[Optimizer]) – Optimizers that should be used for training.

  • lr_schedulers (list[LRSchedulerWrapper]) – Learning rate schedulers that should be used for training.

  • loss_module (LossModule) – Loss module that should be used for training.

Raises:

TypeError – If the loss value is not a torch.Tensor or a dict of torch.Tensor.

Return type:

None

get_state(metrics=None, optimizers=None, lr_schedulers=None)[source]

Get the state of the trainer.

Return type:

TrainerState

test(model, is_val=False)[source]

Testing loop.

Parameters:
  • model (nn.Module) – Model that should be tested.

  • is_val (bool) – Whether the test is run on the validation set during training.

Return type:

None