Coverage for markdiffusion / utils / utils.py: 98.15%
54 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-14 19:25 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-14 19:25 +0000
1import os
2import json
3import torch
4import numpy as np
5import random
8_PACKAGE_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
11def default_algorithm_config_path(algorithm_name: str) -> str:
12 """Resolve the default `config/{algorithm}.json` shipped with the package.
14 Falls back to the legacy CWD-relative `config/{algorithm}.json` if the
15 bundled file is missing, so existing scripts that pass that path keep working.
16 """
17 bundled = os.path.join(_PACKAGE_ROOT, "config", f"{algorithm_name}.json")
18 if os.path.isfile(bundled):
19 return bundled
20 return os.path.join("config", f"{algorithm_name}.json")
22def inherit_docstring(cls):
23 """
24 Inherit docstrings from base classes to methods without docstrings.
26 This decorator automatically applies the docstring from a base class method
27 to a derived class method if the derived method doesn't have its own docstring.
29 Args:
30 cls: The class to enhance with inherited docstrings
32 Returns:
33 cls: The enhanced class
34 """
35 for name, func in vars(cls).items():
36 if not callable(func) or func.__doc__ is not None:
37 continue
39 # Look for same method in base classes
40 for base in cls.__bases__:
41 base_func = getattr(base, name, None)
42 if base_func and getattr(base_func, "__doc__", None):
43 func.__doc__ = base_func.__doc__
44 break
46 return cls
49def load_config_file(path: str) -> dict:
50 """Load a JSON configuration file from the specified path and return it as a dictionary."""
51 try:
52 with open(path, 'r') as f:
53 config_dict = json.load(f)
54 return config_dict
56 except FileNotFoundError:
57 print(f"Error: The file '{path}' does not exist.")
58 return None
59 except json.JSONDecodeError as e:
60 print(f"Error decoding JSON in '{path}': {e}")
61 # Handle other potential JSON decoding errors here
62 return None
63 except Exception as e:
64 print(f"An unexpected error occurred: {e}")
65 # Handle other unexpected errors here
66 return None
69def load_json_as_list(input_file: str) -> list:
70 """Load a JSON file as a list of dictionaries."""
71 res = []
72 with open(input_file, 'r') as f:
73 lines = f.readlines()
74 for line in lines:
75 d = json.loads(line)
76 res.append(d)
77 return res
80def create_directory_for_file(file_path) -> None:
81 """Create the directory for the specified file path if it does not already exist."""
82 directory = os.path.dirname(file_path)
83 if not os.path.exists(directory):
84 os.makedirs(directory)
87def set_random_seed(seed: int):
88 """Set random seeds for reproducibility."""
90 torch.manual_seed(seed + 0)
91 torch.cuda.manual_seed(seed + 1)
92 torch.cuda.manual_seed_all(seed + 2)
93 np.random.seed((seed + 3) % 2**32)
94 torch.cuda.manual_seed_all(seed + 4)
95 random.seed(seed + 5)