Coverage for markdiffusion / watermark / robin / robin.py: 90.51%

158 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-14 19:25 +0000

1from ..base import BaseWatermark, BaseConfig 

2from markdiffusion.utils.media_utils import * 

3import os 

4import types 

5import torch 

6from typing import Dict, Union, List, Optional 

7from markdiffusion.utils.utils import set_random_seed, inherit_docstring 

8from markdiffusion.utils.diffusion_config import DiffusionConfig 

9import copy 

10import numpy as np 

11from PIL import Image 

12from huggingface_hub import hf_hub_download 

13from markdiffusion.visualize.data_for_visualization import DataForVisualization 

14from markdiffusion.evaluation.dataset import StableDiffusionPromptsDataset 

15from markdiffusion.utils.media_utils import get_random_latents 

16from .watermark_generator import get_watermarking_mask, inject_watermark, ROBINWatermarkedImageGeneration # OptimizedDataset, optimizer_wm_prompt 

17from markdiffusion.detection.robin.robin_detection import ROBINDetector 

18 

19class ROBINConfig(BaseConfig): 

20 """Config class for ROBIN algorithm, load config file and initialize parameters.""" 

21 

22 def initialize_parameters(self) -> None: 

23 """Initialize algorithm-specific parameters.""" 

24 ## Watermarking-Specific Parameters 

25 self.w_seed = self.config_dict['w_seed'] 

26 self.w_channel = self.config_dict['w_channel'] 

27 self.w_pattern = self.config_dict['w_pattern'] 

28 self.w_mask_shape = self.config_dict['w_mask_shape'] 

29 self.w_up_radius = self.config_dict['w_up_radius'] 

30 self.w_low_radius = self.config_dict['w_low_radius'] 

31 self.w_injection = self.config_dict['w_injection'] 

32 self.w_pattern_const = self.config_dict['w_pattern_const'] 

33 self.threshold = self.config_dict['threshold'] 

34 self.threshold_p_value = self.config_dict.get('threshold_p_value', 0.01) 

35 self.threshold_cosine_similarity = self.config_dict.get('threshold_cosine_similarity', 0.5) 

36 

37 self.watermarking_step = self.config_dict['watermarking_step'] 

38 

39 self.is_training_from_scratch = self.config_dict.get('training_from_scratch', False) 

40 ## Training-Specific Parameters 

41 self.learning_rate = self.config_dict['learning_rate'] # learning rate for watermark optimization 

42 self.scale_lr = self.config_dict['scale_lr'] # if True, learning_rate will be multiplied by gradient_accumulation_steps * train_batch_size * num_processes 

43 self.max_train_steps = self.config_dict['max_train_steps'] # maximum number of training steps for watermark optimization 

44 self.save_steps = self.config_dict['save_steps'] # save steps for watermark optimization 

45 self.train_batch_size = self.config_dict['train_batch_size'] # batch size for watermark optimization 

46 self.gradient_accumulation_steps = self.config_dict['gradient_accumulation_steps'] # gradient accumulation steps for watermark optimization 

47 self.gradient_checkpointing = self.config_dict['gradient_checkpointing'] # if True, use gradient checkpointing for watermark optimization 

48 self.mixed_precision = self.config_dict['mixed_precision'] # fp16, fp32, bf16 

49 self.train_seed = self.config_dict['train_seed'] # seed for watermark optimization 

50 

51 self.optimized_guidance_scale = self.config_dict['optimized_guidance_scale'] # guidance scale for optimized prompt signal 

52 self.data_guidance_scale = self.config_dict['data_guidance_scale'] # guidance scale for data prompt signal 

53 self.train_guidance_scale = self.config_dict['train_guidance_scale'] # guidance scale for training prompt signal 

54 self.hf_dir = self.config_dict['hf_dir'] 

55 # self.output_img_dir = 'watermark/robin/generated_images' 

56 self.output_img_dir = "watermark/robin/generated_images" 

57 self.ckpt_dir = 'watermark/robin/ckpts' 

58 

59 @property 

60 def algorithm_name(self) -> str: 

61 """Return the algorithm name.""" 

62 return 'ROBIN' 

63 

