Coverage for watermark / robin / robin.py: 90.38%

156 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from ..base import BaseWatermark, BaseConfig 

2from utils.media_utils import * 

3import os 

4import types 

5import torch 

6from typing import Dict, Union, List, Optional 

7from utils.utils import set_random_seed, inherit_docstring 

8from utils.diffusion_config import DiffusionConfig 

9import copy 

10import numpy as np 

11from PIL import Image 

12from huggingface_hub import hf_hub_download 

13from visualize.data_for_visualization import DataForVisualization 

14from evaluation.dataset import StableDiffusionPromptsDataset 

15from utils.media_utils import get_random_latents 

16from .watermark_generator import get_watermarking_mask, inject_watermark, ROBINWatermarkedImageGeneration # OptimizedDataset, optimizer_wm_prompt 

17from detection.robin.robin_detection import ROBINDetector 

18 

19class ROBINConfig(BaseConfig): 

20 """Config class for ROBIN algorithm, load config file and initialize parameters.""" 

21 

22 def initialize_parameters(self) -> None: 

23 """Initialize algorithm-specific parameters.""" 

24 ## Watermarking-Specific Parameters 

25 self.w_seed = self.config_dict['w_seed'] 

26 self.w_channel = self.config_dict['w_channel'] 

27 self.w_pattern = self.config_dict['w_pattern'] 

28 self.w_mask_shape = self.config_dict['w_mask_shape'] 

29 self.w_up_radius = self.config_dict['w_up_radius'] 

30 self.w_low_radius = self.config_dict['w_low_radius'] 

31 self.w_injection = self.config_dict['w_injection'] 

32 self.w_pattern_const = self.config_dict['w_pattern_const'] 

33 self.threshold = self.config_dict['threshold'] 

34 

35 self.watermarking_step = self.config_dict['watermarking_step'] 

36 

37 self.is_training_from_scratch = self.config_dict.get('training_from_scratch', False) 

38 ## Training-Specific Parameters 

39 self.learning_rate = self.config_dict['learning_rate'] # learning rate for watermark optimization 

40 self.scale_lr = self.config_dict['scale_lr'] # if True, learning_rate will be multiplied by gradient_accumulation_steps * train_batch_size * num_processes 

41 self.max_train_steps = self.config_dict['max_train_steps'] # maximum number of training steps for watermark optimization 

42 self.save_steps = self.config_dict['save_steps'] # save steps for watermark optimization 

43 self.train_batch_size = self.config_dict['train_batch_size'] # batch size for watermark optimization 

44 self.gradient_accumulation_steps = self.config_dict['gradient_accumulation_steps'] # gradient accumulation steps for watermark optimization 

45 self.gradient_checkpointing = self.config_dict['gradient_checkpointing'] # if True, use gradient checkpointing for watermark optimization 

46 self.mixed_precision = self.config_dict['mixed_precision'] # fp16, fp32, bf16 

47 self.train_seed = self.config_dict['train_seed'] # seed for watermark optimization 

48 

49 self.optimized_guidance_scale = self.config_dict['optimized_guidance_scale'] # guidance scale for optimized prompt signal 

50 self.data_guidance_scale = self.config_dict['data_guidance_scale'] # guidance scale for data prompt signal 

51 self.train_guidance_scale = self.config_dict['train_guidance_scale'] # guidance scale for training prompt signal 

52 self.hf_dir = self.config_dict['hf_dir'] 

53 # self.output_img_dir = 'watermark/robin/generated_images' 

54 self.output_img_dir = "watermark/robin/generated_images" 

55 self.ckpt_dir = 'watermark/robin/ckpts' 

56 

57 @property 

58 def algorithm_name(self) -> str: 

59 """Return the algorithm name.""" 

60 return 'ROBIN' 

61 

62class ROBINUtils: 

63 """Utility class for ROBIN algorithm, contains helper functions.""" 

64 

65 def __init__(self, config: ROBINConfig, *args, **kwargs) -> None: 

