Source code for vis4d.zoo.base.pl_trainer

"""Default runtime configuration for PyTorch Lightning."""

import inspect

from lightning import Trainer

from vis4d.config import FieldConfigDict
from vis4d.config.typing import ExperimentConfig


[docs] def get_default_pl_trainer_cfg(config: ExperimentConfig) -> ExperimentConfig: """Get PyTorch Lightning Trainer config.""" pl_trainer = FieldConfigDict() # PL Trainer arguments for k, v in inspect.signature(Trainer).parameters.items(): if not k in {"callbacks", "devices", "logger", "strategy"}: pl_trainer[k] = v.default # PL Trainer settings pl_trainer.benchmark = config.benchmark pl_trainer.use_distributed_sampler = False pl_trainer.num_sanity_val_steps = 0 # logger pl_trainer.enable_progress_bar = False pl_trainer.log_every_n_steps = config.log_every_n_steps # Default Trainer arguments pl_trainer.work_dir = config.work_dir pl_trainer.exp_name = config.experiment_name pl_trainer.version = config.version pl_trainer.find_unused_parameters = False pl_trainer.checkpoint_period = 1 pl_trainer.save_top_k = 1 pl_trainer.wandb = False return pl_trainer