Coverage for markdiffusion / watermark / robin / robin.py: 90.51%
158 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-14 19:25 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-14 19:25 +0000
1from ..base import BaseWatermark, BaseConfig
2from markdiffusion.utils.media_utils import *
3import os
4import types
5import torch
6from typing import Dict, Union, List, Optional
7from markdiffusion.utils.utils import set_random_seed, inherit_docstring
8from markdiffusion.utils.diffusion_config import DiffusionConfig
9import copy
10import numpy as np
11from PIL import Image
12from huggingface_hub import hf_hub_download
13from markdiffusion.visualize.data_for_visualization import DataForVisualization
14from markdiffusion.evaluation.dataset import StableDiffusionPromptsDataset
15from markdiffusion.utils.media_utils import get_random_latents
16from .watermark_generator import get_watermarking_mask, inject_watermark, ROBINWatermarkedImageGeneration # OptimizedDataset, optimizer_wm_prompt
17from markdiffusion.detection.robin.robin_detection import ROBINDetector
19class ROBINConfig(BaseConfig):
20 """Config class for ROBIN algorithm, load config file and initialize parameters."""
22 def initialize_parameters(self) -> None:
23 """Initialize algorithm-specific parameters."""
24 ## Watermarking-Specific Parameters
25 self.w_seed = self.config_dict['w_seed']
26 self.w_channel = self.config_dict['w_channel']
27 self.w_pattern = self.config_dict['w_pattern']
28 self.w_mask_shape = self.config_dict['w_mask_shape']
29 self.w_up_radius = self.config_dict['w_up_radius']
30 self.w_low_radius = self.config_dict['w_low_radius']
31 self.w_injection = self.config_dict['w_injection']
32 self.w_pattern_const = self.config_dict['w_pattern_const']
33 self.threshold = self.config_dict['threshold']
34 self.threshold_p_value = self.config_dict.get('threshold_p_value', 0.01)
35 self.threshold_cosine_similarity = self.config_dict.get('threshold_cosine_similarity', 0.5)
37 self.watermarking_step = self.config_dict['watermarking_step']
39 self.is_training_from_scratch = self.config_dict.get('training_from_scratch', False)
40 ## Training-Specific Parameters
41 self.learning_rate = self.config_dict['learning_rate'] # learning rate for watermark optimization
42 self.scale_lr = self.config_dict['scale_lr'] # if True, learning_rate will be multiplied by gradient_accumulation_steps * train_batch_size * num_processes
43 self.max_train_steps = self.config_dict['max_train_steps'] # maximum number of training steps for watermark optimization
44 self.save_steps = self.config_dict['save_steps'] # save steps for watermark optimization
45 self.train_batch_size = self.config_dict['train_batch_size'] # batch size for watermark optimization
46 self.gradient_accumulation_steps = self.config_dict['gradient_accumulation_steps'] # gradient accumulation steps for watermark optimization
47 self.gradient_checkpointing = self.config_dict['gradient_checkpointing'] # if True, use gradient checkpointing for watermark optimization
48 self.mixed_precision = self.config_dict['mixed_precision'] # fp16, fp32, bf16
49 self.train_seed = self.config_dict['train_seed'] # seed for watermark optimization
51 self.optimized_guidance_scale = self.config_dict['optimized_guidance_scale'] # guidance scale for optimized prompt signal
52 self.data_guidance_scale = self.config_dict['data_guidance_scale'] # guidance scale for data prompt signal
53 self.train_guidance_scale = self.config_dict['train_guidance_scale'] # guidance scale for training prompt signal
54 self.hf_dir = self.config_dict['hf_dir']
55 # self.output_img_dir = 'watermark/robin/generated_images'
56 self.output_img_dir = "watermark/robin/generated_images"
57 self.ckpt_dir = 'watermark/robin/ckpts'
59 @property
60 def algorithm_name(self) -> str:
61 """Return the algorithm name."""
62 return 'ROBIN'
64class ROBINUtils:
65 """Utility class for ROBIN algorithm, contains helper functions."""
67 def __init__(self, config: ROBINConfig, *args, **kwargs) -> None:
68 """
69 Initialize the ROBIN watermarking algorithm.
71 Parameters:
72 config (ROBINConfig): Configuration for the ROBIN algorithm.
73 """
74 self.config = config
76 def build_generation_params(self, **kwargs) -> Dict:
77 """Build generation parameters from config and kwargs."""
78 generation_params = {
79 "num_images_per_prompt": self.config.num_images,
80 "guidance_scale": self.config.guidance_scale,
81 "num_inference_steps": self.config.num_inference_steps,
82 "height": self.config.image_size[0],
83 "width": self.config.image_size[1],
84 "latents": self.config.init_latents,
85 }
87 # Add parameters from config.gen_kwargs
88 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
89 for key, value in self.config.gen_kwargs.items():
90 if key not in generation_params:
91 generation_params[key] = value
93 # Use kwargs to override default parameters
94 for key, value in kwargs.items():
95 generation_params[key] = value
97 return generation_params
99 def build_watermarking_args(self) -> types.SimpleNamespace:
100 """Build watermarking arguments from config."""
101 watermarking_args = {
102 "w_seed": self.config.w_seed,
103 "w_channel": self.config.w_channel,
104 "w_pattern": self.config.w_pattern,
105 "w_mask_shape": self.config.w_mask_shape,
106 "w_up_radius": self.config.w_up_radius,
107 "w_low_radius": self.config.w_low_radius,
108 "w_pattern_const": self.config.w_pattern_const,
109 "w_injection": self.config.w_injection,
110 }
111 return types.SimpleNamespace(**watermarking_args)
113 def build_hyperparameters(self) -> Dict:
114 """Build hyperparameters for optimization from config."""
115 return {
116 "learning_rate": self.config.learning_rate,
117 "scale_lr": self.config.scale_lr,
118 "max_train_steps": self.config.max_train_steps,
119 "save_steps": self.config.save_steps,
120 "train_batch_size": self.config.train_batch_size,
121 "gradient_accumulation_steps": self.config.gradient_accumulation_steps,
122 "gradient_checkpointing": self.config.gradient_checkpointing,
123 "guidance_scale": self.config.train_guidance_scale,
124 "optimized_guidance_scale": self.config.optimized_guidance_scale,
125 "mixed_precision": self.config.mixed_precision,
126 "seed": self.config.train_seed,
127 "output_dir": self.config.ckpt_dir,
128 }
130 def optimize_watermark(self, dataset: StableDiffusionPromptsDataset, watermarking_args: types.SimpleNamespace) -> tuple:
131 """Optimize watermark and watermarking signal."""
132 init_latents_w = get_random_latents(pipe=self.config.pipe)
133 watermarking_mask = get_watermarking_mask(init_latents_w, self.config, self.config.device).detach().cpu()
135 # Build hyperparameters
136 hyperparameters = self.build_hyperparameters()
137 filename = f"optimized_wm5-30_embedding-step-{hyperparameters['max_train_steps']}.pt"
139 # Check if file already exists locally before downloading
140 base_dir = os.path.dirname(os.path.abspath(__file__))
141 checkpoint_path = None
143 # Check multiple potential local paths
144 potential_paths = [
145 os.path.join(base_dir, self.config.hf_dir, filename) if self.config.hf_dir else None,
146 os.path.join(self.config.hf_dir, filename) if self.config.hf_dir else None,
147 os.path.join(self.config.ckpt_dir, filename),
148 ]
150 for path in potential_paths:
151 if path and os.path.exists(path):
152 checkpoint_path = path
153 print(f"Using existing ROBIN checkpoint: {checkpoint_path}")
154 break
156 # If not found locally, download from HuggingFace
157 if checkpoint_path is None:
158 checkpoint_path = hf_hub_download(
159 repo_id="Generative-Watermark-Toolkits/MarkDiffusion-robin",
160 filename=filename,
161 cache_dir=self.config.hf_dir
162 )
163 print(f"Downloaded ROBIN checkpoint from Huggingface Hub: {checkpoint_path}")
165 # if os.path.exists(checkpoint_path):
166 # if (not self.config.is_training_from_scratch):
167 if not os.path.exists(checkpoint_path):
168 os.makedirs(self.config.ckpt_dir, exist_ok=True)
169 from huggingface_hub import snapshot_download
170 snapshot_download(
171 repo_id="Generative-Watermark-Toolkits/MarkDiffusion-robin",
172 local_dir=self.config.ckpt_dir,
173 repo_type="model",
174 local_dir_use_symlinks=False,
175 endpoint=os.getenv("HF_ENDPOINT", "https://huggingface.co"),
176 )
178 print(f"Loading checkpoint from {checkpoint_path}")
179 checkpoint = torch.load(checkpoint_path, map_location=self.config.device)
180 optimized_watermark = checkpoint['opt_wm'].to(self.config.device)
181 optimized_watermarking_signal = checkpoint['opt_acond'].to(self.config.device)
183 return watermarking_mask, optimized_watermark, optimized_watermarking_signal
185 def initialize_detector(self, watermarking_mask, optimized_watermark) -> ROBINDetector:
186 """Initialize the ROBIN detector."""
187 return ROBINDetector(
188 watermarking_mask=watermarking_mask,
189 gt_patch=optimized_watermark,
190 threshold=self.config.threshold,
191 device=self.config.device,
192 threshold_p_value=self.config.threshold_p_value,
193 threshold_cosine_similarity=self.config.threshold_cosine_similarity,
194 )
196 # def preprocess_image_for_detection(self, image: Image.Image, prompt: str, guidance_scale: float) -> tuple:
197 # """Preprocess image and get text embeddings for detection."""
198 # # Get Text Embeddings
199 # do_classifier_free_guidance = (guidance_scale > 1.0)
200 # prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
201 # prompt=prompt,
202 # device=self.config.device,
203 # do_classifier_free_guidance=do_classifier_free_guidance,
204 # num_images_per_prompt=1,
205 # )
207 # if do_classifier_free_guidance:
208 # text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
209 # else:
210 # text_embeddings = prompt_embeds
212 # # Preprocess Image
213 # processed_image = transform_to_model_format(
214 # image,
215 # target_size=self.config.image_size[0]
216 # ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
218 # return text_embeddings, processed_image
220 # def extract_latents_for_detection(self,
221 # image: Image.Image,
222 # prompt: str,
223 # guidance_scale: float,
224 # num_inference_steps: int,
225 # extract_latents_step: int,
226 # **kwargs) -> torch.Tensor:
227 # """Extract and reverse latents for watermark detection."""
228 # # Preprocess image and get text embeddings
229 # text_embeddings, processed_image = self.preprocess_image_for_detection(image, prompt, guidance_scale)
231 # # Get Image Latents
232 # image_latents = get_media_latents(
233 # pipe=self.config.pipe,
234 # media=processed_image,
235 # sample=False,
236 # decoder_inv=kwargs.get('decoder_inv', False)
237 # )
239 # # Reverse Image Latents
240 # inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']}
242 # reversed_latents = self.config.inversion.forward_diffusion(
243 # latents=image_latents,
244 # text_embeddings=text_embeddings,
245 # guidance_scale=guidance_scale,
246 # num_inference_steps=num_inference_steps,
247 # **inversion_kwargs
248 # )[extract_latents_step]
250 # return reversed_latents
252@inherit_docstring
253class ROBIN(BaseWatermark):
254 def __init__(self,
255 watermarking_config: ROBINConfig,
256 *args, **kwargs):
257 """
258 Initialize the ROBIN watermarking algorithm.
260 Parameters:
261 watermarking_config (ROBINConfig): Configuration for the ROBIN algorithm.
262 """
263 #super().__init__(algorithm_config, diffusion_config)
264 self.config = watermarking_config
265 self.utils = ROBINUtils(self.config)
267 # === Get the optimized watermark & watermarking signal before generation ===
268 self.dataset = StableDiffusionPromptsDataset()
270 # Build watermarking arguments
271 self.watermarking_args = self.utils.build_watermarking_args()
273 # Optimize watermark and get components
274 self.watermarking_mask, self.optimized_watermark, self.optimized_watermarking_signal = self.utils.optimize_watermark(
275 self.dataset,
276 self.watermarking_args
277 )
279 # Initialize detector
280 self.detector = self.utils.initialize_detector(self.watermarking_mask, self.optimized_watermark)
282 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image:
283 """Internal method to generate a watermarked image."""
284 self.set_orig_watermarked_latents(self.config.init_latents)
286 # Build generation parameters using utils
287 generation_params = self.utils.build_generation_params(**kwargs)
288 # Override guidance_scale for watermarked generation
289 generation_params["guidance_scale"] = self.config.guidance_scale
290 # Ensure latents parameter is not overridden
291 generation_params["latents"] = self.config.init_latents
293 # Filter out parameters not supported by ROBINWatermarkedImageGeneration
294 supported_params = {
295 'height', 'width', 'num_inference_steps', 'guidance_scale', 'optimized_guidance_scale',
296 'negative_prompt', 'num_images_per_prompt', 'eta', 'generator', 'latents',
297 'output_type', 'return_dict', 'callback', 'callback_steps'
298 }
299 filtered_params = {k: v for k, v in generation_params.items() if k in supported_params}
301 # Ensure watermarking components are on the correct device
302 watermarking_mask = self.watermarking_mask.to(self.config.device)
303 optimized_watermark = self.optimized_watermark.to(self.config.device)
304 optimized_watermarking_signal = self.optimized_watermarking_signal.to(self.config.device) if self.optimized_watermarking_signal is not None else None
306 # Generate watermarked image
307 set_random_seed(self.config.gen_seed)
308 result = ROBINWatermarkedImageGeneration(
309 pipe=self.config.pipe,
310 prompt=prompt,
311 watermarking_mask=watermarking_mask,
312 gt_patch=optimized_watermark,
313 opt_acond=optimized_watermarking_signal,
314 watermarking_step=self.config.watermarking_step,
315 args=self.watermarking_args,
316 **filtered_params,
317 )
318 return result.images[0]
321 def _detect_watermark_in_image(self,
322 image: Image.Image,
323 prompt: str = "",
324 *args,
325 **kwargs) -> Dict[str, float]:
326 """Detect the watermark in the image."""
327 # Use config values as defaults if not explicitly provided
328 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
329 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps)
331 # Step 1: Get Text Embeddings
332 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
333 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
334 prompt=prompt,
335 device=self.config.device,
336 do_classifier_free_guidance=do_classifier_free_guidance,
337 num_images_per_prompt=1,
338 )
340 if do_classifier_free_guidance:
341 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
342 else:
343 text_embeddings = prompt_embeds
345 # Step 2: Preprocess Image
346 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
348 # Step 3: Get Image Latents
349 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get('decoder_inv', False))
351 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters
352 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']}
354 # Extract and reverse latents for detection using utils
355 reversed_latents_list = self.config.inversion.forward_diffusion(
356 latents=image_latents,
357 text_embeddings=text_embeddings,
358 guidance_scale=guidance_scale_to_use,
359 num_inference_steps=num_steps_to_use,
360 **inversion_kwargs
361 )
363 # Handle case where forward_diffusion returns a single tensor instead of a list
364 if isinstance(reversed_latents_list, torch.Tensor):
365 reversed_latents = reversed_latents_list
366 else:
367 # Ensure index is within bounds
368 target_index = num_steps_to_use - 1 - self.config.watermarking_step
369 if target_index < 0:
370 target_index = 0
371 elif target_index >= len(reversed_latents_list):
372 target_index = len(reversed_latents_list) - 1
373 reversed_latents = reversed_latents_list[target_index]
375 # Evaluate watermark
376 if 'detector_type' in kwargs:
377 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type'])
378 else:
379 return self.detector.eval_watermark(reversed_latents)
381 def get_data_for_visualize(self,
382 image: Image.Image,
383 prompt: str="",
384 guidance_scale: Optional[float]=None,
385 decoder_inv: bool=False,
386 *args,
387 **kwargs) -> DataForVisualization:
388 # Use config values as defaults if not explicitly provided
389 guidance_scale_to_use = guidance_scale if guidance_scale is not None else self.config.guidance_scale
391 # Step 1: Generate watermarked latents (generation process)
392 set_random_seed(self.config.gen_seed)
393 # For ROBIN, the watermarked latents are the init_latents (watermark is applied during generation)
394 watermarked_latents = self.config.init_latents
396 # Step 2: Generate actual watermarked image using the same process as _generate_watermarked_image
397 generation_params = self.utils.build_generation_params()
398 generation_params["guidance_scale"] = self.config.guidance_scale
399 generation_params["latents"] = self.config.init_latents
401 # Generate the actual watermarked image with ROBIN watermarking
402 watermarked_image = ROBINWatermarkedImageGeneration(
403 pipe=self.config.pipe,
404 prompt=prompt,
405 watermarking_mask=self.watermarking_mask,
406 gt_patch=self.optimized_watermark,
407 opt_acond=self.optimized_watermarking_signal,
408 watermarking_step=self.config.watermarking_step,
409 args=self.watermarking_args,
410 **generation_params,
411 ).images[0]
413 # Step 3: Perform watermark detection to get inverted latents (detection process)
414 reversed_latents_list = None
416 # Get Text Embeddings for detection
417 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
418 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
419 prompt=prompt,
420 device=self.config.device,
421 do_classifier_free_guidance=do_classifier_free_guidance,
422 num_images_per_prompt=1,
423 )
425 if do_classifier_free_guidance:
426 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
427 else:
428 text_embeddings = prompt_embeds
430 # Preprocess watermarked image for detection
431 processed_image = transform_to_model_format(
432 watermarked_image,
433 target_size=self.config.image_size[0]
434 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
436 # Get Image Latents
437 image_latents = get_media_latents(
438 pipe=self.config.pipe,
439 media=processed_image,
440 sample=False,
441 decoder_inv=decoder_inv
442 )
444 # Reverse Image Latents to get inverted noise
445 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['prompt', 'decoder_inv', 'guidance_scale', 'num_inference_steps']}
447 reversed_latents_list = self.config.inversion.forward_diffusion(
448 latents=image_latents,
449 text_embeddings=text_embeddings,
450 guidance_scale=guidance_scale_to_use,
451 num_inference_steps=self.config.num_inference_steps,
452 **inversion_kwargs
453 )
455 # Step 4: Prepare visualization data
456 return DataForVisualization(
457 config=self.config,
458 utils=self.utils,
459 reversed_latents=reversed_latents_list,
460 orig_watermarked_latents=self.orig_watermarked_latents,
461 image=image,
462 # ROBIN-specific data
463 watermarking_mask=self.watermarking_mask,
464 optimized_watermark=self.optimized_watermark,
465 )