66 """ 

67 Initialize the ROBIN watermarking algorithm. 

68  

69 Parameters: 

70 config (ROBINConfig): Configuration for the ROBIN algorithm. 

71 """ 

72 self.config = config 

73 

74 def build_generation_params(self, **kwargs) -> Dict: 

75 """Build generation parameters from config and kwargs.""" 

76 generation_params = { 

77 "num_images_per_prompt": self.config.num_images, 

78 "guidance_scale": self.config.guidance_scale, 

79 "num_inference_steps": self.config.num_inference_steps, 

80 "height": self.config.image_size[0], 

81 "width": self.config.image_size[1], 

82 "latents": self.config.init_latents, 

83 } 

84 

85 # Add parameters from config.gen_kwargs 

86 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs: 

87 for key, value in self.config.gen_kwargs.items(): 

88 if key not in generation_params: 

89 generation_params[key] = value 

90 

91 # Use kwargs to override default parameters 

92 for key, value in kwargs.items(): 

93 generation_params[key] = value 

94 

95 return generation_params 

96 

97 def build_watermarking_args(self) -> types.SimpleNamespace: 

98 """Build watermarking arguments from config.""" 

99 watermarking_args = { 

100 "w_seed": self.config.w_seed, 

101 "w_channel": self.config.w_channel, 

102 "w_pattern": self.config.w_pattern, 

103 "w_mask_shape": self.config.w_mask_shape, 

104 "w_up_radius": self.config.w_up_radius, 

105 "w_low_radius": self.config.w_low_radius, 

106 "w_pattern_const": self.config.w_pattern_const, 

107 "w_injection": self.config.w_injection, 

108 } 

109 return types.SimpleNamespace(**watermarking_args) 

110 

111 def build_hyperparameters(self) -> Dict: 

112 """Build hyperparameters for optimization from config.""" 

113 return { 

114 "learning_rate": self.config.learning_rate, 

115 "scale_lr": self.config.scale_lr, 

116 "max_train_steps": self.config.max_train_steps, 

117 "save_steps": self.config.save_steps, 

118 "train_batch_size": self.config.train_batch_size, 

119 "gradient_accumulation_steps": self.config.gradient_accumulation_steps, 

120 "gradient_checkpointing": self.config.gradient_checkpointing, 

121 "guidance_scale": self.config.train_guidance_scale, 

122 "optimized_guidance_scale": self.config.optimized_guidance_scale, 

123 "mixed_precision": self.config.mixed_precision, 

124 "seed": self.config.train_seed, 

125 "output_dir": self.config.ckpt_dir, 

126 } 

127 

128 def optimize_watermark(self, dataset: StableDiffusionPromptsDataset, watermarking_args: types.SimpleNamespace) -> tuple: 

129 """Optimize watermark and watermarking signal.""" 

130 init_latents_w = get_random_latents(pipe=self.config.pipe) 

131 watermarking_mask = get_watermarking_mask(init_latents_w, self.config, self.config.device).detach().cpu() 

132 

133 # Build hyperparameters 

134 hyperparameters = self.build_hyperparameters() 

135 filename = f"optimized_wm5-30_embedding-step-{hyperparameters['max_train_steps']}.pt" 

136 

137 # Check if file already exists locally before downloading 

138 base_dir = os.path.dirname(os.path.abspath(__file__)) 

139 checkpoint_path = None 

140 

141 # Check multiple potential local paths 

142 potential_paths = [ 

143 os.path.join(base_dir, self.config.hf_dir, filename) if self.config.hf_dir else None, 

144 os.path.join(self.config.hf_dir, filename) if self.config.hf_dir else None, 

145 os.path.join(self.config.ckpt_dir, filename), 

146 ] 

147 

148 for path in potential_paths: 

149 if path and os.path.exists(path): 

150 checkpoint_path = path 

151 print(f"Using existing ROBIN checkpoint: {checkpoint_path}") 

152 break 

153 

154 # If not found locally, download from HuggingFace 

155 if checkpoint_path is None: 