64class ROBINUtils: 

65 """Utility class for ROBIN algorithm, contains helper functions.""" 

66 

67 def __init__(self, config: ROBINConfig, *args, **kwargs) -> None: 

68 """ 

69 Initialize the ROBIN watermarking algorithm. 

70  

71 Parameters: 

72 config (ROBINConfig): Configuration for the ROBIN algorithm. 

73 """ 

74 self.config = config 

75 

76 def build_generation_params(self, **kwargs) -> Dict: 

77 """Build generation parameters from config and kwargs.""" 

78 generation_params = { 

79 "num_images_per_prompt": self.config.num_images, 

80 "guidance_scale": self.config.guidance_scale, 

81 "num_inference_steps": self.config.num_inference_steps, 

82 "height": self.config.image_size[0], 

83 "width": self.config.image_size[1], 

84 "latents": self.config.init_latents, 

85 } 

86 

87 # Add parameters from config.gen_kwargs 

88 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs: 

89 for key, value in self.config.gen_kwargs.items(): 

90 if key not in generation_params: 

91 generation_params[key] = value 

92 

93 # Use kwargs to override default parameters 

94 for key, value in kwargs.items(): 

95 generation_params[key] = value 

96 

97 return generation_params 

98 

99 def build_watermarking_args(self) -> types.SimpleNamespace: 

100 """Build watermarking arguments from config.""" 

101 watermarking_args = { 

102 "w_seed": self.config.w_seed, 

103 "w_channel": self.config.w_channel, 

104 "w_pattern": self.config.w_pattern, 

105 "w_mask_shape": self.config.w_mask_shape, 

106 "w_up_radius": self.config.w_up_radius, 

107 "w_low_radius": self.config.w_low_radius, 

108 "w_pattern_const": self.config.w_pattern_const, 

109 "w_injection": self.config.w_injection, 

110 } 

111 return types.SimpleNamespace(**watermarking_args) 

112 

113 def build_hyperparameters(self) -> Dict: 

114 """Build hyperparameters for optimization from config.""" 

115 return { 

116 "learning_rate": self.config.learning_rate, 

117 "scale_lr": self.config.scale_lr, 

118 "max_train_steps": self.config.max_train_steps, 

119 "save_steps": self.config.save_steps, 

120 "train_batch_size": self.config.train_batch_size, 

121 "gradient_accumulation_steps": self.config.gradient_accumulation_steps, 

122 "gradient_checkpointing": self.config.gradient_checkpointing, 

123 "guidance_scale": self.config.train_guidance_scale, 

124 "optimized_guidance_scale": self.config.optimized_guidance_scale, 

125 "mixed_precision": self.config.mixed_precision, 

126 "seed": self.config.train_seed, 

127 "output_dir": self.config.ckpt_dir, 

128 } 

129 

130 def optimize_watermark(self, dataset: StableDiffusionPromptsDataset, watermarking_args: types.SimpleNamespace) -> tuple: 

131 """Optimize watermark and watermarking signal.""" 

132 init_latents_w = get_random_latents(pipe=self.config.pipe) 

133 watermarking_mask = get_watermarking_mask(init_latents_w, self.config, self.config.device).detach().cpu() 

134 

135 # Build hyperparameters 

136 hyperparameters = self.build_hyperparameters() 

137 filename = f"optimized_wm5-30_embedding-step-{hyperparameters['max_train_steps']}.pt" 

138 

139 # Check if file already exists locally before downloading 

140 base_dir = os.path.dirname(os.path.abspath(__file__)) 

141 checkpoint_path = None 

142 

143 # Check multiple potential local paths 

144 potential_paths = [ 

145 os.path.join(base_dir, self.config.hf_dir, filename) if self.config.hf_dir else None, 

146 os.path.join(self.config.hf_dir, filename) if self.config.hf_dir else None, 

147 os.path.join(self.config.ckpt_dir, filename), 

148 ] 

149 

150 for path in potential_paths: 

151 if path and os.path.exists(path): 

152 checkpoint_path = path 

153 print(f"Using existing ROBIN checkpoint: {checkpoint_path}") 

154 break 

155 

