"""Show connected components in the config."""
from __future__ import annotations
import inspect
from typing import Any, TypedDict, get_type_hints
from absl import app # pylint: disable=no-name-in-module
from torch import nn
from vis4d.common.typing import ArgsType
from vis4d.engine.callbacks import (
Callback,
EvaluatorCallback,
VisualizerCallback,
)
from vis4d.engine.connectors import CallbackConnector, DataConnector
from vis4d.engine.flag import _CONFIG
from vis4d.engine.loss_module import LossModule
from vis4d.eval.base import Evaluator
from vis4d.vis.base import Visualizer
from .config_dict import instantiate_classes
# Types
[docs]
class DataConnectionInfo(TypedDict):
"""Internal type def for visualization.
This defines a block component
"""
in_keys: list[str]
out_keys: list[str]
name: str
# Private Functions
def _rename_ds(name: str) -> str:
"""Replaces data with d and prediction with p.
Use this to remap the datasources to shorter names.
Args:
name: Name to remap
Returns:
remapped name
"""
return name.replace("data", "d").replace("prediction", "p")
def _get_model_conn_infos(
model: nn.Module,
) -> dict[str, DataConnectionInfo]:
"""Returns the connection infos for a pytorch Model.
Requires "forward_train" and "forward_test" to be defined and properly
typed!
Args:
model: Model to extract data from
Returns:
train_connections, test_connections
"""
train_t = get_type_hints(model.forward_train)["return"]
test_t = get_type_hints(model.forward_test)["return"]
train_connection_info = DataConnectionInfo(
in_keys=sorted(
list(inspect.signature(model.forward).parameters.keys())
),
out_keys=[
"<p>-" + e for e in sorted(resolve_named_tuple(train_t, prefix=""))
],
name=model.__class__.__name__,
)
test_connection_info = DataConnectionInfo(
in_keys=sorted(
list(inspect.signature(model.forward).parameters.keys())
),
out_keys=[
"<p>-" + e for e in sorted(resolve_named_tuple(test_t, prefix=""))
],
name=model.__class__.__name__,
)
return {"train": train_connection_info, "test": test_connection_info}
def _get_loss_connection_infos(loss: LossModule) -> list[DataConnectionInfo]:
"""Returns the connection infos for a loss.
Args:
loss (LossModule): Custom loss module with .forward()
Returns:
DataConnectionInfo for the loss.
"""
loss_connection_info = []
for l in loss.losses:
loss_out = []
loss_in = []
for entry, value in l["connector"].key_mapping.items():
loss_out.append(f"{entry}")
loss_in.append(f"<{_rename_ds(value['source'])}>-" + value["key"])
loss_connection_info.append(
DataConnectionInfo(
in_keys=loss_in, out_keys=loss_out, name=l["name"]
)
)
return loss_connection_info
def _get_vis_connection_infos(
visualizer: Visualizer,
) -> DataConnectionInfo:
"""Returns the connection infos for a visualizer.
Args:
visualizer: Visualizer to extract data from
Returns:
DataConnectionInfo for the visualizer.
"""
return DataConnectionInfo(
in_keys=sorted(
list(inspect.signature(visualizer.process).parameters.keys())
),
out_keys=[],
name=visualizer.__class__.__name__,
)
def _get_evaluator_connection_infos(
evaluator: Evaluator,
) -> DataConnectionInfo:
"""Returns the connection infos for an evaluator.
Args:
evaluator: Evaluator to extract data from
Returns:
DataConnectionInfo for the evaluator.
"""
return DataConnectionInfo(
in_keys=sorted(
list(inspect.signature(evaluator.process).parameters.keys())
),
out_keys=[],
name=evaluator.__class__.__name__,
)
def _get_data_connector_infos(
data_connector: DataConnector, name: str
) -> DataConnectionInfo:
"""Returns the connection infos for a DataConnector.
Args:
data_connector (DataConnector): Data connector to extract data.
name (str): Name of the data connector.
Returns:
DataConnectionInfo for the data connector.
"""
return DataConnectionInfo(
in_keys=["<d>-" + e for e in list(data_connector.key_mapping.keys())],
out_keys=list(data_connector.key_mapping.values()),
name=name,
)
def _get_cb_connection_infos(
name: str,
cb_data_connector: None | CallbackConnector = None,
) -> DataConnectionInfo | None:
"""Returns the connection infos for a callback."""
if cb_data_connector is not None:
eval_out = []
eval_in = []
for entry, value in cb_data_connector.key_mapping.items():
eval_out.append(f"{entry}")
eval_in.append(f"<{_rename_ds(value['source'])}>-" + value["key"])
return DataConnectionInfo(
in_keys=eval_in, out_keys=eval_out, name=name
)
return None
def _get_with_color(key: str, warn_unconnected: bool = True) -> str:
"""Prepends colors for internal vsiualization."""
if "*" in key:
# We connected this one
return f"\033[94m{key}\033[00m"
if "<d>" in key: # key comes from data
return f"\033[90m{key}\033[00m"
# comes from prediction and is not connected
if warn_unconnected:
return f"\u001b[33m{key}\033[00m"
return f"\033[00m{key}\033[00m"
# API Functions
[docs]
def print_box(
title: str, inputs: list[str], outputs: list[str], use_color: bool = True
) -> str:
"""Prints a box with title and in/outputs.
Args:
title: Title to plot in the middle.
inputs: inputs to plot on the left.
outputs: Outputs to plot on the right.
use_color: Whether to use color in the output.
Returns:
str: The box as a string.
Example:
--------------
<d>-boxes2d | | *boxes2d
<d>-boxes2d_classes | | *boxes2d_classes
<d>-images | Train Data | *images
<d>-input_hw | | *input_hw
--------------
"""
len_title = len(title) + 4
n_lines = max(len(inputs), len(outputs))
max_len_inputs = max([0] + [len(inp) for inp in inputs])
max_len_outputs = max([0] + [len(out) for out in outputs])
divider = (
" " * (max_len_inputs + 1)
+ "-" * len_title
+ " " * (max_len_outputs + 1)
)
lines = divider + "\n"
for idx in range(n_lines):
in_data = inputs[idx] if len(inputs) > idx else ""
# left pad
in_key = " " * (max_len_inputs - len(in_data)) + in_data
out_data = outputs[idx] if len(outputs) > idx else ""
# right pad
out_key = out_data + " " * (max_len_outputs - len(out_data))
# title in middle
line = ""
line += _get_with_color(in_key)
line += " | "
line += " " * len(title) if idx != n_lines // 2 else title
line += " | "
line += _get_with_color(out_key) if use_color else out_key
lines += line + "\n"
lines += divider + "\n"
return lines
[docs]
def resolve_named_tuple( # type:ignore
clazz: Any, prefix: str = ""
) -> list[str]:
"""Returns all fields defined in the clazz t.
Use this to get all fields defined for an e.g. Named Tuple.
Args:
clazz: Class that should be resolved
prefix: Prefix to prepend (will be prefix.<field>)
Returns:
List with all fields and prefixes prepended.
Examples:
>>> Person = namedtuple("Person", ["name", "age", "gender"])
>>> Address = namedtuple("Address", ["street", "city", "zipcode"])
>>> resolve_named_tuple(clazz=Person, prefix="person")
["person.name", "person.age", "person.gender"]
>>> resolve_named_tuple(clazz=Address, prefix="address")
["address.street", "address.city", "address.zipcode"]
>>> resolve_named_tuple(clazz=Person, prefix="")
["name", "age", "gender"]
With more complex types:
>>> User = namedtuple("User", ["name", "address"])
>>> user = User(name=Person(name="John"), address=Address(street="str", city="zrh", zipcode="1"))
>>> resolve_named_tuple(clazz=user, prefix="user")
["user.name.name", "user.address.street", "user.address.city",
"user.address.zipcode"]
"""
fields = []
if hasattr(clazz, "_fields"):
for f in clazz._fields:
p = f"{prefix}.{f}" if len(prefix) > 0 else f
fields += resolve_named_tuple(getattr(clazz, f), prefix=p)
return fields
return [prefix]
[docs]
def connect_components(
in_info: DataConnectionInfo, out_info: DataConnectionInfo
) -> None:
"""Marks two components as connected.
Checks if they have intersecting keys and marks them as matched.
Updates the components inplace.
Args:
in_info (DataConnectionInfo): Input DataConnection
out_info (DataConnectionInfo): Ouput DataConnection
"""
out_keys = []
for out in out_info["in_keys"]:
out = out.replace("*", "")
out_keys.append(out.split(".")[0])
# Check connection
for idx, key in enumerate(in_info["out_keys"]):
key = key.replace("*", "")
for o_idx, o_key in enumerate(out_keys):
if key == o_key:
in_info["out_keys"][idx] = "*" + key
out_info["in_keys"][o_idx] = (
" " + out_info["in_keys"][o_idx].replace("*", "") + "*"
)
[docs]
def prints_datagraph_for_config(
model: nn.Module,
train_data_connector: DataConnector,
test_data_connector: DataConnector,
loss: LossModule,
callbacks: list[Callback],
) -> str:
"""Shows the setup of the configuration objects.
For each components, plots which inputs is connected to which output.
Connected components are marked with "*". Use this to debug your
configuration setup.
Note, that data loaded from the dataset are highlighted with <d> and data
from model predictions with <p>.
Args:
model (nn.Module): Model to plot.
train_data_connector (DataConnector): Train data connector to plot.
test_data_connector (DataConnector): Test data connector to plot.
loss (LossModule): Loss to plot.
callbacks (list[Callback]): Callbacks to plot.
Returns:
str: The datagraph as a string, that can be printed to the console.
Example:
The following is train datagraph for FasterRCNN with COCO.
Inputs loaded from dataset are marked with <d> and predictions
with <p>. Unconnected inputs are missing a (*) sign.
>>> dg = prints_datagraph_for_config(model, train_data_connector, test_data_connector, loss, callbacks)))
>>> print(dg)
```
# TODO: check if this is correct
===================================
= Training Loop =
===================================
--------------
<d>-boxes2d | | *boxes2d
<d>-boxes2d_classes | | *boxes2d_classes
<d>-images | Train Data | *images
<d>-input_hw | | *input_hw
--------------
--------------
boxes2d* | | <p>-proposals
boxes2d_classes* | | <p>-roi
images* | | *<p>-rpn
input_hw* | FasterRCNN | <p>-sampled_proposals
original_hw | | <p>-sampled_target_indices
| | <p>-sampled_targets
--------------
-----------
<p>-rpn.cls* | | cls_outs
<d>-input_hw | | images_hw
<p>-rpn.box* | RPNLoss | reg_outs
<d>-boxes2d | | target_boxes
-----------
------------
<p>-sampled_proposals.boxes | | boxes
<p>-sampled_targets.labels | | boxes_mask
<p>-roi.cls_score | | class_outs
<p>-roi.bbox_pred | RCNNLoss | regression_outs
<p>-sampled_targets.boxes | | target_boxes
<p>-sampled_targets.classes | | target_classes
------------
===================================
= Testing Loop =
===================================
-------------
<d>-images | | *images
<d>-input_hw | Test Data | *input_hw
<d>-original_hw | | *original_hw
-------------
--------------
boxes2d | | <p>-boxes
boxes2d_classes | | <p>-class_ids
images* | FasterRCNN | <p>-scores
input_hw* | |
original_hw* | |
--------------
===================================
= Callbacks =
===================================
-------------------------
<d>-original_images | | *images
<d>-sample_names | | *image_names
<p>-boxes | BoundingBoxVisualizer | *boxes
<p>-scores | | *scores
<p>-class_ids | | *class_ids
-------------------------
----------------------
<d>-sample_names | | *coco_image_id
<p>-boxes | | *pred_boxes
<p>-scores | COCODetectEvaluator | *pred_scores
<p>-class_ids | | *pred_classes
----------------------
```
"""
model_connection_info = _get_model_conn_infos(model)
# TODO: support more data connectors
assert isinstance(train_data_connector, DataConnector) and isinstance(
test_data_connector, DataConnector
), "Only DataConnector is supported."
train_data_connection_info = _get_data_connector_infos(
train_data_connector, name="Train Data"
)
test_data_connection_info = _get_data_connector_infos(
test_data_connector, name="Test Data"
)
loss_info = _get_loss_connection_infos(loss)
log_str = ""
# connect components
log_str += "=" * 35 + "\n"
log_str += "=" + " " * 10 + "Training Loop" + " " * 10 + "=" + "\n"
log_str += "=" * 35 + "\n"
train_components = [
train_data_connection_info,
model_connection_info["train"],
] + loss_info
for inp, out in zip(train_components[:-1], train_components[1:]):
connect_components(inp, out)
for e in train_components:
log_str += print_box(e["name"], e["in_keys"], e["out_keys"])
log_str += "=" * 35 + "\n"
log_str += "=" + " " * 10 + "Testing Loop " + " " * 10 + "=" + "\n"
log_str += "=" * 35 + "\n"
test_components = [
test_data_connection_info,
model_connection_info["test"],
]
for inp, out in zip(test_components[:-1], test_components[1:]):
connect_components(inp, out)
for e in test_components:
log_str += print_box(e["name"], e["in_keys"], e["out_keys"])
# TODO: Add support for more callbacks and handle train_connector
log_str += "=" * 35 + "\n"
log_str += "=" + " " * 12 + "Callbacks" + " " * 12 + "=" + "\n"
log_str += "=" * 35 + "\n"
# evaluator and visualizer
callback_components: list[DataConnectionInfo] = []
for cb in callbacks:
if isinstance(cb, EvaluatorCallback):
evaluator = cb.evaluator
connect_info = _get_evaluator_connection_infos(evaluator)
component = _get_cb_connection_infos(
cb.evaluator.__class__.__name__, cb.test_connector
)
# found matching connector
if component is not None:
connect_components(component, connect_info)
callback_components.append(component)
if isinstance(cb, VisualizerCallback):
visualizer = cb.visualizer
connect_info = _get_vis_connection_infos(visualizer)
component = _get_cb_connection_infos(
cb.visualizer.__class__.__name__, cb.test_connector
)
# found matching connector
if component is not None:
connect_components(component, connect_info)
callback_components.append(component)
for e in callback_components:
log_str += print_box(e["name"], e["in_keys"], e["out_keys"])
return log_str
[docs]
def main(
argv: ArgsType, # pylint: disable=unused-argument
) -> None: # pragma: no cover
"""Main entry point to show connected components in the config.
>>> python -m vis4d.config.show_connection --config vis4d/zoo/faster_rcnn/faster_rcnn_coco.py
"""
config = _CONFIG.value
train_data_connector = instantiate_classes(config.train_data_connector)
test_data_connector = instantiate_classes(config.test_data_connector)
loss = instantiate_classes(config.loss)
model = instantiate_classes(config.model)
callbacks = [instantiate_classes(cb) for cb in config.callbacks]
dg = prints_datagraph_for_config(
model, train_data_connector, test_data_connector, loss, callbacks
)
print(dg)
if __name__ == "__main__": # pragma: no cover
app.run(main)