156 checkpoint_path = hf_hub_download( 

157 repo_id="Generative-Watermark-Toolkits/MarkDiffusion-robin", 

158 filename=filename, 

159 cache_dir=self.config.hf_dir 

160 ) 

161 print(f"Downloaded ROBIN checkpoint from Huggingface Hub: {checkpoint_path}") 

162 

163 # if os.path.exists(checkpoint_path): 

164 # if (not self.config.is_training_from_scratch): 

165 if not os.path.exists(checkpoint_path): 

166 os.makedirs(self.config.ckpt_dir, exist_ok=True) 

167 from huggingface_hub import snapshot_download 

168 snapshot_download( 

169 repo_id="Generative-Watermark-Toolkits/MarkDiffusion-robin", 

170 local_dir=self.config.ckpt_dir, 

171 repo_type="model", 

172 local_dir_use_symlinks=False, 

173 endpoint=os.getenv("HF_ENDPOINT", "https://huggingface.co"), 

174 ) 

175 

176 print(f"Loading checkpoint from {checkpoint_path}") 

177 checkpoint = torch.load(checkpoint_path, map_location=self.config.device) 

178 optimized_watermark = checkpoint['opt_wm'].to(self.config.device) 

179 optimized_watermarking_signal = checkpoint['opt_acond'].to(self.config.device) 

180 

181 return watermarking_mask, optimized_watermark, optimized_watermarking_signal 

182 

183 def initialize_detector(self, watermarking_mask, optimized_watermark) -> ROBINDetector: 

184 """Initialize the ROBIN detector.""" 

185 return ROBINDetector( 

186 watermarking_mask=watermarking_mask, 

187 gt_patch=optimized_watermark, 

188 threshold=self.config.threshold, 

189 device=self.config.device 

190 ) 

191 

192 # def preprocess_image_for_detection(self, image: Image.Image, prompt: str, guidance_scale: float) -> tuple: 

193 # """Preprocess image and get text embeddings for detection.""" 

194 # # Get Text Embeddings 

195 # do_classifier_free_guidance = (guidance_scale > 1.0) 

196 # prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

197 # prompt=prompt,  

198 # device=self.config.device,  

199 # do_classifier_free_guidance=do_classifier_free_guidance, 

200 # num_images_per_prompt=1, 

201 # ) 

202 

203 # if do_classifier_free_guidance: 

204 # text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

205 # else: 

206 # text_embeddings = prompt_embeds 

207 

208 # # Preprocess Image 

209 # processed_image = transform_to_model_format( 

210 # image,  

211 # target_size=self.config.image_size[0] 

212 # ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

213 

214 # return text_embeddings, processed_image 

215 

216 # def extract_latents_for_detection(self,  

217 # image: Image.Image,  

218 # prompt: str,  

219 # guidance_scale: float,  

220 # num_inference_steps: int, 

221 # extract_latents_step: int, 

222 # **kwargs) -> torch.Tensor: 

223 # """Extract and reverse latents for watermark detection.""" 

224 # # Preprocess image and get text embeddings 

225 # text_embeddings, processed_image = self.preprocess_image_for_detection(image, prompt, guidance_scale) 

226 

227 # # Get Image Latents 

228 # image_latents = get_media_latents( 

229 # pipe=self.config.pipe,  

230 # media=processed_image,  

231 # sample=False,  

232 # decoder_inv=kwargs.get('decoder_inv', False) 

233 # ) 

234 

235 # # Reverse Image Latents 

236 # inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']} 

237 

238 # reversed_latents = self.config.inversion.forward_diffusion( 

239 # latents=image_latents, 

240 # text_embeddings=text_embeddings, 

241 # guidance_scale=guidance_scale, 

242 # num_inference_steps=num_inference_steps, 

243 # **inversion_kwargs 

244 # )[extract_latents_step] 

245 

246 # return reversed_latents 

247 

248@inherit_docstring 

249class ROBIN(BaseWatermark): 

250 def __init__(self, 

251 watermarking_config: ROBINConfig, 

252 *args, **kwargs): 

