"""This module contains utilities for progress bar."""from__future__importannotationsimportdatetimefromtorchimportTensorfrom.timeimportTimerfrom.typingimportMetricLogs
[docs]defcompose_log_str(prefix:str,cur_iter:int,total_iters:int,timer:Timer,metrics:None|MetricLogs=None,)->str:"""Compose log str from given information."""time_sec_tot=timer.time()time_sec_avg=time_sec_tot/cur_itereta_sec=time_sec_avg*(total_iters-cur_iter)ifnoteta_sec==float("inf"):eta_str=str(datetime.timedelta(seconds=int(eta_sec)))else:# pragma: no covereta_str="---"metrics_list:list[str]=[]ifmetricsisnotNone:fork,vinmetrics.items():name=k.split("/")[-1]# remove prefix, e.g. train/lossifisinstance(v,(Tensor,float)):# display more digits for small valuesifabs(v)<1e-3:# type: ignore[operator]kv_str=f"{name}: {v:.3e}"else:kv_str=f"{name}: {v:.4f}"else:kv_str=f"{name}: {v}"ifname=="loss":# put total loss firstmetrics_list.insert(0,kv_str)else:metrics_list.append(kv_str)time_str=f"ETA: {eta_str}, "+(f"{time_sec_avg:.2f}s/it"iftime_sec_avg>1elsef"{1/time_sec_avg:.2f}it/s")logging_str=f"{prefix}: {cur_iter}/{total_iters}, {time_str}"iflen(metrics_list)>0:logging_str+=", "+", ".join(metrics_list)returnlogging_str