Coverage for markdiffusion / watermark / auto_config.py: 83.33%

24 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-14 19:25 +0000

1import importlib 

2from typing import Dict, Optional, Any 

3from markdiffusion.utils.diffusion_config import DiffusionConfig 

4from markdiffusion.utils.utils import default_algorithm_config_path 

5 

6CONFIG_MAPPING_NAMES = { 

7 'TR': 'markdiffusion.watermark.tr.TRConfig', 

8 'GS': 'markdiffusion.watermark.gs.GSConfig', 

9 'PRC': 'markdiffusion.watermark.prc.PRCConfig', 

10 'VideoShield': 'markdiffusion.watermark.videoshield.VideoShieldConfig', 

11 "VideoMark": 'markdiffusion.watermark.videomark.VideoMarkConfig', 

12 'RI': 'markdiffusion.watermark.ri.RIConfig', 

13 'SEAL': 'markdiffusion.watermark.seal.SEALConfig', 

14 'ROBIN': 'markdiffusion.watermark.robin.ROBINConfig', 

15 'WIND': 'markdiffusion.watermark.wind.WINDConfig', 

16 'SFW': 'markdiffusion.watermark.sfw.SFWConfig', 

17 'GM': 'markdiffusion.watermark.gm.GMConfig', 

18} 

19 

20def config_name_from_alg_name(name: str) -> Optional[str]: 

21 """Get the config class name from the algorithm name.""" 

22 if name in CONFIG_MAPPING_NAMES: 

23 return CONFIG_MAPPING_NAMES[name] 

24 else: 

25 raise ValueError(f"Invalid algorithm name: {name}") 

26 

27class AutoConfig: 

28 """ 

29 A generic configuration class that will be instantiated as one of the configuration classes 

30 of the library when created with the [`AutoConfig.load`] class method. 

31 

32 This class cannot be instantiated directly using `__init__()` (throws an error). 

33 """ 

34 

35 def __init__(self): 

36 raise EnvironmentError( 

37 "AutoConfig is designed to be instantiated " 

38 "using the `AutoConfig.load(algorithm_name, **kwargs)` method." 

39 ) 

40 

41 @classmethod 

42 def load(cls, algorithm_name: str, diffusion_config: DiffusionConfig, algorithm_config_path=None, **kwargs) -> Any: 

43 """ 

44 Load the configuration class for the specified watermark algorithm. 

45 

46 Args: 

47 algorithm_name (str): The name of the watermark algorithm 

48 diffusion_config (DiffusionConfig): Configuration for the diffusion model 

49 algorithm_config_path (str): Path to the algorithm configuration file 

50 **kwargs: Additional keyword arguments to pass to the configuration class 

51 

52 Returns: 

53 The instantiated configuration class for the specified algorithm 

54 """ 

55 config_name = config_name_from_alg_name(algorithm_name) 

56 if config_name is None: 

57 raise ValueError(f"Unknown algorithm name: {algorithm_name}") 

58 

59 module_name, class_name = config_name.rsplit('.', 1) 

60 module = importlib.import_module(module_name) 

61 config_class = getattr(module, class_name) 

62 if algorithm_config_path is None: 

63 algorithm_config_path = default_algorithm_config_path(algorithm_name) 

64 config_instance = config_class(algorithm_config_path, diffusion_config, **kwargs) 

65 return config_instance