Source code for vis4d.config.show_connection

"""Show connected components in the config."""

from __future__ import annotations

import inspect
from typing import Any, TypedDict, get_type_hints

from absl import app  # pylint: disable=no-name-in-module
from torch import nn

from vis4d.common.typing import ArgsType
from vis4d.engine.callbacks import (
    Callback,
    EvaluatorCallback,
    VisualizerCallback,
)
from vis4d.engine.connectors import CallbackConnector, DataConnector
from vis4d.engine.flag import _CONFIG
from vis4d.engine.loss_module import LossModule
from vis4d.eval.base import Evaluator
from vis4d.vis.base import Visualizer

from .config_dict import instantiate_classes


# Types
[docs] class DataConnectionInfo(TypedDict): """Internal type def for visualization. This defines a block component """ in_keys: list[str] out_keys: list[str] name: str
# Private Functions def _rename_ds(name: str) -> str: """Replaces data with d and prediction with p. Use this to remap the datasources to shorter names. Args: name: Name to remap Returns: remapped name """ return name.replace("data", "d").replace("prediction", "p") def _get_model_conn_infos( model: nn.Module, ) -> dict[str, DataConnectionInfo]: """Returns the connection infos for a pytorch Model. Requires "forward_train" and "forward_test" to be defined and properly typed! Args: model: Model to extract data from Returns: train_connections, test_connections """ train_t = get_type_hints(model.forward_train)["return"] test_t = get_type_hints(model.forward_test)["return"] train_connection_info = DataConnectionInfo( in_keys=sorted( list(inspect.signature(model.forward).parameters.keys()) ), out_keys=[ "<p>-" + e for e in sorted(resolve_named_tuple(train_t, prefix="")) ], name=model.__class__.__name__, ) test_connection_info = DataConnectionInfo( in_keys=sorted( list(inspect.signature(model.forward).parameters.keys()) ), out_keys=[ "<p>-" + e for e in sorted(resolve_named_tuple(test_t, prefix="")) ], name=model.__class__.__name__, ) return {"train": train_connection_info, "test": test_connection_info} def _get_loss_connection_infos(loss: LossModule) -> list[DataConnectionInfo]: """Returns the connection infos for a loss. Args: loss (LossModule): Custom loss module with .forward() Returns: DataConnectionInfo for the loss. """ loss_connection_info = [] for l in loss.losses: loss_out = [] loss_in = [] for entry, value in l["connector"].key_mapping.items(): loss_out.append(f"{entry}") loss_in.append(f"<{_rename_ds(value['source'])}>-" + value["key"]) loss_connection_info.append( DataConnectionInfo( in_keys=loss_in, out_keys=loss_out, name=l["name"] ) ) return loss_connection_info def _get_vis_connection_infos( visualizer: Visualizer, ) -> DataConnectionInfo: """Returns the connection infos for a visualizer. Args: visualizer: Visualizer to extract data from Returns: DataConnectionInfo for the visualizer. """ return DataConnectionInfo( in_keys=sorted( list(inspect.signature(visualizer.process).parameters.keys()) ), out_keys=[], name=visualizer.__class__.__name__, ) def _get_evaluator_connection_infos( evaluator: Evaluator, ) -> DataConnectionInfo: """Returns the connection infos for an evaluator. Args: evaluator: Evaluator to extract data from Returns: DataConnectionInfo for the evaluator. """ return DataConnectionInfo( in_keys=sorted( list(inspect.signature(evaluator.process).parameters.keys()) ), out_keys=[], name=evaluator.__class__.__name__, ) def _get_data_connector_infos( data_connector: DataConnector, name: str ) -> DataConnectionInfo: """Returns the connection infos for a DataConnector. Args: data_connector (DataConnector): Data connector to extract data. name (str): Name of the data connector. Returns: DataConnectionInfo for the data connector. """ return DataConnectionInfo( in_keys=["<d>-" + e for e in list(data_connector.key_mapping.keys())], out_keys=list(data_connector.key_mapping.values()), name=name, ) def _get_cb_connection_infos( name: str, cb_data_connector: None | CallbackConnector = None, ) -> DataConnectionInfo | None: """Returns the connection infos for a callback.""" if cb_data_connector is not None: eval_out = [] eval_in = [] for entry, value in cb_data_connector.key_mapping.items(): eval_out.append(f"{entry}") eval_in.append(f"<{_rename_ds(value['source'])}>-" + value["key"]) return DataConnectionInfo( in_keys=eval_in, out_keys=eval_out, name=name ) return None def _get_with_color(key: str, warn_unconnected: bool = True) -> str: """Prepends colors for internal vsiualization.""" if "*" in key: # We connected this one return f"\033[94m{key}\033[00m" if "<d>" in key: # key comes from data return f"\033[90m{key}\033[00m" # comes from prediction and is not connected if warn_unconnected: return f"\u001b[33m{key}\033[00m" return f"\033[00m{key}\033[00m" # API Functions
[docs] def resolve_named_tuple( # type:ignore clazz: Any, prefix: str = "" ) -> list[str]: """Returns all fields defined in the clazz t. Use this to get all fields defined for an e.g. Named Tuple. Args: clazz: Class that should be resolved prefix: Prefix to prepend (will be prefix.<field>) Returns: List with all fields and prefixes prepended. Examples: >>> Person = namedtuple("Person", ["name", "age", "gender"]) >>> Address = namedtuple("Address", ["street", "city", "zipcode"]) >>> resolve_named_tuple(clazz=Person, prefix="person") ["person.name", "person.age", "person.gender"] >>> resolve_named_tuple(clazz=Address, prefix="address") ["address.street", "address.city", "address.zipcode"] >>> resolve_named_tuple(clazz=Person, prefix="") ["name", "age", "gender"] With more complex types: >>> User = namedtuple("User", ["name", "address"]) >>> user = User(name=Person(name="John"), address=Address(street="str", city="zrh", zipcode="1")) >>> resolve_named_tuple(clazz=user, prefix="user") ["user.name.name", "user.address.street", "user.address.city", "user.address.zipcode"] """ fields = [] if hasattr(clazz, "_fields"): for f in clazz._fields: p = f"{prefix}.{f}" if len(prefix) > 0 else f fields += resolve_named_tuple(getattr(clazz, f), prefix=p) return fields return [prefix]
[docs] def connect_components( in_info: DataConnectionInfo, out_info: DataConnectionInfo ) -> None: """Marks two components as connected. Checks if they have intersecting keys and marks them as matched. Updates the components inplace. Args: in_info (DataConnectionInfo): Input DataConnection out_info (DataConnectionInfo): Ouput DataConnection """ out_keys = [] for out in out_info["in_keys"]: out = out.replace("*", "") out_keys.append(out.split(".")[0]) # Check connection for idx, key in enumerate(in_info["out_keys"]): key = key.replace("*", "") for o_idx, o_key in enumerate(out_keys): if key == o_key: in_info["out_keys"][idx] = "*" + key out_info["in_keys"][o_idx] = ( " " + out_info["in_keys"][o_idx].replace("*", "") + "*" )
[docs] def prints_datagraph_for_config( model: nn.Module, train_data_connector: DataConnector, test_data_connector: DataConnector, loss: LossModule, callbacks: list[Callback], ) -> str: """Shows the setup of the configuration objects. For each components, plots which inputs is connected to which output. Connected components are marked with "*". Use this to debug your configuration setup. Note, that data loaded from the dataset are highlighted with <d> and data from model predictions with <p>. Args: model (nn.Module): Model to plot. train_data_connector (DataConnector): Train data connector to plot. test_data_connector (DataConnector): Test data connector to plot. loss (LossModule): Loss to plot. callbacks (list[Callback]): Callbacks to plot. Returns: str: The datagraph as a string, that can be printed to the console. Example: The following is train datagraph for FasterRCNN with COCO. Inputs loaded from dataset are marked with <d> and predictions with <p>. Unconnected inputs are missing a (*) sign. >>> dg = prints_datagraph_for_config(model, train_data_connector, test_data_connector, loss, callbacks))) >>> print(dg) ``` # TODO: check if this is correct =================================== = Training Loop = =================================== -------------- <d>-boxes2d | | *boxes2d <d>-boxes2d_classes | | *boxes2d_classes <d>-images | Train Data | *images <d>-input_hw | | *input_hw -------------- -------------- boxes2d* | | <p>-proposals boxes2d_classes* | | <p>-roi images* | | *<p>-rpn input_hw* | FasterRCNN | <p>-sampled_proposals original_hw | | <p>-sampled_target_indices | | <p>-sampled_targets -------------- ----------- <p>-rpn.cls* | | cls_outs <d>-input_hw | | images_hw <p>-rpn.box* | RPNLoss | reg_outs <d>-boxes2d | | target_boxes ----------- ------------ <p>-sampled_proposals.boxes | | boxes <p>-sampled_targets.labels | | boxes_mask <p>-roi.cls_score | | class_outs <p>-roi.bbox_pred | RCNNLoss | regression_outs <p>-sampled_targets.boxes | | target_boxes <p>-sampled_targets.classes | | target_classes ------------ =================================== = Testing Loop = =================================== ------------- <d>-images | | *images <d>-input_hw | Test Data | *input_hw <d>-original_hw | | *original_hw ------------- -------------- boxes2d | | <p>-boxes boxes2d_classes | | <p>-class_ids images* | FasterRCNN | <p>-scores input_hw* | | original_hw* | | -------------- =================================== = Callbacks = =================================== ------------------------- <d>-original_images | | *images <d>-sample_names | | *image_names <p>-boxes | BoundingBoxVisualizer | *boxes <p>-scores | | *scores <p>-class_ids | | *class_ids ------------------------- ---------------------- <d>-sample_names | | *coco_image_id <p>-boxes | | *pred_boxes <p>-scores | COCODetectEvaluator | *pred_scores <p>-class_ids | | *pred_classes ---------------------- ``` """ model_connection_info = _get_model_conn_infos(model) # TODO: support more data connectors assert isinstance(train_data_connector, DataConnector) and isinstance( test_data_connector, DataConnector ), "Only DataConnector is supported." train_data_connection_info = _get_data_connector_infos( train_data_connector, name="Train Data" ) test_data_connection_info = _get_data_connector_infos( test_data_connector, name="Test Data" ) loss_info = _get_loss_connection_infos(loss) log_str = "" # connect components log_str += "=" * 35 + "\n" log_str += "=" + " " * 10 + "Training Loop" + " " * 10 + "=" + "\n" log_str += "=" * 35 + "\n" train_components = [ train_data_connection_info, model_connection_info["train"], ] + loss_info for inp, out in zip(train_components[:-1], train_components[1:]): connect_components(inp, out) for e in train_components: log_str += print_box(e["name"], e["in_keys"], e["out_keys"]) log_str += "=" * 35 + "\n" log_str += "=" + " " * 10 + "Testing Loop " + " " * 10 + "=" + "\n" log_str += "=" * 35 + "\n" test_components = [ test_data_connection_info, model_connection_info["test"], ] for inp, out in zip(test_components[:-1], test_components[1:]): connect_components(inp, out) for e in test_components: log_str += print_box(e["name"], e["in_keys"], e["out_keys"]) # TODO: Add support for more callbacks and handle train_connector log_str += "=" * 35 + "\n" log_str += "=" + " " * 12 + "Callbacks" + " " * 12 + "=" + "\n" log_str += "=" * 35 + "\n" # evaluator and visualizer callback_components: list[DataConnectionInfo] = [] for cb in callbacks: if isinstance(cb, EvaluatorCallback): evaluator = cb.evaluator connect_info = _get_evaluator_connection_infos(evaluator) component = _get_cb_connection_infos( cb.evaluator.__class__.__name__, cb.test_connector ) # found matching connector if component is not None: connect_components(component, connect_info) callback_components.append(component) if isinstance(cb, VisualizerCallback): visualizer = cb.visualizer connect_info = _get_vis_connection_infos(visualizer) component = _get_cb_connection_infos( cb.visualizer.__class__.__name__, cb.test_connector ) # found matching connector if component is not None: connect_components(component, connect_info) callback_components.append(component) for e in callback_components: log_str += print_box(e["name"], e["in_keys"], e["out_keys"]) return log_str
[docs] def main( argv: ArgsType, # pylint: disable=unused-argument ) -> None: # pragma: no cover """Main entry point to show connected components in the config. >>> python -m vis4d.config.show_connection --config vis4d/zoo/faster_rcnn/faster_rcnn_coco.py """ config = _CONFIG.value train_data_connector = instantiate_classes(config.train_data_connector) test_data_connector = instantiate_classes(config.test_data_connector) loss = instantiate_classes(config.loss) model = instantiate_classes(config.model) callbacks = [instantiate_classes(cb) for cb in config.callbacks] dg = prints_datagraph_for_config( model, train_data_connector, test_data_connector, loss, callbacks ) print(dg)
if __name__ == "__main__": # pragma: no cover app.run(main)