253 """ 

254 Initialize the ROBIN watermarking algorithm. 

255  

256 Parameters: 

257 watermarking_config (ROBINConfig): Configuration for the ROBIN algorithm. 

258 """ 

259 #super().__init__(algorithm_config, diffusion_config) 

260 self.config = watermarking_config 

261 self.utils = ROBINUtils(self.config) 

262 

263 # === Get the optimized watermark & watermarking signal before generation === 

264 self.dataset = StableDiffusionPromptsDataset() 

265 

266 # Build watermarking arguments 

267 self.watermarking_args = self.utils.build_watermarking_args() 

268 

269 # Optimize watermark and get components 

270 self.watermarking_mask, self.optimized_watermark, self.optimized_watermarking_signal = self.utils.optimize_watermark( 

271 self.dataset, 

272 self.watermarking_args 

273 ) 

274 

275 # Initialize detector 

276 self.detector = self.utils.initialize_detector(self.watermarking_mask, self.optimized_watermark) 

277 

278 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image: 

279 """Internal method to generate a watermarked image.""" 

280 self.set_orig_watermarked_latents(self.config.init_latents) 

281 

282 # Build generation parameters using utils 

283 generation_params = self.utils.build_generation_params(**kwargs) 

284 # Override guidance_scale for watermarked generation 

285 generation_params["guidance_scale"] = self.config.guidance_scale 

286 # Ensure latents parameter is not overridden 

287 generation_params["latents"] = self.config.init_latents 

288 

289 # Filter out parameters not supported by ROBINWatermarkedImageGeneration 

290 supported_params = { 

291 'height', 'width', 'num_inference_steps', 'guidance_scale', 'optimized_guidance_scale', 

292 'negative_prompt', 'num_images_per_prompt', 'eta', 'generator', 'latents', 

293 'output_type', 'return_dict', 'callback', 'callback_steps' 

294 } 

295 filtered_params = {k: v for k, v in generation_params.items() if k in supported_params} 

296 

297 # Ensure watermarking components are on the correct device 

298 watermarking_mask = self.watermarking_mask.to(self.config.device) 

299 optimized_watermark = self.optimized_watermark.to(self.config.device) 

300 optimized_watermarking_signal = self.optimized_watermarking_signal.to(self.config.device) if self.optimized_watermarking_signal is not None else None 

301 

302 # Generate watermarked image 

303 set_random_seed(self.config.gen_seed) 

304 result = ROBINWatermarkedImageGeneration( 

305 pipe=self.config.pipe, 

306 prompt=prompt, 

307 watermarking_mask=watermarking_mask, 

308 gt_patch=optimized_watermark, 

309 opt_acond=optimized_watermarking_signal, 

310 watermarking_step=self.config.watermarking_step, 

311 args=self.watermarking_args, 

312 **filtered_params, 

313 ) 

314 return result.images[0] 

315 

316 

317 def _detect_watermark_in_image(self, 

318 image: Image.Image, 

319 prompt: str = "", 

320 *args, 

321 **kwargs) -> Dict[str, float]: 

322 """Detect the watermark in the image.""" 

323 # Use config values as defaults if not explicitly provided 

324 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

325 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps) 

326 

327 # Step 1: Get Text Embeddings 

328 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

329 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

330 prompt=prompt, 

331 device=self.config.device, 

332 do_classifier_free_guidance=do_classifier_free_guidance, 

333 num_images_per_prompt=1, 

334 ) 

335 

336 if do_classifier_free_guidance: 

337 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

338 else: 

339 text_embeddings = prompt_embeds 

340 

341 # Step 2: Preprocess Image 

342 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

343 

344 # Step 3: Get Image Latents 

345 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get('decoder_inv', False)) 

346 

347 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters 

348 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']} 

349 

350 # Extract and reverse latents for detection using utils 

351 reversed_latents_list = self.config.inversion.forward_diffusion( 

352 latents=image_latents, 

353 text_embeddings=text_embeddings, 

354 guidance_scale=guidance_scale_to_use, 

355 num_inference_steps=num_steps_to_use, 

356 **inversion_kwargs 

357 ) 

358 

