Coverage for markdiffusion / utils / utils.py: 98.15%

54 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-14 19:25 +0000

1import os 

2import json 

3import torch 

4import numpy as np 

5import random 

6 

7 

8_PACKAGE_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) 

9 

10 

11def default_algorithm_config_path(algorithm_name: str) -> str: 

12 """Resolve the default `config/{algorithm}.json` shipped with the package. 

13 

14 Falls back to the legacy CWD-relative `config/{algorithm}.json` if the 

15 bundled file is missing, so existing scripts that pass that path keep working. 

16 """ 

17 bundled = os.path.join(_PACKAGE_ROOT, "config", f"{algorithm_name}.json") 

18 if os.path.isfile(bundled): 

19 return bundled 

20 return os.path.join("config", f"{algorithm_name}.json") 

21 

22def inherit_docstring(cls): 

23 """ 

24 Inherit docstrings from base classes to methods without docstrings. 

25  

26 This decorator automatically applies the docstring from a base class method  

27 to a derived class method if the derived method doesn't have its own docstring. 

28  

29 Args: 

30 cls: The class to enhance with inherited docstrings 

31  

32 Returns: 

33 cls: The enhanced class 

34 """ 

35 for name, func in vars(cls).items(): 

36 if not callable(func) or func.__doc__ is not None: 

37 continue 

38 

39 # Look for same method in base classes 

40 for base in cls.__bases__: 

41 base_func = getattr(base, name, None) 

42 if base_func and getattr(base_func, "__doc__", None): 

43 func.__doc__ = base_func.__doc__ 

44 break 

45 

46 return cls 

47 

48 

49def load_config_file(path: str) -> dict: 

50 """Load a JSON configuration file from the specified path and return it as a dictionary.""" 

51 try: 

52 with open(path, 'r') as f: 

53 config_dict = json.load(f) 

54 return config_dict 

55 

56 except FileNotFoundError: 

57 print(f"Error: The file '{path}' does not exist.") 

58 return None 

59 except json.JSONDecodeError as e: 

60 print(f"Error decoding JSON in '{path}': {e}") 

61 # Handle other potential JSON decoding errors here 

62 return None 

63 except Exception as e: 

64 print(f"An unexpected error occurred: {e}") 

65 # Handle other unexpected errors here 

66 return None 

67 

68 

69def load_json_as_list(input_file: str) -> list: 

70 """Load a JSON file as a list of dictionaries.""" 

71 res = [] 

72 with open(input_file, 'r') as f: 

73 lines = f.readlines() 

74 for line in lines: 

75 d = json.loads(line) 

76 res.append(d) 

77 return res 

78 

79 

80def create_directory_for_file(file_path) -> None: 

81 """Create the directory for the specified file path if it does not already exist.""" 

82 directory = os.path.dirname(file_path) 

83 if not os.path.exists(directory): 

84 os.makedirs(directory) 

85 

86 

87def set_random_seed(seed: int): 

88 """Set random seeds for reproducibility.""" 

89 

90 torch.manual_seed(seed + 0) 

91 torch.cuda.manual_seed(seed + 1) 

92 torch.cuda.manual_seed_all(seed + 2) 

93 np.random.seed((seed + 3) % 2**32) 

94 torch.cuda.manual_seed_all(seed + 4) 

95 random.seed(seed + 5)