vis4d.pl.callbacks.callback_wrapper

Wrapper to connect PyTorch Lightning callbacks.

Functions

get_loss_module(loss_module)

Get loss_module from pl module.

get_model(model)

Get model from pl module.

get_trainer_state(trainer, pl_module[, val])

Wrap pl.Trainer and pl.LightningModule into Trainer.

Classes

CallbackWrapper(callback)

Wrapper to connect vis4d callbacks to pytorch lightning callbacks.

class CallbackWrapper(callback)[source]

Wrapper to connect vis4d callbacks to pytorch lightning callbacks.

Init class.

on_test_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Wait for on_test_batch_end PL hook to call ‘process’.

Return type:

None

on_test_epoch_end(trainer, pl_module)[source]

Wait for on_test_epoch_end PL hook to call ‘evaluate’.

Return type:

None

on_test_epoch_start(trainer, pl_module)[source]

Hook to run at the start of a testing epoch.

Return type:

None

on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)[source]

Hook to run at the end of a training batch.

Return type:

None

on_train_batch_start(trainer, pl_module, batch, batch_idx)[source]

Called when the train batch begins.

Return type:

None

on_train_epoch_end(trainer, pl_module)[source]

Hook to run at the end of a training epoch.

Return type:

None

on_train_epoch_start(trainer, pl_module)[source]

Hook to run at the start of a training epoch.

Return type:

None

on_validation_batch_end(trainer, pl_module, outputs, batch, batch_idx, dataloader_idx=0)[source]

Wait for on_validation_batch_end PL hook to call ‘process’.

Return type:

None

on_validation_epoch_end(trainer, pl_module)[source]

Wait for on_validation_epoch_end PL hook to call ‘evaluate’.

Return type:

None

on_validation_epoch_start(trainer, pl_module)[source]

Hook to run at the start of a validation epoch.

Return type:

None

setup(trainer, pl_module, stage)[source]

Setup callback.

Return type:

None

get_loss_module(loss_module)[source]

Get loss_module from pl module.

Return type:

LossModule

get_model(model)[source]

Get model from pl module.

Return type:

Module

get_trainer_state(trainer, pl_module, val=False)[source]

Wrap pl.Trainer and pl.LightningModule into Trainer.

Return type:

TrainerState