156 # If not found locally, download from HuggingFace 

157 if checkpoint_path is None: 

158 checkpoint_path = hf_hub_download( 

159 repo_id="Generative-Watermark-Toolkits/MarkDiffusion-robin", 

160 filename=filename, 

161 cache_dir=self.config.hf_dir 

162 ) 

163 print(f"Downloaded ROBIN checkpoint from Huggingface Hub: {checkpoint_path}") 

164 

165 # if os.path.exists(checkpoint_path): 

166 # if (not self.config.is_training_from_scratch): 

167 if not os.path.exists(checkpoint_path): 

168 os.makedirs(self.config.ckpt_dir, exist_ok=True) 

169 from huggingface_hub import snapshot_download 

170 snapshot_download( 

171 repo_id="Generative-Watermark-Toolkits/MarkDiffusion-robin", 

172 local_dir=self.config.ckpt_dir, 

173 repo_type="model", 

174 local_dir_use_symlinks=False, 

175 endpoint=os.getenv("HF_ENDPOINT", "https://huggingface.co"), 

176 ) 

177 

178 print(f"Loading checkpoint from {checkpoint_path}") 

179 checkpoint = torch.load(checkpoint_path, map_location=self.config.device) 

180 optimized_watermark = checkpoint['opt_wm'].to(self.config.device) 

181 optimized_watermarking_signal = checkpoint['opt_acond'].to(self.config.device) 

182 

183 return watermarking_mask, optimized_watermark, optimized_watermarking_signal 

184 

185 def initialize_detector(self, watermarking_mask, optimized_watermark) -> ROBINDetector: 

186 """Initialize the ROBIN detector.""" 

187 return ROBINDetector( 

188 watermarking_mask=watermarking_mask, 

189 gt_patch=optimized_watermark, 

190 threshold=self.config.threshold, 

191 device=self.config.device, 

192 threshold_p_value=self.config.threshold_p_value, 

193 threshold_cosine_similarity=self.config.threshold_cosine_similarity, 

194 ) 

195 

196 # def preprocess_image_for_detection(self, image: Image.Image, prompt: str, guidance_scale: float) -> tuple: 

197 # """Preprocess image and get text embeddings for detection.""" 

198 # # Get Text Embeddings 

199 # do_classifier_free_guidance = (guidance_scale > 1.0) 

200 # prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

201 # prompt=prompt,  

202 # device=self.config.device,  

203 # do_classifier_free_guidance=do_classifier_free_guidance, 

204 # num_images_per_prompt=1, 

205 # ) 

206 

207 # if do_classifier_free_guidance: 

208 # text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

209 # else: 

210 # text_embeddings = prompt_embeds 

211 

212 # # Preprocess Image 

213 # processed_image = transform_to_model_format( 

214 # image,  

215 # target_size=self.config.image_size[0] 

216 # ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

217 

218 # return text_embeddings, processed_image 

219 

220 # def extract_latents_for_detection(self,  

221 # image: Image.Image,  

222 # prompt: str,  

223 # guidance_scale: float,  

224 # num_inference_steps: int, 

225 # extract_latents_step: int, 

226 # **kwargs) -> torch.Tensor: 

227 # """Extract and reverse latents for watermark detection.""" 

228 # # Preprocess image and get text embeddings 

229 # text_embeddings, processed_image = self.preprocess_image_for_detection(image, prompt, guidance_scale) 

230 

231 # # Get Image Latents 

232 # image_latents = get_media_latents( 

233 # pipe=self.config.pipe,  

234 # media=processed_image,  

235 # sample=False,  

236 # decoder_inv=kwargs.get('decoder_inv', False) 

237 # ) 

238 

239 # # Reverse Image Latents 

240 # inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']} 

241 

242 # reversed_latents = self.config.inversion.forward_diffusion( 

243 # latents=image_latents, 

244 # text_embeddings=text_embeddings, 

245 # guidance_scale=guidance_scale, 

246 # num_inference_steps=num_inference_steps, 

247 # **inversion_kwargs 

248 # )[extract_latents_step] 

249 

250 # return reversed_latents 

251 

252@inherit_docstring 

253class ROBIN(BaseWatermark): 

