Coverage for markdiffusion / watermark / auto_config.py: 83.33%
24 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-14 19:25 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-14 19:25 +0000
1import importlib
2from typing import Dict, Optional, Any
3from markdiffusion.utils.diffusion_config import DiffusionConfig
4from markdiffusion.utils.utils import default_algorithm_config_path
6CONFIG_MAPPING_NAMES = {
7 'TR': 'markdiffusion.watermark.tr.TRConfig',
8 'GS': 'markdiffusion.watermark.gs.GSConfig',
9 'PRC': 'markdiffusion.watermark.prc.PRCConfig',
10 'VideoShield': 'markdiffusion.watermark.videoshield.VideoShieldConfig',
11 "VideoMark": 'markdiffusion.watermark.videomark.VideoMarkConfig',
12 'RI': 'markdiffusion.watermark.ri.RIConfig',
13 'SEAL': 'markdiffusion.watermark.seal.SEALConfig',
14 'ROBIN': 'markdiffusion.watermark.robin.ROBINConfig',
15 'WIND': 'markdiffusion.watermark.wind.WINDConfig',
16 'SFW': 'markdiffusion.watermark.sfw.SFWConfig',
17 'GM': 'markdiffusion.watermark.gm.GMConfig',
18}
20def config_name_from_alg_name(name: str) -> Optional[str]:
21 """Get the config class name from the algorithm name."""
22 if name in CONFIG_MAPPING_NAMES:
23 return CONFIG_MAPPING_NAMES[name]
24 else:
25 raise ValueError(f"Invalid algorithm name: {name}")
27class AutoConfig:
28 """
29 A generic configuration class that will be instantiated as one of the configuration classes
30 of the library when created with the [`AutoConfig.load`] class method.
32 This class cannot be instantiated directly using `__init__()` (throws an error).
33 """
35 def __init__(self):
36 raise EnvironmentError(
37 "AutoConfig is designed to be instantiated "
38 "using the `AutoConfig.load(algorithm_name, **kwargs)` method."
39 )
41 @classmethod
42 def load(cls, algorithm_name: str, diffusion_config: DiffusionConfig, algorithm_config_path=None, **kwargs) -> Any:
43 """
44 Load the configuration class for the specified watermark algorithm.
46 Args:
47 algorithm_name (str): The name of the watermark algorithm
48 diffusion_config (DiffusionConfig): Configuration for the diffusion model
49 algorithm_config_path (str): Path to the algorithm configuration file
50 **kwargs: Additional keyword arguments to pass to the configuration class
52 Returns:
53 The instantiated configuration class for the specified algorithm
54 """
55 config_name = config_name_from_alg_name(algorithm_name)
56 if config_name is None:
57 raise ValueError(f"Unknown algorithm name: {algorithm_name}")
59 module_name, class_name = config_name.rsplit('.', 1)
60 module = importlib.import_module(module_name)
61 config_class = getattr(module, class_name)
62 if algorithm_config_path is None:
63 algorithm_config_path = default_algorithm_config_path(algorithm_name)
64 config_instance = config_class(algorithm_config_path, diffusion_config, **kwargs)
65 return config_instance