Coverage for watermark / base.py: 93.81%
194 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16from abc import ABC, abstractmethod
17import torch
18from typing import Dict, List, Union, Optional, Any, Tuple
19from utils.diffusion_config import DiffusionConfig
20from utils.utils import load_config_file, set_random_seed
21from utils.media_utils import *
22from utils.pipeline_utils import (
23 get_pipeline_type,
24 is_image_pipeline,
25 is_video_pipeline,
26 is_t2v_pipeline,
27 is_i2v_pipeline,
28 PIPELINE_TYPE_IMAGE,
29 PIPELINE_TYPE_TEXT_TO_VIDEO,
30 PIPELINE_TYPE_IMAGE_TO_VIDEO
31)
32from PIL import Image
33from diffusers import (
34 StableDiffusionPipeline,
35 TextToVideoSDPipeline,
36 StableVideoDiffusionPipeline,
37 DDIMInverseScheduler
38)
40class BaseConfig(ABC):
41 """Base configuration class for diffusion watermarking methods."""
43 def __init__(self, algorithm_config: str, diffusion_config: DiffusionConfig, *args, **kwargs) -> None:
44 """Initialize base configuration with common parameters."""
46 # Load config file
47 self.config_dict = load_config_file(f'config/{self.algorithm_name()}.json') if algorithm_config is None else load_config_file(algorithm_config)
49 # Diffusion model parameters
50 if diffusion_config is None:
51 raise ValueError("diffusion_config cannot be None for BaseConfig initialization")
53 if kwargs:
54 self.config_dict.update(kwargs)
56 self.pipe = diffusion_config.pipe
57 self.scheduler = diffusion_config.scheduler
58 self.device = diffusion_config.device
59 self.guidance_scale = diffusion_config.guidance_scale
60 self.num_images = diffusion_config.num_images
61 self.num_inference_steps = diffusion_config.num_inference_steps
62 self.num_inversion_steps = diffusion_config.num_inversion_steps
63 self.image_size = diffusion_config.image_size
64 self.dtype = diffusion_config.dtype
65 self.gen_seed = diffusion_config.gen_seed
66 self.init_latents_seed = diffusion_config.init_latents_seed
67 self.inversion_type = diffusion_config.inversion_type
68 self.num_frames = diffusion_config.num_frames
70 # Set inversion module
71 self.inversion = set_inversion(self.pipe, self.inversion_type)
72 # Set generation kwargs
73 self.gen_kwargs = diffusion_config.gen_kwargs
75 # Get initial latents
76 init_latents_rng = torch.Generator(device=self.device)
77 init_latents_rng.manual_seed(self.init_latents_seed)
78 if self.num_frames < 1:
79 self.init_latents = get_random_latents(self.pipe, height=self.image_size[0], width=self.image_size[1], generator=init_latents_rng)
80 else:
81 self.init_latents = get_random_latents(self.pipe, num_frames=self.num_frames, height=self.image_size[0], width=self.image_size[1], generator=init_latents_rng)
83 # Initialize algorithm-specific parameters
84 self.initialize_parameters()
86 @abstractmethod
87 def initialize_parameters(self) -> None:
88 """Initialize algorithm-specific parameters. Should be overridden by subclasses."""
89 raise NotImplementedError
91 @property
92 def algorithm_name(self) -> str:
93 """Return the algorithm name."""
94 raise NotImplementedError
96class BaseWatermark(ABC):
97 """Base class for diffusion watermarking methods."""
99 def __init__(self,
100 config: BaseConfig,
101 *args, **kwargs) -> None:
102 """Initialize the watermarking algorithm."""
103 self.config = config
104 self.orig_watermarked_latents = None
106 # Determine pipeline type
107 self.pipeline_type = self._detect_pipeline_type()
109 # Validate pipeline configuration
110 self._validate_pipeline_config()
112 def _detect_pipeline_type(self) -> str:
113 """Detect the type of pipeline being used."""
114 pipeline_type = get_pipeline_type(self.config.pipe)
115 if pipeline_type is None:
116 raise ValueError(f"Unsupported pipeline type: {type(self.config.pipe)}")
117 return pipeline_type
119 def _validate_pipeline_config(self) -> None:
120 """Validate that the pipeline configuration is correct for the pipeline type."""
121 # For image-to-video pipelines, ensure num_frames is set correctly
122 if self.pipeline_type == PIPELINE_TYPE_IMAGE_TO_VIDEO or self.pipeline_type == PIPELINE_TYPE_TEXT_TO_VIDEO:
123 if self.config.num_frames < 1:
124 raise ValueError(f"For {self.pipeline_type} pipelines, num_frames must be >= 1, got {self.config.num_frames}")
125 # For image pipelines, ensure num_frames is -1
126 elif self.pipeline_type == PIPELINE_TYPE_IMAGE:
127 if self.config.num_frames >= 1:
128 raise ValueError(f"For {self.pipeline_type} pipelines, num_frames should be -1, got {self.config.num_frames}")
130 def get_orig_watermarked_latents(self) -> torch.Tensor:
131 """Get the original watermarked latents."""
132 return self.orig_watermarked_latents
134 def set_orig_watermarked_latents(self, value: torch.Tensor) -> None:
135 """Set the original watermarked latents."""
136 self.orig_watermarked_latents = value
138 def generate_watermarked_media(self,
139 input_data: Union[str, Image.Image],
140 *args,
141 **kwargs) -> Union[Image.Image, List[Image.Image]]:
142 """
143 Generate watermarked media (image or video) based on pipeline type.
145 This is the main interface for generating watermarked content with any
146 watermarking algorithm. It automatically routes to the appropriate generation
147 method based on the pipeline type (image or video).
149 Args:
150 input_data: Text prompt (for T2I or T2V) or input image (for I2V)
151 *args: Additional positional arguments
152 **kwargs: Additional keyword arguments, including:
153 - guidance_scale: Guidance scale for generation
154 - num_inference_steps: Number of inference steps
155 - height, width: Dimensions of generated media
156 - seed: Random seed for generation
158 Returns:
159 Union[Image.Image, List[Image.Image]]: Generated watermarked media
160 - For image pipelines: Returns a single PIL Image
161 - For video pipelines: Returns a list of PIL Images (frames)
163 Examples:
164 ```python
165 # Image watermarking
166 watermark = AutoWatermark.load('TR', diffusion_config=config)
167 image = watermark.generate_watermarked_media(
168 input_data="A beautiful landscape",
169 guidance_scale=7.5,
170 num_inference_steps=50
171 )
173 # Video watermarking (T2V)
174 watermark = AutoWatermark.load('VideoShield', diffusion_config=config)
175 frames = watermark.generate_watermarked_media(
176 input_data="A dog running in a park",
177 num_frames=16
178 )
180 # Video watermarking (I2V)
181 watermark = AutoWatermark.load('VideoShield', diffusion_config=config)
182 frames = watermark.generate_watermarked_media(
183 input_data=reference_image,
184 num_frames=16
185 )
186 ```
187 """
188 # Route to the appropriate generation method based on pipeline type
189 if is_image_pipeline(self.config.pipe):
190 if not isinstance(input_data, str):
191 raise ValueError("For image generation, input_data must be a text prompt (string)")
192 return self._generate_watermarked_image(input_data, *args, **kwargs)
193 elif is_video_pipeline(self.config.pipe):
194 return self._generate_watermarked_video(input_data, *args, **kwargs)
196 def generate_unwatermarked_media(self,
197 input_data: Union[str, Image.Image],
198 *args,
199 **kwargs) -> Union[Image.Image, List[Image.Image]]:
200 """
201 Generate unwatermarked media (image or video) based on pipeline type.
203 Args:
204 input_data: Text prompt (for T2I or T2V) or input image (for I2V)
205 *args: Additional positional arguments
206 **kwargs: Additional keyword arguments, including:
207 - save_path: Path to save the generated media
209 Returns:
210 Union[Image.Image, List[Image.Image]]: Generated unwatermarked media
211 """
212 # Route to the appropriate generation method based on pipeline type
213 if is_image_pipeline(self.config.pipe):
214 if not isinstance(input_data, str):
215 raise ValueError("For image generation, input_data must be a text prompt (string)")
216 return self._generate_unwatermarked_image(input_data, *args, **kwargs)
217 elif is_video_pipeline(self.config.pipe):
218 return self._generate_unwatermarked_video(input_data, *args, **kwargs)
220 def detect_watermark_in_media(self,
221 media: Union[Image.Image, List[Image.Image], np.ndarray, torch.Tensor],
222 *args,
223 **kwargs) -> Dict[str, Any]:
224 """
225 Detect watermark in media (image or video).
227 Args:
228 media: The media to detect watermark in (can be PIL image, list of frames, numpy array, or tensor)
229 *args: Additional positional arguments
230 **kwargs: Additional keyword arguments, including:
231 - prompt: Optional text prompt used to generate the media (for some algorithms)
232 - num_inference_steps: Optional number of inference steps
233 - guidance_scale: Optional guidance scale
234 - num_frames: Optional number of frames
235 - decoder_inv: Optional decoder inversion
236 - inv_order: Inverse order for Exact Inversion
237 - detector_type: Type of detector to use
239 Returns:
240 Dict[str, Any]: Detection results with metrics and possibly visualizations
241 """
242 # Process the input media into the right format based on pipeline type
243 processed_media = self._preprocess_media_for_detection(media)
245 # Route to the appropriate detection method
246 if is_image_pipeline(self.config.pipe):
247 return self._detect_watermark_in_image(
248 processed_media,
249 *args,
250 **kwargs
251 )
252 else:
253 return self._detect_watermark_in_video(
254 processed_media,
255 *args,
256 **kwargs
257 )
259 def _preprocess_media_for_detection(self,
260 media: Union[Image.Image, List[Image.Image], np.ndarray, torch.Tensor]
261 ) -> Union[Image.Image, List[Image.Image], torch.Tensor]:
262 """
263 Preprocess media for detection based on its type and the pipeline type.
265 Args:
266 media: The media to preprocess
268 Returns:
269 Union[Image.Image, List[Image.Image], torch.Tensor]: Preprocessed media
270 """
271 if is_image_pipeline(self.config.pipe):
272 if isinstance(media, Image.Image):
273 return media
274 elif isinstance(media, np.ndarray):
275 return cv2_to_pil(media)
276 elif isinstance(media, torch.Tensor):
277 # Convert tensor to PIL image
278 if media.dim() == 3: # C, H, W
279 media = media.unsqueeze(0) # Add batch dimension
280 img_np = torch_to_numpy(media)[0] # Take first image
281 return cv2_to_pil(img_np)
282 elif isinstance(media, list): # Compatible for detection pipeline
283 return media[0]
284 else:
285 raise ValueError(f"Unsupported media type for image pipeline: {type(media)}")
286 else:
287 # Video pipeline
288 if isinstance(media, list):
289 # List of frames
290 if all(isinstance(frame, Image.Image) for frame in media):
291 return media
292 elif all(isinstance(frame, np.ndarray) for frame in media):
293 return [cv2_to_pil(frame) for frame in media]
294 else:
295 raise ValueError("All frames must be either PIL images or numpy arrays")
296 elif isinstance(media, np.ndarray):
297 # Convert numpy video to list of PIL images
298 if media.ndim == 4: # F, H, W, C
299 return [cv2_to_pil(frame) for frame in media]
300 else:
301 raise ValueError(f"Unsupported numpy array shape for video: {media.shape}")
302 elif isinstance(media, torch.Tensor):
303 # Convert tensor to list of PIL images
304 if media.dim() == 5: # B, C, F, H, W
305 video_np = torch_to_numpy(media)[0] # Take first batch
306 return [cv2_to_pil(frame) for frame in video_np]
307 elif media.dim() == 4 and media.shape[0] > 3: # F, C, H, W (assuming F > 3)
308 frames = []
309 for i in range(media.shape[0]):
310 frame_np = torch_to_numpy(media[i].unsqueeze(0))[0]
311 frames.append(cv2_to_pil(frame_np))
312 return frames
313 else:
314 raise ValueError(f"Unsupported tensor shape for video: {media.shape}")
315 else:
316 raise ValueError(f"Unsupported media type for video pipeline: {type(media)}")
318 def _generate_watermarked_image(self,
319 prompt: str,
320 *args,
321 **kwargs) -> Image.Image:
322 """
323 Generate watermarked image from text prompt.
325 Parameters:
326 prompt (str): The input prompt.
328 Returns:
329 Image.Image: The generated watermarked image.
331 Raises:
332 ValueError: If the pipeline doesn't support image generation.
333 """
334 if self.pipeline_type != PIPELINE_TYPE_IMAGE:
335 raise ValueError(f"This pipeline ({self.pipeline_type}) does not support image generation. Use generate_watermarked_video instead.")
337 # The implementation depends on the specific watermarking algorithm
338 # This method should be implemented by subclasses
339 raise NotImplementedError("This method is not implemented for this watermarking algorithm.")
341 def _generate_watermarked_video(self,
342 input_data: Union[str, Image.Image],
343 *args,
344 **kwargs) -> Union[List[Image.Image], Image.Image]:
345 """
346 Generate watermarked video based on text prompt or input image.
348 Parameters:
349 input_data (Union[str, Image.Image]): Either a text prompt (for T2V) or an input image (for I2V).
350 - If the pipeline is T2V, input_data should be a string prompt.
351 - If the pipeline is I2V, input_data should be an Image object or can be passed as kwargs['input_image'].
352 kwargs:
353 - 'input_image': The input image for I2V pipelines.
354 - 'prompt': The text prompt for T2V pipelines.
355 - 'image_path': The path to the input image for I2V pipelines.
357 Returns:
358 Union[List[Image.Image], Image.Image]: The generated watermarked video frames.
360 Raises:
361 ValueError: If the pipeline doesn't support video generation or if input type is incompatible.
362 """
363 if not is_video_pipeline(self.config.pipe):
364 raise ValueError(f"This pipeline ({self.pipeline_type}) does not support video generation. Use generate_watermarked_image instead.")
366 # The implementation depends on the specific watermarking algorithm
367 # This method should be implemented by subclasses
368 raise NotImplementedError("This method is not implemented for this watermarking algorithm.")
370 def _generate_unwatermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image:
371 """
372 Generate unwatermarked image from text prompt.
374 Parameters:
375 prompt (str): The input prompt.
377 Returns:
378 Image.Image: The generated unwatermarked image.
380 Raises:
381 ValueError: If the pipeline doesn't support image generation.
382 """
383 if not is_image_pipeline(self.config.pipe):
384 raise ValueError(f"This pipeline ({self.pipeline_type}) does not support image generation. Use generate_unwatermarked_video instead.")
386 # Construct generation parameters
387 generation_params = {
388 "num_images_per_prompt": self.config.num_images,
389 "guidance_scale": self.config.guidance_scale,
390 "num_inference_steps": self.config.num_inference_steps,
391 "height": self.config.image_size[0],
392 "width": self.config.image_size[1],
393 "latents": self.config.init_latents,
394 }
396 # Add parameters from config.gen_kwargs
397 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
398 for key, value in self.config.gen_kwargs.items():
399 if key not in generation_params:
400 generation_params[key] = value
402 # Use kwargs to override default parameters
403 for key, value in kwargs.items():
404 generation_params[key] = value
406 set_random_seed(self.config.gen_seed)
407 return self.config.pipe(
408 prompt,
409 **generation_params
410 ).images[0]
412 def _generate_unwatermarked_video(self, input_data: Union[str, Image.Image], *args, **kwargs) -> List[Image.Image]:
413 """
414 Generate unwatermarked video based on text prompt or input image.
416 Parameters:
417 input_data (Union[str, Image.Image]): Either a text prompt (for T2V) or an input image (for I2V).
418 - If the pipeline is T2V, input_data should be a string prompt.
419 - If the pipeline is I2V, input_data should be an Image object or can be passed as kwargs['input_image'].
420 kwargs:
421 - 'input_image': The input image for I2V pipelines.
422 - 'prompt': The text prompt for T2V pipelines.
423 - 'image_path': The path to the input image for I2V pipelines.
425 Returns:
426 List[Image.Image]: The generated unwatermarked video frames.
428 Raises:
429 ValueError: If the pipeline doesn't support video generation or if input type is incompatible.
430 """
431 if not is_video_pipeline(self.config.pipe):
432 raise ValueError(f"This pipeline ({self.pipeline_type}) does not support video generation. Use generate_unwatermarked_image instead.")
434 # Handle Text-to-Video pipeline
435 if is_t2v_pipeline(self.config.pipe):
436 # For T2V, input should be a text prompt
437 if not isinstance(input_data, str):
438 raise ValueError("Text-to-Video pipeline requires a text prompt as input_data")
440 # Construct generation parameters
441 generation_params = {
442 "latents": self.config.init_latents,
443 "num_frames": self.config.num_frames,
444 "height": self.config.image_size[0],
445 "width": self.config.image_size[1],
446 "num_inference_steps": self.config.num_inference_steps,
447 "guidance_scale": self.config.guidance_scale,
448 }
450 # Add parameters from config.gen_kwargs
451 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
452 for key, value in self.config.gen_kwargs.items():
453 if key not in generation_params:
454 generation_params[key] = value
456 # Use kwargs to override default parameters
457 for key, value in kwargs.items():
458 generation_params[key] = value
460 # Generate the video
461 set_random_seed(self.config.gen_seed)
462 output = self.config.pipe(
463 input_data, # Use prompt
464 **generation_params
465 )
467 # 根据测试结果,我们知道 TextToVideoSDPipeline 的输出有 frames 属性
468 if hasattr(output, 'frames'):
469 frames = output.frames[0]
470 elif hasattr(output, 'videos'):
471 frames = output.videos[0]
472 else:
473 frames = output[0] if isinstance(output, tuple) else output
475 # Convert frames to PIL images
476 frame_list = [cv2_to_pil(frame) for frame in frames]
477 return frame_list
479 # Handle Image-to-Video pipeline
480 elif is_i2v_pipeline(self.config.pipe):
481 # For I2V, input should be an image, text prompt is optional
482 input_image = None
483 text_prompt = None
485 # Check if input_data is an image passed via kwargs
486 if "input_image" in kwargs and isinstance(kwargs["input_image"], Image.Image):
487 input_image = kwargs["input_image"]
489 # Check if input_data is an image
490 elif isinstance(input_data, Image.Image):
491 input_image = input_data
493 # If input_data is a string but we need an image, check if an image path was provided
494 elif isinstance(input_data, str):
495 import os
496 from PIL import Image as PILImage
498 if os.path.exists(input_data):
499 try:
500 input_image = PILImage.open(input_data).convert("RGB")
501 except Exception as e:
502 raise ValueError(f"Input data is neither an Image object nor a valid image path. Failed to load image from path: {e}")
503 else:
504 # Treat as text prompt if no valid image path
505 text_prompt = input_data
506 if input_image is None:
507 raise ValueError("Input image is required for Image-to-Video pipeline")
509 # Construct generation parameters
510 generation_params = {
511 "image": input_image,
512 "height": self.config.image_size[0],
513 "width": self.config.image_size[1],
514 "num_frames": self.config.num_frames,
515 "latents": self.config.init_latents,
516 "num_inference_steps": self.config.num_inference_steps,
517 "max_guidance_scale": self.config.guidance_scale,
518 "output_type": "np",
519 }
520 # In I2VGen-XL, the text prompt is needed
521 if text_prompt is not None:
522 generation_params["prompt"] = text_prompt
524 # Add parameters from config.gen_kwargs
525 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
526 for key, value in self.config.gen_kwargs.items():
527 if key not in generation_params:
528 generation_params[key] = value
530 # Use kwargs to override default parameters
531 for key, value in kwargs.items():
532 generation_params[key] = value
534 # Generate the video
535 set_random_seed(self.config.gen_seed)
536 video = self.config.pipe(
537 **generation_params
538 ).frames[0]
540 # Convert frames to PIL images
541 frame_list = [cv2_to_pil(frame) for frame in video]
542 return frame_list
544 # This should never happen since we already checked pipeline type
545 raise NotImplementedError(f"Unsupported video pipeline type: {self.pipeline_type}")
547 def _detect_watermark_in_video(self,
548 video_frames: List[Image.Image],
549 *args,
550 **kwargs) -> Dict[str, Any]:
551 """
552 Detect watermark in video frames.
554 Args:
555 video_frames: List of video frames as PIL images
556 kwargs:
557 - 'prompt': Optional text prompt used for generation (for T2V pipelines)
558 - 'reference_image': Optional reference image (for I2V pipelines)
559 - 'guidance_scale': The guidance scale for the detector (optional)
560 - 'detector_type': The type of detector to use (optional)
561 - 'num_inference_steps': Number of inference steps for inversion (optional)
562 - 'num_frames': Number of frames to use for detection (optional for I2V pipelines)
563 - 'decoder_inv': Whether to use decoder inversion (optional)
564 - 'inv_order': Inverse order for Exact Inversion (optional)
566 Returns:
567 Dict[str, Any]: Detection results
569 Raises:
570 NotImplementedError: If the watermarking algorithm doesn't support video watermark detection
571 """
572 raise NotImplementedError("Video watermark detection is not implemented for this algorithm")
574 def _detect_watermark_in_image(self,
575 image: Image.Image,
576 prompt: str = "",
577 *args,
578 **kwargs) -> Dict[str, float]:
579 """
580 Detect watermark in image.
582 Args:
583 image (Image.Image): The input image.
584 prompt (str): The prompt used for generation.
585 kwargs:
586 - 'guidance_scale': The guidance scale for the detector.
587 - 'detector_type': The type of detector to use.
588 - 'num_inference_steps': Number of inference steps for inversion.
589 - 'decoder_inv': Whether to use decoder inversion.
590 - 'inv_order': Inverse order for Exact Inversion.
592 Returns:
593 Dict[str, float]: The detection result.
594 """
595 raise NotImplementedError("Watermark detection in image is not implemented for this algorithm")
597 @abstractmethod
598 def get_data_for_visualize(self, media, *args, **kwargs):
599 """Get data for visualization."""
600 pass