Coverage for watermark / auto_watermark.py: 84.78%

46 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1import importlib 

2from watermark.base import BaseWatermark 

3from typing import Union, Optional 

4from utils.pipeline_utils import ( 

5 get_pipeline_type, 

6 PIPELINE_TYPE_IMAGE, 

7 PIPELINE_TYPE_TEXT_TO_VIDEO, 

8 PIPELINE_TYPE_IMAGE_TO_VIDEO 

9) 

10from watermark.auto_config import AutoConfig 

11 

12WATERMARK_MAPPING_NAMES={ 

13 'TR': 'watermark.tr.TR', 

14 'GS': 'watermark.gs.GS', 

15 'PRC': 'watermark.prc.PRC', 

16 'VideoShield': 'watermark.videoshield.VideoShieldWatermark', 

17 "VideoMark": 'watermark.videomark.VideoMarkWatermark', 

18 'RI': 'watermark.ri.RI', 

19 'SEAL': 'watermark.seal.SEAL', 

20 'ROBIN': 'watermark.robin.ROBIN', 

21 'WIND': 'watermark.wind.WIND', 

22 'SFW': 'watermark.sfw.SFW', 

23 'GM': 'watermark.gm.GM' 

24} 

25 

26# Dictionary mapping pipeline types to supported watermarking algorithms 

27PIPELINE_SUPPORTED_WATERMARKS = { 

28 PIPELINE_TYPE_IMAGE: ["TR", "GS", "PRC", "RI", "SEAL", "ROBIN", "WIND", "GM", "SFW"], 

29 PIPELINE_TYPE_TEXT_TO_VIDEO: ["VideoShield", "VideoMark"], 

30 PIPELINE_TYPE_IMAGE_TO_VIDEO: ["VideoShield", "VideoMark"] 

31} 

32 

33def watermark_name_from_alg_name(name: str) -> Optional[str]: 

34 """Get the watermark class name from the algorithm name.""" 

35 for algorithm_name, watermark_name in WATERMARK_MAPPING_NAMES.items(): 

36 if name.lower() == algorithm_name.lower(): 

37 return watermark_name 

38 return None 

39 

40class AutoWatermark: 

41 """ 

42 This is a generic watermark class that will be instantiated as one of the watermark classes of the library when 

43 created with the [`AutoWatermark.load`] class method. 

44 

45 This class cannot be instantiated directly using `__init__()` (throws an error). 

46 """ 

47 

48 def __init__(self): 

49 raise EnvironmentError( 

50 "AutoWatermark is designed to be instantiated " 

51 "using the `AutoWatermark.load(algorithm_name, algorithm_config, diffusion_config)` method." 

52 ) 

53 

54 @staticmethod 

55 def _check_pipeline_compatibility(pipeline_type: str, algorithm_name: str) -> bool: 

56 """Check if the pipeline type is compatible with the watermarking algorithm.""" 

57 if pipeline_type is None: 

58 return False 

59 

60 if algorithm_name not in WATERMARK_MAPPING_NAMES: 

61 return False 

62 

63 return algorithm_name in PIPELINE_SUPPORTED_WATERMARKS.get(pipeline_type, []) 

64 

65 @classmethod 

66 def load(cls, algorithm_name, algorithm_config=None, diffusion_config=None, *args, **kwargs) -> BaseWatermark: 

67 """Load the watermark algorithm instance based on the algorithm name.""" 

68 # Check if the algorithm exists 

69 watermark_name = watermark_name_from_alg_name(algorithm_name) 

70 if watermark_name is None: 

71 supported_algs = list(WATERMARK_MAPPING_NAMES.keys()) 

72 raise ValueError(f"Invalid algorithm name: {algorithm_name}. Please use one of the supported algorithms: {', '.join(supported_algs)}") 

73 

74 # Check pipeline compatibility 

75 if diffusion_config and diffusion_config.pipe: 

76 pipeline_type = get_pipeline_type(diffusion_config.pipe) 

77 if not cls._check_pipeline_compatibility(pipeline_type, algorithm_name): 

78 supported_algs = PIPELINE_SUPPORTED_WATERMARKS.get(pipeline_type, []) 

79 raise ValueError( 

80 f"The algorithm '{algorithm_name}' is not compatible with the {pipeline_type} pipeline type. " 

81 f"Supported algorithms for this pipeline type are: {', '.join(supported_algs)}" 

82 ) 

83 

84 # Load the watermark module 

85 module_name, class_name = watermark_name.rsplit('.', 1) 

86 module = importlib.import_module(module_name) 

87 watermark_class = getattr(module, class_name) 

88 watermark_config = AutoConfig.load(algorithm_name, diffusion_config, algorithm_config_path=algorithm_config, **kwargs) 

89 watermark_instance = watermark_class(watermark_config) 

90 return watermark_instance 

91 

92 @classmethod 

93 def list_supported_algorithms(cls, pipeline_type: Optional[str] = None): 

94 """List all supported watermarking algorithms, optionally filtered by pipeline type.""" 

95 if pipeline_type is None: 

96 return list(WATERMARK_MAPPING_NAMES.keys()) 

97 else: 

98 if pipeline_type not in PIPELINE_SUPPORTED_WATERMARKS: 

99 raise ValueError(f"Unknown pipeline type: {pipeline_type}. Supported types are: {', '.join(PIPELINE_SUPPORTED_WATERMARKS.keys())}") 

100 return PIPELINE_SUPPORTED_WATERMARKS[pipeline_type] 

101 

102# try all