[docs]defget_all_keys(entry:NamedTuple)->list[str]:"""Get all keys in a NamedTuple."""keys=[]forkeyinentry._fields:ifis_namedtuple(getattr(entry,key)):keys.extend([f"{key}.{k}"forkinget_all_keys(getattr(entry,key))])else:keys.append(key)returnkeys
[docs]defget_from_namedtuple(entry:NamedTuple,key:str)->Any:# type: ignore"""Get a value from a nested Named tuple. Example passing key = "test.my.data" will resolve the value of the named tuple at 'test' 'my' 'data'. Raises: ValueError: If the key is not present in the named tuple. """keys=key.split(".")first_key=keys[0]ifnothasattr(entry,first_key):raiseValueError(f"Key {first_key} not in named tuple! Current keys: "f"{get_all_keys(entry)}")iflen(keys)==1:returngetattr(entry,first_key)returnget_from_namedtuple(getattr(entry,first_key),".".join(keys[1:]))
[docs]defis_namedtuple(obj:object)->bool:"""Check if obj is namedtuple. https://github.com/pytorch/pytorch/blob/v1.8.1/torch/nn/parallel/scatter_gather.py#L4-L8 """return(isinstance(obj,tuple)andhasattr(obj,"_asdict")andhasattr(obj,"_fields"))