Coverage for watermark / robin / robin.py: 90.38%
156 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from ..base import BaseWatermark, BaseConfig
2from utils.media_utils import *
3import os
4import types
5import torch
6from typing import Dict, Union, List, Optional
7from utils.utils import set_random_seed, inherit_docstring
8from utils.diffusion_config import DiffusionConfig
9import copy
10import numpy as np
11from PIL import Image
12from huggingface_hub import hf_hub_download
13from visualize.data_for_visualization import DataForVisualization
14from evaluation.dataset import StableDiffusionPromptsDataset
15from utils.media_utils import get_random_latents
16from .watermark_generator import get_watermarking_mask, inject_watermark, ROBINWatermarkedImageGeneration # OptimizedDataset, optimizer_wm_prompt
17from detection.robin.robin_detection import ROBINDetector
19class ROBINConfig(BaseConfig):
20 """Config class for ROBIN algorithm, load config file and initialize parameters."""
22 def initialize_parameters(self) -> None:
23 """Initialize algorithm-specific parameters."""
24 ## Watermarking-Specific Parameters
25 self.w_seed = self.config_dict['w_seed']
26 self.w_channel = self.config_dict['w_channel']
27 self.w_pattern = self.config_dict['w_pattern']
28 self.w_mask_shape = self.config_dict['w_mask_shape']
29 self.w_up_radius = self.config_dict['w_up_radius']
30 self.w_low_radius = self.config_dict['w_low_radius']
31 self.w_injection = self.config_dict['w_injection']
32 self.w_pattern_const = self.config_dict['w_pattern_const']
33 self.threshold = self.config_dict['threshold']
35 self.watermarking_step = self.config_dict['watermarking_step']
37 self.is_training_from_scratch = self.config_dict.get('training_from_scratch', False)
38 ## Training-Specific Parameters
39 self.learning_rate = self.config_dict['learning_rate'] # learning rate for watermark optimization
40 self.scale_lr = self.config_dict['scale_lr'] # if True, learning_rate will be multiplied by gradient_accumulation_steps * train_batch_size * num_processes
41 self.max_train_steps = self.config_dict['max_train_steps'] # maximum number of training steps for watermark optimization
42 self.save_steps = self.config_dict['save_steps'] # save steps for watermark optimization
43 self.train_batch_size = self.config_dict['train_batch_size'] # batch size for watermark optimization
44 self.gradient_accumulation_steps = self.config_dict['gradient_accumulation_steps'] # gradient accumulation steps for watermark optimization
45 self.gradient_checkpointing = self.config_dict['gradient_checkpointing'] # if True, use gradient checkpointing for watermark optimization
46 self.mixed_precision = self.config_dict['mixed_precision'] # fp16, fp32, bf16
47 self.train_seed = self.config_dict['train_seed'] # seed for watermark optimization
49 self.optimized_guidance_scale = self.config_dict['optimized_guidance_scale'] # guidance scale for optimized prompt signal
50 self.data_guidance_scale = self.config_dict['data_guidance_scale'] # guidance scale for data prompt signal
51 self.train_guidance_scale = self.config_dict['train_guidance_scale'] # guidance scale for training prompt signal
52 self.hf_dir = self.config_dict['hf_dir']
53 # self.output_img_dir = 'watermark/robin/generated_images'
54 self.output_img_dir = "watermark/robin/generated_images"
55 self.ckpt_dir = 'watermark/robin/ckpts'
57 @property
58 def algorithm_name(self) -> str:
59 """Return the algorithm name."""
60 return 'ROBIN'
62class ROBINUtils:
63 """Utility class for ROBIN algorithm, contains helper functions."""
65 def __init__(self, config: ROBINConfig, *args, **kwargs) -> None:
66 """
67 Initialize the ROBIN watermarking algorithm.
69 Parameters:
70 config (ROBINConfig): Configuration for the ROBIN algorithm.
71 """
72 self.config = config
74 def build_generation_params(self, **kwargs) -> Dict:
75 """Build generation parameters from config and kwargs."""
76 generation_params = {
77 "num_images_per_prompt": self.config.num_images,
78 "guidance_scale": self.config.guidance_scale,
79 "num_inference_steps": self.config.num_inference_steps,
80 "height": self.config.image_size[0],
81 "width": self.config.image_size[1],
82 "latents": self.config.init_latents,
83 }
85 # Add parameters from config.gen_kwargs
86 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
87 for key, value in self.config.gen_kwargs.items():
88 if key not in generation_params:
89 generation_params[key] = value
91 # Use kwargs to override default parameters
92 for key, value in kwargs.items():
93 generation_params[key] = value
95 return generation_params
97 def build_watermarking_args(self) -> types.SimpleNamespace:
98 """Build watermarking arguments from config."""
99 watermarking_args = {
100 "w_seed": self.config.w_seed,
101 "w_channel": self.config.w_channel,
102 "w_pattern": self.config.w_pattern,
103 "w_mask_shape": self.config.w_mask_shape,
104 "w_up_radius": self.config.w_up_radius,
105 "w_low_radius": self.config.w_low_radius,
106 "w_pattern_const": self.config.w_pattern_const,
107 "w_injection": self.config.w_injection,
108 }
109 return types.SimpleNamespace(**watermarking_args)
111 def build_hyperparameters(self) -> Dict:
112 """Build hyperparameters for optimization from config."""
113 return {
114 "learning_rate": self.config.learning_rate,
115 "scale_lr": self.config.scale_lr,
116 "max_train_steps": self.config.max_train_steps,
117 "save_steps": self.config.save_steps,
118 "train_batch_size": self.config.train_batch_size,
119 "gradient_accumulation_steps": self.config.gradient_accumulation_steps,
120 "gradient_checkpointing": self.config.gradient_checkpointing,
121 "guidance_scale": self.config.train_guidance_scale,
122 "optimized_guidance_scale": self.config.optimized_guidance_scale,
123 "mixed_precision": self.config.mixed_precision,
124 "seed": self.config.train_seed,
125 "output_dir": self.config.ckpt_dir,
126 }
128 def optimize_watermark(self, dataset: StableDiffusionPromptsDataset, watermarking_args: types.SimpleNamespace) -> tuple:
129 """Optimize watermark and watermarking signal."""
130 init_latents_w = get_random_latents(pipe=self.config.pipe)
131 watermarking_mask = get_watermarking_mask(init_latents_w, self.config, self.config.device).detach().cpu()
133 # Build hyperparameters
134 hyperparameters = self.build_hyperparameters()
135 filename = f"optimized_wm5-30_embedding-step-{hyperparameters['max_train_steps']}.pt"
137 # Check if file already exists locally before downloading
138 base_dir = os.path.dirname(os.path.abspath(__file__))
139 checkpoint_path = None
141 # Check multiple potential local paths
142 potential_paths = [
143 os.path.join(base_dir, self.config.hf_dir, filename) if self.config.hf_dir else None,
144 os.path.join(self.config.hf_dir, filename) if self.config.hf_dir else None,
145 os.path.join(self.config.ckpt_dir, filename),
146 ]
148 for path in potential_paths:
149 if path and os.path.exists(path):
150 checkpoint_path = path
151 print(f"Using existing ROBIN checkpoint: {checkpoint_path}")
152 break
154 # If not found locally, download from HuggingFace
155 if checkpoint_path is None:
156 checkpoint_path = hf_hub_download(
157 repo_id="Generative-Watermark-Toolkits/MarkDiffusion-robin",
158 filename=filename,
159 cache_dir=self.config.hf_dir
160 )
161 print(f"Downloaded ROBIN checkpoint from Huggingface Hub: {checkpoint_path}")
163 # if os.path.exists(checkpoint_path):
164 # if (not self.config.is_training_from_scratch):
165 if not os.path.exists(checkpoint_path):
166 os.makedirs(self.config.ckpt_dir, exist_ok=True)
167 from huggingface_hub import snapshot_download
168 snapshot_download(
169 repo_id="Generative-Watermark-Toolkits/MarkDiffusion-robin",
170 local_dir=self.config.ckpt_dir,
171 repo_type="model",
172 local_dir_use_symlinks=False,
173 endpoint=os.getenv("HF_ENDPOINT", "https://huggingface.co"),
174 )
176 print(f"Loading checkpoint from {checkpoint_path}")
177 checkpoint = torch.load(checkpoint_path, map_location=self.config.device)
178 optimized_watermark = checkpoint['opt_wm'].to(self.config.device)
179 optimized_watermarking_signal = checkpoint['opt_acond'].to(self.config.device)
181 return watermarking_mask, optimized_watermark, optimized_watermarking_signal
183 def initialize_detector(self, watermarking_mask, optimized_watermark) -> ROBINDetector:
184 """Initialize the ROBIN detector."""
185 return ROBINDetector(
186 watermarking_mask=watermarking_mask,
187 gt_patch=optimized_watermark,
188 threshold=self.config.threshold,
189 device=self.config.device
190 )
192 # def preprocess_image_for_detection(self, image: Image.Image, prompt: str, guidance_scale: float) -> tuple:
193 # """Preprocess image and get text embeddings for detection."""
194 # # Get Text Embeddings
195 # do_classifier_free_guidance = (guidance_scale > 1.0)
196 # prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
197 # prompt=prompt,
198 # device=self.config.device,
199 # do_classifier_free_guidance=do_classifier_free_guidance,
200 # num_images_per_prompt=1,
201 # )
203 # if do_classifier_free_guidance:
204 # text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
205 # else:
206 # text_embeddings = prompt_embeds
208 # # Preprocess Image
209 # processed_image = transform_to_model_format(
210 # image,
211 # target_size=self.config.image_size[0]
212 # ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
214 # return text_embeddings, processed_image
216 # def extract_latents_for_detection(self,
217 # image: Image.Image,
218 # prompt: str,
219 # guidance_scale: float,
220 # num_inference_steps: int,
221 # extract_latents_step: int,
222 # **kwargs) -> torch.Tensor:
223 # """Extract and reverse latents for watermark detection."""
224 # # Preprocess image and get text embeddings
225 # text_embeddings, processed_image = self.preprocess_image_for_detection(image, prompt, guidance_scale)
227 # # Get Image Latents
228 # image_latents = get_media_latents(
229 # pipe=self.config.pipe,
230 # media=processed_image,
231 # sample=False,
232 # decoder_inv=kwargs.get('decoder_inv', False)
233 # )
235 # # Reverse Image Latents
236 # inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']}
238 # reversed_latents = self.config.inversion.forward_diffusion(
239 # latents=image_latents,
240 # text_embeddings=text_embeddings,
241 # guidance_scale=guidance_scale,
242 # num_inference_steps=num_inference_steps,
243 # **inversion_kwargs
244 # )[extract_latents_step]
246 # return reversed_latents
248@inherit_docstring
249class ROBIN(BaseWatermark):
250 def __init__(self,
251 watermarking_config: ROBINConfig,
252 *args, **kwargs):
253 """
254 Initialize the ROBIN watermarking algorithm.
256 Parameters:
257 watermarking_config (ROBINConfig): Configuration for the ROBIN algorithm.
258 """
259 #super().__init__(algorithm_config, diffusion_config)
260 self.config = watermarking_config
261 self.utils = ROBINUtils(self.config)
263 # === Get the optimized watermark & watermarking signal before generation ===
264 self.dataset = StableDiffusionPromptsDataset()
266 # Build watermarking arguments
267 self.watermarking_args = self.utils.build_watermarking_args()
269 # Optimize watermark and get components
270 self.watermarking_mask, self.optimized_watermark, self.optimized_watermarking_signal = self.utils.optimize_watermark(
271 self.dataset,
272 self.watermarking_args
273 )
275 # Initialize detector
276 self.detector = self.utils.initialize_detector(self.watermarking_mask, self.optimized_watermark)
278 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image:
279 """Internal method to generate a watermarked image."""
280 self.set_orig_watermarked_latents(self.config.init_latents)
282 # Build generation parameters using utils
283 generation_params = self.utils.build_generation_params(**kwargs)
284 # Override guidance_scale for watermarked generation
285 generation_params["guidance_scale"] = self.config.guidance_scale
286 # Ensure latents parameter is not overridden
287 generation_params["latents"] = self.config.init_latents
289 # Filter out parameters not supported by ROBINWatermarkedImageGeneration
290 supported_params = {
291 'height', 'width', 'num_inference_steps', 'guidance_scale', 'optimized_guidance_scale',
292 'negative_prompt', 'num_images_per_prompt', 'eta', 'generator', 'latents',
293 'output_type', 'return_dict', 'callback', 'callback_steps'
294 }
295 filtered_params = {k: v for k, v in generation_params.items() if k in supported_params}
297 # Ensure watermarking components are on the correct device
298 watermarking_mask = self.watermarking_mask.to(self.config.device)
299 optimized_watermark = self.optimized_watermark.to(self.config.device)
300 optimized_watermarking_signal = self.optimized_watermarking_signal.to(self.config.device) if self.optimized_watermarking_signal is not None else None
302 # Generate watermarked image
303 set_random_seed(self.config.gen_seed)
304 result = ROBINWatermarkedImageGeneration(
305 pipe=self.config.pipe,
306 prompt=prompt,
307 watermarking_mask=watermarking_mask,
308 gt_patch=optimized_watermark,
309 opt_acond=optimized_watermarking_signal,
310 watermarking_step=self.config.watermarking_step,
311 args=self.watermarking_args,
312 **filtered_params,
313 )
314 return result.images[0]
317 def _detect_watermark_in_image(self,
318 image: Image.Image,
319 prompt: str = "",
320 *args,
321 **kwargs) -> Dict[str, float]:
322 """Detect the watermark in the image."""
323 # Use config values as defaults if not explicitly provided
324 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
325 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps)
327 # Step 1: Get Text Embeddings
328 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
329 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
330 prompt=prompt,
331 device=self.config.device,
332 do_classifier_free_guidance=do_classifier_free_guidance,
333 num_images_per_prompt=1,
334 )
336 if do_classifier_free_guidance:
337 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
338 else:
339 text_embeddings = prompt_embeds
341 # Step 2: Preprocess Image
342 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
344 # Step 3: Get Image Latents
345 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get('decoder_inv', False))
347 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters
348 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']}
350 # Extract and reverse latents for detection using utils
351 reversed_latents_list = self.config.inversion.forward_diffusion(
352 latents=image_latents,
353 text_embeddings=text_embeddings,
354 guidance_scale=guidance_scale_to_use,
355 num_inference_steps=num_steps_to_use,
356 **inversion_kwargs
357 )
359 # Handle case where forward_diffusion returns a single tensor instead of a list
360 if isinstance(reversed_latents_list, torch.Tensor):
361 reversed_latents = reversed_latents_list
362 else:
363 # Ensure index is within bounds
364 target_index = num_steps_to_use - 1 - self.config.watermarking_step
365 if target_index < 0:
366 target_index = 0
367 elif target_index >= len(reversed_latents_list):
368 target_index = len(reversed_latents_list) - 1
369 reversed_latents = reversed_latents_list[target_index]
371 # Evaluate watermark
372 if 'detector_type' in kwargs:
373 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type'])
374 else:
375 return self.detector.eval_watermark(reversed_latents)
377 def get_data_for_visualize(self,
378 image: Image.Image,
379 prompt: str="",
380 guidance_scale: Optional[float]=None,
381 decoder_inv: bool=False,
382 *args,
383 **kwargs) -> DataForVisualization:
384 # Use config values as defaults if not explicitly provided
385 guidance_scale_to_use = guidance_scale if guidance_scale is not None else self.config.guidance_scale
387 # Step 1: Generate watermarked latents (generation process)
388 set_random_seed(self.config.gen_seed)
389 # For ROBIN, the watermarked latents are the init_latents (watermark is applied during generation)
390 watermarked_latents = self.config.init_latents
392 # Step 2: Generate actual watermarked image using the same process as _generate_watermarked_image
393 generation_params = self.utils.build_generation_params()
394 generation_params["guidance_scale"] = self.config.guidance_scale
395 generation_params["latents"] = self.config.init_latents
397 # Generate the actual watermarked image with ROBIN watermarking
398 watermarked_image = ROBINWatermarkedImageGeneration(
399 pipe=self.config.pipe,
400 prompt=prompt,
401 watermarking_mask=self.watermarking_mask,
402 gt_patch=self.optimized_watermark,
403 opt_acond=self.optimized_watermarking_signal,
404 watermarking_step=self.config.watermarking_step,
405 args=self.watermarking_args,
406 **generation_params,
407 ).images[0]
409 # Step 3: Perform watermark detection to get inverted latents (detection process)
410 reversed_latents_list = None
412 # Get Text Embeddings for detection
413 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
414 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
415 prompt=prompt,
416 device=self.config.device,
417 do_classifier_free_guidance=do_classifier_free_guidance,
418 num_images_per_prompt=1,
419 )
421 if do_classifier_free_guidance:
422 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
423 else:
424 text_embeddings = prompt_embeds
426 # Preprocess watermarked image for detection
427 processed_image = transform_to_model_format(
428 watermarked_image,
429 target_size=self.config.image_size[0]
430 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
432 # Get Image Latents
433 image_latents = get_media_latents(
434 pipe=self.config.pipe,
435 media=processed_image,
436 sample=False,
437 decoder_inv=decoder_inv
438 )
440 # Reverse Image Latents to get inverted noise
441 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['prompt', 'decoder_inv', 'guidance_scale', 'num_inference_steps']}
443 reversed_latents_list = self.config.inversion.forward_diffusion(
444 latents=image_latents,
445 text_embeddings=text_embeddings,
446 guidance_scale=guidance_scale_to_use,
447 num_inference_steps=self.config.num_inference_steps,
448 **inversion_kwargs
449 )
451 # Step 4: Prepare visualization data
452 return DataForVisualization(
453 config=self.config,
454 utils=self.utils,
455 reversed_latents=reversed_latents_list,
456 orig_watermarked_latents=self.orig_watermarked_latents,
457 image=image,
458 # ROBIN-specific data
459 watermarking_mask=self.watermarking_mask,
460 optimized_watermark=self.optimized_watermark,
461 )