"""Utility functions for common usage."""importrandomfromdifflibimportget_close_matchesimportnumpyasnpimporttorchfrompackagingimportversionfrom.importsimportis_torch_tf32_availablefrom.loggingimportrank_zero_info,rank_zero_warn
[docs]defcreate_did_you_mean_msg(keys:list[str],query:str)->str:"""Create a did you mean message. Args: keys (list[str]): List of available keys. query (str): Query. Returns: str: Did you mean message. Examples: >>> keys = ["foo", "bar", "baz"] >>> query = "fo" >>> print(create_did_you_mean_msg(keys, query)) Did you mean: foo """msg=""iflen(keys)>0:msg="Did you mean:\n\t"msg+="\n\t".join(get_close_matches(query,keys,cutoff=0.75))returnmsg
[docs]defset_tf32(use_tf32:bool,precision:str)->None:# pragma: no cover"""Set torch TF32. Args: use_tf32: Whether to use torch TF32. Details: https://pytorch.org/docs/stable/notes/cuda.html#tf32-on-ampere precision: Internal precision of float32 matrix multiplications. Details: https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision # pylint: disable=line-too-long """ifuse_tf32:# pragma: no coverrank_zero_info("Using Torch TF32. "+"It might harm the performance due to the precision. "+"You can turn it off by setting config.use_tf32=False.")ifnotis_torch_tf32_available():rank_zero_warn("Torch TF32 is not available.")elif(version.parse("1.11")>=version.parse(torch.__version__)>=version.parse("1.7")):rank_zero_info("Torch TF32 is turned on by default!")else:rank_zero_info("Turn on Torch TF32 on matmul.")torch.backends.cuda.matmul.allow_tf32=Truetorch.backends.cudnn.allow_tf32=Trueelse:torch.backends.cuda.matmul.allow_tf32=Falsetorch.backends.cudnn.allow_tf32=False# Control the precision of matmul operations.# Equivalent to setting torch.backends.cuda.matmul.allow_tf32.torch.set_float32_matmul_precision(precision)
[docs]definit_random_seed()->int:"""Initialize random seed for the experiment."""returnnp.random.randint(2**31)
[docs]defset_random_seed(seed:int,deterministic:bool=False)->None:"""Set random seed. Args: seed (int): Seed to be used. deterministic (bool): Whether to set the deterministic option for CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` to True and `torch.backends.cudnn.benchmark` to False. Default: False. """random.seed(seed)np.random.seed(seed)torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)ifdeterministic:torch.backends.cudnn.deterministic=Truetorch.backends.cudnn.benchmark=False