Coverage for watermark / auto_config.py: 82.61%

23 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1import importlib 

2from typing import Dict, Optional, Any 

3from utils.diffusion_config import DiffusionConfig 

4 

5CONFIG_MAPPING_NAMES = { 

6 'TR': 'watermark.tr.TRConfig', 

7 'GS': 'watermark.gs.GSConfig', 

8 'PRC': 'watermark.prc.PRCConfig', 

9 'VideoShield': 'watermark.videoshield.VideoShieldConfig', 

10 "VideoMark": 'watermark.videomark.VideoMarkConfig', 

11 'RI': 'watermark.ri.RIConfig', 

12 'SEAL': 'watermark.seal.SEALConfig', 

13 'ROBIN': 'watermark.robin.ROBINConfig', 

14 'WIND': 'watermark.wind.WINDConfig', 

15 'SFW': 'watermark.sfw.SFWConfig', 

16 'GM': 'watermark.gm.GMConfig', 

17} 

18 

19def config_name_from_alg_name(name: str) -> Optional[str]: 

20 """Get the config class name from the algorithm name.""" 

21 if name in CONFIG_MAPPING_NAMES: 

22 return CONFIG_MAPPING_NAMES[name] 

23 else: 

24 raise ValueError(f"Invalid algorithm name: {name}") 

25 

26class AutoConfig: 

27 """ 

28 A generic configuration class that will be instantiated as one of the configuration classes 

29 of the library when created with the [`AutoConfig.load`] class method. 

30 

31 This class cannot be instantiated directly using `__init__()` (throws an error). 

32 """ 

33 

34 def __init__(self): 

35 raise EnvironmentError( 

36 "AutoConfig is designed to be instantiated " 

37 "using the `AutoConfig.load(algorithm_name, **kwargs)` method." 

38 ) 

39 

40 @classmethod 

41 def load(cls, algorithm_name: str, diffusion_config: DiffusionConfig, algorithm_config_path=None, **kwargs) -> Any: 

42 """ 

43 Load the configuration class for the specified watermark algorithm. 

44 

45 Args: 

46 algorithm_name (str): The name of the watermark algorithm 

47 diffusion_config (DiffusionConfig): Configuration for the diffusion model 

48 algorithm_config_path (str): Path to the algorithm configuration file 

49 **kwargs: Additional keyword arguments to pass to the configuration class 

50 

51 Returns: 

52 The instantiated configuration class for the specified algorithm 

53 """ 

54 config_name = config_name_from_alg_name(algorithm_name) 

55 if config_name is None: 

56 raise ValueError(f"Unknown algorithm name: {algorithm_name}") 

57 

58 module_name, class_name = config_name.rsplit('.', 1) 

59 module = importlib.import_module(module_name) 

60 config_class = getattr(module, class_name) 

61 if algorithm_config_path is None: 

62 algorithm_config_path = f'config/{algorithm_name}.json' 

63 config_instance = config_class(algorithm_config_path, diffusion_config, **kwargs) 

64 return config_instance