Coverage for watermark / auto_watermark.py: 84.78%
46 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1import importlib
2from watermark.base import BaseWatermark
3from typing import Union, Optional
4from utils.pipeline_utils import (
5 get_pipeline_type,
6 PIPELINE_TYPE_IMAGE,
7 PIPELINE_TYPE_TEXT_TO_VIDEO,
8 PIPELINE_TYPE_IMAGE_TO_VIDEO
9)
10from watermark.auto_config import AutoConfig
12WATERMARK_MAPPING_NAMES={
13 'TR': 'watermark.tr.TR',
14 'GS': 'watermark.gs.GS',
15 'PRC': 'watermark.prc.PRC',
16 'VideoShield': 'watermark.videoshield.VideoShieldWatermark',
17 "VideoMark": 'watermark.videomark.VideoMarkWatermark',
18 'RI': 'watermark.ri.RI',
19 'SEAL': 'watermark.seal.SEAL',
20 'ROBIN': 'watermark.robin.ROBIN',
21 'WIND': 'watermark.wind.WIND',
22 'SFW': 'watermark.sfw.SFW',
23 'GM': 'watermark.gm.GM'
24}
26# Dictionary mapping pipeline types to supported watermarking algorithms
27PIPELINE_SUPPORTED_WATERMARKS = {
28 PIPELINE_TYPE_IMAGE: ["TR", "GS", "PRC", "RI", "SEAL", "ROBIN", "WIND", "GM", "SFW"],
29 PIPELINE_TYPE_TEXT_TO_VIDEO: ["VideoShield", "VideoMark"],
30 PIPELINE_TYPE_IMAGE_TO_VIDEO: ["VideoShield", "VideoMark"]
31}
33def watermark_name_from_alg_name(name: str) -> Optional[str]:
34 """Get the watermark class name from the algorithm name."""
35 for algorithm_name, watermark_name in WATERMARK_MAPPING_NAMES.items():
36 if name.lower() == algorithm_name.lower():
37 return watermark_name
38 return None
40class AutoWatermark:
41 """
42 This is a generic watermark class that will be instantiated as one of the watermark classes of the library when
43 created with the [`AutoWatermark.load`] class method.
45 This class cannot be instantiated directly using `__init__()` (throws an error).
46 """
48 def __init__(self):
49 raise EnvironmentError(
50 "AutoWatermark is designed to be instantiated "
51 "using the `AutoWatermark.load(algorithm_name, algorithm_config, diffusion_config)` method."
52 )
54 @staticmethod
55 def _check_pipeline_compatibility(pipeline_type: str, algorithm_name: str) -> bool:
56 """Check if the pipeline type is compatible with the watermarking algorithm."""
57 if pipeline_type is None:
58 return False
60 if algorithm_name not in WATERMARK_MAPPING_NAMES:
61 return False
63 return algorithm_name in PIPELINE_SUPPORTED_WATERMARKS.get(pipeline_type, [])
65 @classmethod
66 def load(cls, algorithm_name, algorithm_config=None, diffusion_config=None, *args, **kwargs) -> BaseWatermark:
67 """Load the watermark algorithm instance based on the algorithm name."""
68 # Check if the algorithm exists
69 watermark_name = watermark_name_from_alg_name(algorithm_name)
70 if watermark_name is None:
71 supported_algs = list(WATERMARK_MAPPING_NAMES.keys())
72 raise ValueError(f"Invalid algorithm name: {algorithm_name}. Please use one of the supported algorithms: {', '.join(supported_algs)}")
74 # Check pipeline compatibility
75 if diffusion_config and diffusion_config.pipe:
76 pipeline_type = get_pipeline_type(diffusion_config.pipe)
77 if not cls._check_pipeline_compatibility(pipeline_type, algorithm_name):
78 supported_algs = PIPELINE_SUPPORTED_WATERMARKS.get(pipeline_type, [])
79 raise ValueError(
80 f"The algorithm '{algorithm_name}' is not compatible with the {pipeline_type} pipeline type. "
81 f"Supported algorithms for this pipeline type are: {', '.join(supported_algs)}"
82 )
84 # Load the watermark module
85 module_name, class_name = watermark_name.rsplit('.', 1)
86 module = importlib.import_module(module_name)
87 watermark_class = getattr(module, class_name)
88 watermark_config = AutoConfig.load(algorithm_name, diffusion_config, algorithm_config_path=algorithm_config, **kwargs)
89 watermark_instance = watermark_class(watermark_config)
90 return watermark_instance
92 @classmethod
93 def list_supported_algorithms(cls, pipeline_type: Optional[str] = None):
94 """List all supported watermarking algorithms, optionally filtered by pipeline type."""
95 if pipeline_type is None:
96 return list(WATERMARK_MAPPING_NAMES.keys())
97 else:
98 if pipeline_type not in PIPELINE_SUPPORTED_WATERMARKS:
99 raise ValueError(f"Unknown pipeline type: {pipeline_type}. Supported types are: {', '.join(PIPELINE_SUPPORTED_WATERMARKS.keys())}")
100 return PIPELINE_SUPPORTED_WATERMARKS[pipeline_type]
102# try all