254 def __init__(self, 

255 watermarking_config: ROBINConfig, 

256 *args, **kwargs): 

257 """ 

258 Initialize the ROBIN watermarking algorithm. 

259  

260 Parameters: 

261 watermarking_config (ROBINConfig): Configuration for the ROBIN algorithm. 

262 """ 

263 #super().__init__(algorithm_config, diffusion_config) 

264 self.config = watermarking_config 

265 self.utils = ROBINUtils(self.config) 

266 

267 # === Get the optimized watermark & watermarking signal before generation === 

268 self.dataset = StableDiffusionPromptsDataset() 

269 

270 # Build watermarking arguments 

271 self.watermarking_args = self.utils.build_watermarking_args() 

272 

273 # Optimize watermark and get components 

274 self.watermarking_mask, self.optimized_watermark, self.optimized_watermarking_signal = self.utils.optimize_watermark( 

275 self.dataset, 

276 self.watermarking_args 

277 ) 

278 

279 # Initialize detector 

280 self.detector = self.utils.initialize_detector(self.watermarking_mask, self.optimized_watermark) 

281 

282 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image: 

283 """Internal method to generate a watermarked image.""" 

284 self.set_orig_watermarked_latents(self.config.init_latents) 

285 

286 # Build generation parameters using utils 

287 generation_params = self.utils.build_generation_params(**kwargs) 

288 # Override guidance_scale for watermarked generation 

289 generation_params["guidance_scale"] = self.config.guidance_scale 

290 # Ensure latents parameter is not overridden 

291 generation_params["latents"] = self.config.init_latents 

292 

293 # Filter out parameters not supported by ROBINWatermarkedImageGeneration 

294 supported_params = { 

295 'height', 'width', 'num_inference_steps', 'guidance_scale', 'optimized_guidance_scale', 

296 'negative_prompt', 'num_images_per_prompt', 'eta', 'generator', 'latents', 

297 'output_type', 'return_dict', 'callback', 'callback_steps' 

298 } 

299 filtered_params = {k: v for k, v in generation_params.items() if k in supported_params} 

300 

301 # Ensure watermarking components are on the correct device 

302 watermarking_mask = self.watermarking_mask.to(self.config.device) 

303 optimized_watermark = self.optimized_watermark.to(self.config.device) 

304 optimized_watermarking_signal = self.optimized_watermarking_signal.to(self.config.device) if self.optimized_watermarking_signal is not None else None 

305 

306 # Generate watermarked image 

307 set_random_seed(self.config.gen_seed) 

308 result = ROBINWatermarkedImageGeneration( 

309 pipe=self.config.pipe, 

310 prompt=prompt, 

311 watermarking_mask=watermarking_mask, 

312 gt_patch=optimized_watermark, 

313 opt_acond=optimized_watermarking_signal, 

314 watermarking_step=self.config.watermarking_step, 

315 args=self.watermarking_args, 

316 **filtered_params, 

317 ) 

318 return result.images[0] 

319 

320 

321 def _detect_watermark_in_image(self, 

322 image: Image.Image, 

323 prompt: str = "", 

324 *args, 

325 **kwargs) -> Dict[str, float]: 

326 """Detect the watermark in the image.""" 

327 # Use config values as defaults if not explicitly provided 

328 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

329 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps) 

330 

331 # Step 1: Get Text Embeddings 

332 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

333 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

334 prompt=prompt, 

335 device=self.config.device, 

336 do_classifier_free_guidance=do_classifier_free_guidance, 

337 num_images_per_prompt=1, 

338 ) 

339 

340 if do_classifier_free_guidance: 

341 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

342 else: 

343 text_embeddings = prompt_embeds 

344 

345 # Step 2: Preprocess Image 

346 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

347 

348 # Step 3: Get Image Latents 

349 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get('decoder_inv', False)) 

350 

351 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters 

352 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']} 

353 

354 # Extract and reverse latents for detection using utils 

355 reversed_latents_list = self.config.inversion.forward_diffusion( 

356 latents=image_latents, 

357 text_embeddings=text_embeddings, 

358 guidance_scale=guidance_scale_to_use, 

359 num_inference_steps=num_steps_to_use, 

360 **inversion_kwargs 

361 ) 

