"""This module contains utilities for pretty printing."""fromtypingimportAnyimportnumpyasnpimporttorch
[docs]classPrettyRepMixin:"""Creates a pretty string representation of a class with parameters. Examples: >>> class TestClass(PrettyRepMixin): ... def __init__(self, a: int, b: str): ... self.a = a ... self.b = b >>> obj = TestClass(1, 'hello') >>> str(obj) 'TestClass(a=1, b=hello)' """
[docs]def__repr__(self)->str:"""Return a string representation of the class and its parameters. Returns: The string representation of the class and its parameters. Examples: >>> class TestClass(PrettyRepMixin): ... def __init__(self, a: int, b: str): ... self.a = a ... self.b = b >>> obj = TestClass(1, 'hello') >>> obj.__repr__() 'TestClass(a=1, b=hello)' """attr_str=""fork,vinvars(self).items():ifk!="type"andnotk.startswith("_"):attr_str+=f"{k}={str(v)}, "attr_str=attr_str.rstrip(", ")returnf"{self.__class__.__name__}({attr_str})"
[docs]defdescribe_shape(obj:Any)->str:# type: ignore"""Recursively output the shape of tensors in an object's structure. Args: obj (Any): The object to describe the shape of. Can be a dictionary, list, torch.Tensor, numpy.ndarray, float, or any other type. Returns: str: A string representing the shapes of all tensors in the object's structure. Examples: >>> describe_shape({'a': torch.rand(2, 3)}) "{a: shape[2, 3]}" >>> describe_shape({'a': [torch.rand(2, 3), torch.rand(4, 5)]}) "{a: [shape[2, 3], shape[4, 5]]}" >>> describe_shape([torch.rand(2, 3), {'a': torch.rand(4, 5)}]) "[shape[2, 3], {a: shape[4, 5]}]" """log_str=""ifisinstance(obj,dict):log_str+="{"log_str+=", ".join([f"{k}: {describe_shape(obj[k])}"forkinobj.keys()])log_str+="}"elifisinstance(obj,list):log_str+="["log_str+=", ".join([describe_shape(v)forvinobj])log_str+="]"elifisinstance(obj,(torch.Tensor,np.ndarray)):log_str+=f"shape[{', '.join([str(s)forsinobj.shape])}]"elifisinstance(obj,float):log_str+=f"{obj:.4f}"else:log_str+=str(obj)returnlog_str