359 # Handle case where forward_diffusion returns a single tensor instead of a list 

360 if isinstance(reversed_latents_list, torch.Tensor): 

361 reversed_latents = reversed_latents_list 

362 else: 

363 # Ensure index is within bounds 

364 target_index = num_steps_to_use - 1 - self.config.watermarking_step 

365 if target_index < 0: 

366 target_index = 0 

367 elif target_index >= len(reversed_latents_list): 

368 target_index = len(reversed_latents_list) - 1 

369 reversed_latents = reversed_latents_list[target_index] 

370 

371 # Evaluate watermark 

372 if 'detector_type' in kwargs: 

373 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type']) 

374 else: 

375 return self.detector.eval_watermark(reversed_latents) 

376 

377 def get_data_for_visualize(self, 

378 image: Image.Image, 

379 prompt: str="", 

380 guidance_scale: Optional[float]=None, 

381 decoder_inv: bool=False, 

382 *args, 

383 **kwargs) -> DataForVisualization: 

384 # Use config values as defaults if not explicitly provided 

385 guidance_scale_to_use = guidance_scale if guidance_scale is not None else self.config.guidance_scale 

386 

387 # Step 1: Generate watermarked latents (generation process) 

388 set_random_seed(self.config.gen_seed) 

389 # For ROBIN, the watermarked latents are the init_latents (watermark is applied during generation) 

390 watermarked_latents = self.config.init_latents 

391 

392 # Step 2: Generate actual watermarked image using the same process as _generate_watermarked_image 

393 generation_params = self.utils.build_generation_params() 

394 generation_params["guidance_scale"] = self.config.guidance_scale 

395 generation_params["latents"] = self.config.init_latents 

396 

397 # Generate the actual watermarked image with ROBIN watermarking 

398 watermarked_image = ROBINWatermarkedImageGeneration( 

399 pipe=self.config.pipe, 

400 prompt=prompt, 

401 watermarking_mask=self.watermarking_mask, 

402 gt_patch=self.optimized_watermark, 

403 opt_acond=self.optimized_watermarking_signal, 

404 watermarking_step=self.config.watermarking_step, 

405 args=self.watermarking_args, 

406 **generation_params, 

407 ).images[0] 

408 

409 # Step 3: Perform watermark detection to get inverted latents (detection process) 

410 reversed_latents_list = None 

411 

412 # Get Text Embeddings for detection 

413 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

414 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

415 prompt=prompt, 

416 device=self.config.device, 

417 do_classifier_free_guidance=do_classifier_free_guidance, 

418 num_images_per_prompt=1, 

419 ) 

420 

421 if do_classifier_free_guidance: 

422 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

423 else: 

424 text_embeddings = prompt_embeds 

425 

426 # Preprocess watermarked image for detection 

427 processed_image = transform_to_model_format( 

428 watermarked_image, 

429 target_size=self.config.image_size[0] 

430 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

431 

432 # Get Image Latents 

433 image_latents = get_media_latents( 

434 pipe=self.config.pipe, 

435 media=processed_image, 

436 sample=False, 

437 decoder_inv=decoder_inv 

438 ) 

439 

440 # Reverse Image Latents to get inverted noise 

441 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['prompt', 'decoder_inv', 'guidance_scale', 'num_inference_steps']} 

442 

443 reversed_latents_list = self.config.inversion.forward_diffusion( 

444 latents=image_latents, 

445 text_embeddings=text_embeddings, 

446 guidance_scale=guidance_scale_to_use, 

447 num_inference_steps=self.config.num_inference_steps, 

448 **inversion_kwargs 

449 ) 

450 

451 # Step 4: Prepare visualization data  

452 return DataForVisualization( 

453 config=self.config, 

454 utils=self.utils, 

455 reversed_latents=reversed_latents_list, 

456 orig_watermarked_latents=self.orig_watermarked_latents, 

457 image=image, 

458 # ROBIN-specific data 

459 watermarking_mask=self.watermarking_mask, 

460 optimized_watermark=self.optimized_watermark, 

461 )