362 

363 # Handle case where forward_diffusion returns a single tensor instead of a list 

364 if isinstance(reversed_latents_list, torch.Tensor): 

365 reversed_latents = reversed_latents_list 

366 else: 

367 # Ensure index is within bounds 

368 target_index = num_steps_to_use - 1 - self.config.watermarking_step 

369 if target_index < 0: 

370 target_index = 0 

371 elif target_index >= len(reversed_latents_list): 

372 target_index = len(reversed_latents_list) - 1 

373 reversed_latents = reversed_latents_list[target_index] 

374 

375 # Evaluate watermark 

376 if 'detector_type' in kwargs: 

377 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type']) 

378 else: 

379 return self.detector.eval_watermark(reversed_latents) 

380 

381 def get_data_for_visualize(self, 

382 image: Image.Image, 

383 prompt: str="", 

384 guidance_scale: Optional[float]=None, 

385 decoder_inv: bool=False, 

386 *args, 

387 **kwargs) -> DataForVisualization: 

388 # Use config values as defaults if not explicitly provided 

389 guidance_scale_to_use = guidance_scale if guidance_scale is not None else self.config.guidance_scale 

390 

391 # Step 1: Generate watermarked latents (generation process) 

392 set_random_seed(self.config.gen_seed) 

393 # For ROBIN, the watermarked latents are the init_latents (watermark is applied during generation) 

394 watermarked_latents = self.config.init_latents 

395 

396 # Step 2: Generate actual watermarked image using the same process as _generate_watermarked_image 

397 generation_params = self.utils.build_generation_params() 

398 generation_params["guidance_scale"] = self.config.guidance_scale 

399 generation_params["latents"] = self.config.init_latents 

400 

401 # Generate the actual watermarked image with ROBIN watermarking 

402 watermarked_image = ROBINWatermarkedImageGeneration( 

403 pipe=self.config.pipe, 

404 prompt=prompt, 

405 watermarking_mask=self.watermarking_mask, 

406 gt_patch=self.optimized_watermark, 

407 opt_acond=self.optimized_watermarking_signal, 

408 watermarking_step=self.config.watermarking_step, 

409 args=self.watermarking_args, 

410 **generation_params, 

411 ).images[0] 

412 

413 # Step 3: Perform watermark detection to get inverted latents (detection process) 

414 reversed_latents_list = None 

415 

416 # Get Text Embeddings for detection 

417 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

418 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

419 prompt=prompt, 

420 device=self.config.device, 

421 do_classifier_free_guidance=do_classifier_free_guidance, 

422 num_images_per_prompt=1, 

423 ) 

424 

425 if do_classifier_free_guidance: 

426 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

427 else: 

428 text_embeddings = prompt_embeds 

429 

430 # Preprocess watermarked image for detection 

431 processed_image = transform_to_model_format( 

432 watermarked_image, 

433 target_size=self.config.image_size[0] 

434 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

435 

436 # Get Image Latents 

437 image_latents = get_media_latents( 

438 pipe=self.config.pipe, 

439 media=processed_image, 

440 sample=False, 

441 decoder_inv=decoder_inv 

442 ) 

443 

444 # Reverse Image Latents to get inverted noise 

445 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['prompt', 'decoder_inv', 'guidance_scale', 'num_inference_steps']} 

446 

447 reversed_latents_list = self.config.inversion.forward_diffusion( 

448 latents=image_latents, 

449 text_embeddings=text_embeddings, 

450 guidance_scale=guidance_scale_to_use, 

451 num_inference_steps=self.config.num_inference_steps, 

452 **inversion_kwargs 

453 ) 

454 

455 # Step 4: Prepare visualization data  

456 return DataForVisualization( 

457 config=self.config, 

458 utils=self.utils, 

459 reversed_latents=reversed_latents_list, 

460 orig_watermarked_latents=self.orig_watermarked_latents, 

461 image=image, 

462 # ROBIN-specific data 

463 watermarking_mask=self.watermarking_mask, 

464 optimized_watermark=self.optimized_watermark, 

465 )