Coverage for watermark / auto_config.py: 82.61%
23 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1import importlib
2from typing import Dict, Optional, Any
3from utils.diffusion_config import DiffusionConfig
5CONFIG_MAPPING_NAMES = {
6 'TR': 'watermark.tr.TRConfig',
7 'GS': 'watermark.gs.GSConfig',
8 'PRC': 'watermark.prc.PRCConfig',
9 'VideoShield': 'watermark.videoshield.VideoShieldConfig',
10 "VideoMark": 'watermark.videomark.VideoMarkConfig',
11 'RI': 'watermark.ri.RIConfig',
12 'SEAL': 'watermark.seal.SEALConfig',
13 'ROBIN': 'watermark.robin.ROBINConfig',
14 'WIND': 'watermark.wind.WINDConfig',
15 'SFW': 'watermark.sfw.SFWConfig',
16 'GM': 'watermark.gm.GMConfig',
17}
19def config_name_from_alg_name(name: str) -> Optional[str]:
20 """Get the config class name from the algorithm name."""
21 if name in CONFIG_MAPPING_NAMES:
22 return CONFIG_MAPPING_NAMES[name]
23 else:
24 raise ValueError(f"Invalid algorithm name: {name}")
26class AutoConfig:
27 """
28 A generic configuration class that will be instantiated as one of the configuration classes
29 of the library when created with the [`AutoConfig.load`] class method.
31 This class cannot be instantiated directly using `__init__()` (throws an error).
32 """
34 def __init__(self):
35 raise EnvironmentError(
36 "AutoConfig is designed to be instantiated "
37 "using the `AutoConfig.load(algorithm_name, **kwargs)` method."
38 )
40 @classmethod
41 def load(cls, algorithm_name: str, diffusion_config: DiffusionConfig, algorithm_config_path=None, **kwargs) -> Any:
42 """
43 Load the configuration class for the specified watermark algorithm.
45 Args:
46 algorithm_name (str): The name of the watermark algorithm
47 diffusion_config (DiffusionConfig): Configuration for the diffusion model
48 algorithm_config_path (str): Path to the algorithm configuration file
49 **kwargs: Additional keyword arguments to pass to the configuration class
51 Returns:
52 The instantiated configuration class for the specified algorithm
53 """
54 config_name = config_name_from_alg_name(algorithm_name)
55 if config_name is None:
56 raise ValueError(f"Unknown algorithm name: {algorithm_name}")
58 module_name, class_name = config_name.rsplit('.', 1)
59 module = importlib.import_module(module_name)
60 config_class = getattr(module, class_name)
61 if algorithm_config_path is None:
62 algorithm_config_path = f'config/{algorithm_name}.json'
63 config_instance = config_class(algorithm_config_path, diffusion_config, **kwargs)
64 return config_instance