Coverage for watermark / sfw / sfw.py: 91.12%

304 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from ..base import BaseWatermark, BaseConfig 

2from utils.media_utils import * 

3import torch 

4from typing import Dict, Optional 

5from utils.utils import set_random_seed, inherit_docstring 

6from utils.diffusion_config import DiffusionConfig 

7import numpy as np 

8from PIL import Image 

9from visualize.data_for_visualization import DataForVisualization 

10from detection.sfw.sfw_detection import SFWDetector 

11import torchvision.transforms as tforms 

12import qrcode 

13import logging 

14import os 

15 

16class SFWConfig(BaseConfig): 

17 """Config class for SFW algorithm, load config file and initialize parameters.""" 

18 

19 def initialize_parameters(self) -> None: 

20 """Initialize algorithm-specific parameters.""" 

21 self.w_seed = self.config_dict['w_seed'] 

22 self.delta=self.config_dict['delta'] 

23 self.wm_type=self.config_dict['wm_type'] # "HSTR" or "HSQR" 

24 self.threshold = self.config_dict['threshold'] 

25 self.w_channel = self.config_dict['w_channel'] 

26 

27 @property 

28 def algorithm_name(self) -> str: 

29 """Return the algorithm name.""" 

30 return 'SFW' 

31 

32class SFWUtils: 

33 """Utility class for SFW algorithm, contains helper functions.""" 

34 

35 def __init__(self, config: SFWConfig, *args, **kwargs) -> None: 

36 """ 

37 Initialize the SFW watermarking algorithm. 

38 

39 Parameters: 

40 config (SFWConfig): Configuration for the SFW algorithm. 

41 """ 

42 self.config = config 

43 self.gt_patch = self._get_watermarking_pattern() 

44 self.watermarking_mask = self._get_watermarking_mask(self.config.init_latents) 

45 

46 # [Fourier transforms] 

47 @staticmethod 

48 def fft(input_tensor): 

49 assert len(input_tensor.shape) == 4 

50 return torch.fft.fftshift(torch.fft.fft2(input_tensor), dim=(-1, -2)) 

51 

52 @staticmethod 

53 def ifft(input_tensor): 

54 assert len(input_tensor.shape) == 4 

55 return torch.fft.ifft2(torch.fft.ifftshift(input_tensor, dim=(-1, -2))) 

56 

57 @staticmethod 

58 @torch.no_grad() 

59 def rfft(input_tensor): 

60 assert len(input_tensor.shape) == 4 

61 return torch.fft.fftshift(torch.fft.rfft2(input_tensor, dim=(-2,-1)), dim=-2) 

62 

63 @staticmethod 

64 @torch.no_grad() 

65 def irfft(input_tensor): 

66 assert len(input_tensor.shape) == 4 

67 return torch.fft.irfft2(torch.fft.ifftshift(input_tensor, dim=-2), dim=(-2,-1), s=(input_tensor.shape[-2],input_tensor.shape[-2])) 

68 

69 def circle_mask(self,size: int, r=16, x_offset=0, y_offset=0): 

70 x0 = y0 = size // 2 

71 x0 += x_offset 

72 y0 += y_offset 

73 y, x = np.ogrid[:size, :size] 

74 return ((x - x0)**2 + (y-y0)**2)<= r**2 

75 

76 @torch.no_grad() 

77 def enforce_hermitian_symmetry(self,freq_tensor): 

78 B, C, H, W = freq_tensor.shape # fftshifted frequency (complex tensor) - center (32,32) 

79 assert H == W, "H != W" 

80 freq_tensor = freq_tensor.clone() 

81 freq_tensor_tmp = freq_tensor.clone() 

82 # DC point (no imaginary) 

83 freq_tensor[:, :, H//2, W//2] = torch.real(freq_tensor_tmp[:, :, H//2, W//2]) 

84 if H % 2 == 0: # Even 

85 # Nyquist Points (no imaginary) 

86 freq_tensor[:, :, 0, 0] = torch.real(freq_tensor_tmp[:, :, 0, 0]) 

87 freq_tensor[:, :, H//2, 0] = torch.real(freq_tensor_tmp[:, :, H//2, 0]) # (32, 0) 

88 freq_tensor[:, :, 0, W//2] = torch.real(freq_tensor_tmp[:, :, 0, W//2]) # (0, 32) 

89 

90 # Nyquist axis - conjugate 

91 freq_tensor[:, :, 0, 1:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, 0, W//2+1:], dims=[2])) 

92 freq_tensor[:, :, H//2, 1:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2, W//2+1:], dims=[2])) 

93 freq_tensor[:, :, 1:H//2, 0] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, 0], dims=[2])) 

94 freq_tensor[:, :, 1:H//2, W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, W//2], dims=[2])) 

95 # Square quadrants - conjugate 

96 freq_tensor[:, :, 1:H//2, 1:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, W//2+1:], dims=[2, 3])) 

97 freq_tensor[:, :, H//2+1:, 1:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, 1:H//2, W//2+1:], dims=[2, 3])) 

98 else: # Odd 

99 # Nyquist axis - conjugate 

100 freq_tensor[:, :, H//2, 0:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2, W//2+1:], dims=[2])) 

101 freq_tensor[:, :, 0:H//2, W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, W//2], dims=[2])) 

102 # Square quadrants - conjugate 

103 freq_tensor[:, :, 0:H//2, 0:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, W//2+1:], dims=[2, 3])) 

104 freq_tensor[:, :, H//2+1:, 0:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, 0:H//2, W//2+1:], dims=[2, 3])) 

105 return freq_tensor 

106 

107 @torch.no_grad() 

108 def make_Fourier_treering_pattern(self,pipe, shape, w_seed=999999, resolution=512): 

109 assert shape[-1] == shape[-2] # 64==64 

110 g = torch.Generator(device=self.config.device).manual_seed(w_seed) 

111 gt_init = pipe.prepare_latents(1, pipe.unet.in_channels, resolution, resolution, pipe.unet.dtype, torch.device(self.config.device), g) # (1,4,64,64) 

112 # [HSTR] center-aware design 

113 watermarked_latents_fft = SFWUtils.fft(torch.zeros(shape, device=self.config.device)) # (1,4,64,64) complex64 

114 

115 h = shape[-2] 

116 # Use a relative size or default to 44 if possible 

117 if h >= 64: 

118 patch_size = 44 

119 else: 

120 patch_size = h - 2 if h > 2 else h # Use almost full size for small latents 

121 

122 start = (h - patch_size) // 2 

123 end = start + patch_size 

124 

125 center_slice = (slice(None), slice(None), slice(start, end), slice(start, end)) 

126 gt_patch_tmp = SFWUtils.fft(gt_init[center_slice]).clone().detach() # (1,4,44,44) complex64 

127 center_len = gt_patch_tmp.shape[-1] // 2 # 22 

128 for radius in range(center_len-1, 0, -1): # [21,20,...,1] 

129 tmp_mask = torch.tensor(self.circle_mask(size=shape[-1], r=radius)) # (64,64) 

130 for j in range(watermarked_latents_fft.shape[1]): # GT : all channel Tree-Ring 

131 watermarked_latents_fft[:, j, tmp_mask] = gt_patch_tmp[0, j, center_len, center_len + radius].item() # Use (22,22+radius) element. 

132 # Gaussian noise key (Heterogenous watermark in RingID) 

133 watermarked_latents_fft[:,[0], start:end, start:end] = gt_patch_tmp[:, [0]] # (1,1,44,44) complex64 

134 # [Hermitian Symmetric Fourier] HSTR  

135 return self.enforce_hermitian_symmetry(watermarked_latents_fft) 

136 

137 # HSQR - hermitian symmetric QR pattern 

138 class QRCodeGenerator: 

139 def __init__(self, box_size=2, border=1, qr_version=1): 

140 self.qr = qrcode.QRCode(version=qr_version, box_size=box_size, border=border, 

141 error_correction = qrcode.constants.ERROR_CORRECT_H) 

142 

143 def make_qr_tensor(self, data, filename='qrcode.png', save_img=False): 

144 self.qr.add_data(data) 

145 self.qr.make(fit=True) 

146 img = self.qr.make_image(fill_color="black", back_color="white") 

147 if save_img: 

148 img.save(filename) 

149 self.clear() 

150 img_array = np.array(img) 

151 tensor = torch.from_numpy(img_array) 

152 return tensor.clone().detach() # boolean (h,w) 

153 

154 def clear(self): 

155 self.qr.clear() 

156 

157 @torch.no_grad() 

158 def make_hsqr_pattern(self,idx: int): 

159 qr_generator = self.QRCodeGenerator(box_size=2, border=0, qr_version=1) 

160 data = f"HSQR{idx % 10000}" 

161 qr_tensor = qr_generator.make_qr_tensor(data=data) # (42,42) boolean tensor 

162 qr_tensor = qr_tensor.repeat(len([self.config.w_channel]), 1, 1) # (c_wm,42,42) boolean tensor 

163 return qr_tensor # (c_wm,42,42) boolean tensor 

164 

165 def _get_watermarking_pattern(self) -> torch.Tensor: 

166 """Get the ground truth watermarking pattern.""" 

167 set_random_seed(self.config.w_seed) 

168 latent_h = self.config.image_size[0] // 8 

169 latent_w = self.config.image_size[1] // 8 

170 shape = (1, 4, latent_h, latent_w) 

171 if self.config.wm_type == "HSQR": 

172 Fourier_watermark_pattern_list = [self.make_hsqr_pattern(idx=self.config.w_seed)] 

173 else: 

174 Fourier_watermark_pattern_list = [self.make_Fourier_treering_pattern(self.config.pipe, shape, self.config.w_seed)] 

175 pattern_gt_batch = [Fourier_watermark_pattern_list[0]] 

176 # adjust dims of pattern_gt_batch 

177 if len(pattern_gt_batch[0].shape) == 4: 

178 pattern_gt_batch = torch.cat(pattern_gt_batch, dim=0) # (N,4,64,64) for HSTR 

179 elif len(pattern_gt_batch[0].shape) == 3: 

180 pattern_gt_batch = torch.stack(pattern_gt_batch, dim=0) # (N,c_wm,42,42) for HSQR 

181 else: 

182 raise ValueError(f"Unexpected pattern_gt_batch shape: {pattern_gt_batch[0].shape}") 

183 assert len(pattern_gt_batch.shape) == 4 

184 

185 return pattern_gt_batch 

186 

187 # ring mask 

188 class RounderRingMask: 

189 def __init__(self, size=65, r_out=14): 

190 assert size >= 3 

191 self.size = size 

192 self.r_out = r_out 

193 

194 num_rings = r_out 

195 zero_bg_freq = torch.zeros(size, size) 

196 center = size // 2 

197 center_x, center_y = center, center 

198 

199 ring_vector = torch.tensor([(200 - i*4) * (-1)**i for i in range(num_rings)]) 

200 zero_bg_freq[center_x, center_y:center_y+num_rings] = ring_vector 

201 zero_bg_freq = zero_bg_freq[None, None, ...] 

202 self.ring_vector_np = ring_vector.numpy() 

203 

204 res = torch.zeros(360, size, size) 

205 res[0] = zero_bg_freq 

206 for angle in range(1, 360): 

207 zero_bg_freq_rot = tforms.functional.rotate(zero_bg_freq, angle=angle) 

208 res[angle] = zero_bg_freq_rot 

209 

210 res = res.numpy() 

211 self.res = res 

212 self.pure_bg = np.zeros((size, size)) 

213 for x in range(size): 

214 for y in range(size): 

215 values, count = np.unique(res[:, x, y], return_counts=True) 

216 if len(count) > 2: 

217 self.pure_bg[x, y] = values[count == max(count[values!=0])][0] 

218 elif len(count) == 2: 

219 self.pure_bg[x, y] = values[values!=0][0] 

220 

221 def get_ring_mask(self, r_out, r_in): 

222 # get mask from pure_bg 

223 assert r_out <= self.r_out 

224 if r_in - 1 < 0: 

225 right_end = 0 # None, to take the center 

226 else: 

227 right_end = r_in - 1 

228 cand_list = self.ring_vector_np[r_out-1:right_end:-1] 

229 mask = np.isin(self.pure_bg, cand_list) 

230 if self.size % 2: 

231 mask = mask[:self.size-1, :self.size-1] # [64, 64] 

232 return mask 

233 

234 def _get_watermarking_mask(self, init_latents: torch.Tensor) -> torch.Tensor: 

235 """Get the watermarking mask.""" 

236 shape=(1,4,64,64) 

237 tree_masks = torch.zeros(shape, dtype=torch.bool) # (1,4,64,64) 

238 single_channel_tree_watermark_mask = torch.tensor(self.circle_mask(size=shape[-1], r=14)) # (64,64) 

239 tree_masks[:, [self.config.w_channel]] = single_channel_tree_watermark_mask # (64,64) 

240 masks = tree_masks 

241 mask_obj = self.RounderRingMask(size=65, r_out=14) 

242 single_channel_heter_watermark_mask = torch.tensor(mask_obj.get_ring_mask(r_out=14, r_in=3) ) # (64,64) 

243 masks[:, [0]] = single_channel_heter_watermark_mask # (64,64) RounderRingMask for Hetero Watermark (noise) 

244 return masks 

245 

246 @torch.no_grad() 

247 def inject_wm(self,init_latents: torch.Tensor): 

248 # for HSTR 

249 assert len(self.gt_patch.shape) == 4 

250 assert len(self.watermarking_mask.shape) == 4 

251 batch_size = init_latents.shape[0] 

252 self.watermarking_mask = self.watermarking_mask.repeat(batch_size, 1, 1, 1) 

253 

254 init_latents =init_latents.to(self.config.device) 

255 self.gt_patch = self.gt_patch.to(self.config.device) 

256 self.watermarking_mask = self.watermarking_mask.to(self.config.device) 

257 

258 # inject watermarks in fourier space 

259 h = init_latents.shape[-2] 

260 if h >= 64: 

261 patch_size = 44 

262 else: 

263 patch_size = h - 2 if h > 2 else h 

264 

265 start = (h - patch_size) // 2 

266 end = start + patch_size 

267 

268 center_slice = (slice(None), slice(None), slice(start, end), slice(start, end)) 

269 assert len(init_latents[center_slice].shape) == 4 

270 center_latent_fft=torch.fft.fftshift(torch.fft.fft2(init_latents[center_slice]), dim=(-1, -2))# (N,4,44,44) complex64 

271 #injection 

272 temp_mask = self.watermarking_mask[center_slice] # (N,4,44,44) boolean 

273 temp_pattern = self.gt_patch[center_slice] # (N,4,44,44) complex64 

274 center_latent_fft[temp_mask] = temp_pattern[temp_mask].clone() # (N,4,44,44) complex64 

275 # IFFT 

276 assert len(center_latent_fft.shape) == 4 

277 center_latent_ifft=torch.fft.ifft2(torch.fft.ifftshift(center_latent_fft, dim=(-1, -2)))# (N,4,44,44) 

278 center_latent_ifft = center_latent_ifft.real if center_latent_ifft.imag.abs().max() < 1e-3 else center_latent_ifft 

279 

280 init_latents = init_latents.clone() 

281 init_latents[center_slice] = center_latent_ifft 

282 init_latents[init_latents == float("Inf")] = 4 

283 init_latents[init_latents == float("-Inf")] = -4 

284 return init_latents # float32 

285 

286 @torch.no_grad() 

287 def inject_hsqr(self,inverted_latent): # (N,4,64,64) -> (N,4,64,64) 

288 assert len(self.gt_patch.shape) == 4 # (N,c_wm,42,42) 

289 inverted_latent = inverted_latent.to(self.config.device) 

290 self.gt_patch = self.gt_patch.to(self.config.device) 

291 qr_pix_len = self.gt_patch.shape[-1] # 42 

292 qr_pix_half = (qr_pix_len + 1) // 2 # 21 

293 qr_left = self.gt_patch[:, :, :, :qr_pix_half] # (N,c_wm,42,21) boolean 

294 qr_right = self.gt_patch[:, :, :, qr_pix_half:] # (N,c_wm,42,21) boolean 

295 # rfft 

296 h, w = inverted_latent.shape[-2:] 

297 # The original code used a 44x44 slice for a 42x42 patch (padding of 1 on each side?) 

298 # 54 - 10 = 44.  

299 patch_size = qr_pix_len + 2 # 44 

300 

301 if h < patch_size or w < patch_size: 

302 logging.warning(f"Latent size ({h}x{w}) too small for SFW HSQR injection (required {patch_size}x{patch_size}). Skipping injection.") 

303 return inverted_latent 

304 

305 start_h = (h - patch_size) // 2 

306 start_w = (w - patch_size) // 2 

307 end_h = start_h + patch_size 

308 end_w = start_w + patch_size 

309 

310 center_slice = (slice(None), slice(None), slice(start_h, end_h), slice(start_w, end_w)) 

311 center_latent_rfft = SFWUtils.rfft(inverted_latent[center_slice]) # (N,4,44,44) -> # (N,4,44,23) complex64 

312 center_real_batch = center_latent_rfft.real # (N,4,44,23) f32 

313 center_imag_batch = center_latent_rfft.imag # (N,4,44,23) f32 

314 real_slice = (slice(None),[self.config.w_channel], slice(1, 1+qr_pix_len), slice(1, 1+qr_pix_half)) 

315 imag_slice = (slice(None), [self.config.w_channel], slice(1, 1+qr_pix_len), slice(1, 1+qr_pix_half)) 

316 center_real_batch[real_slice] = torch.where(qr_left, center_real_batch[real_slice].abs() + self.config.delta, -center_real_batch[real_slice].abs() - self.config.delta) 

317 center_imag_batch[imag_slice] = torch.where(qr_right, center_imag_batch[imag_slice].abs() + self.config.delta, -center_imag_batch[imag_slice].abs() - self.config.delta) 

318 center_latent_ifft = SFWUtils.irfft(torch.complex(center_real_batch, center_imag_batch)) # (N,4,44,44) f32 

319 inverted_latent = inverted_latent.clone() 

320 inverted_latent[center_slice] = center_latent_ifft 

321 return inverted_latent # (N,4,64,64) 

322 

323@inherit_docstring 

324class SFW(BaseWatermark): 

325 def __init__(self, 

326 watermark_config: SFWConfig, 

327 *args, **kwargs): 

328 """ 

329 Initialize the SFW watermarking algorithm. 

330  

331 Parameters: 

332 watermark_config (SFWConfig): Configuration instance of the SFW algorithm. 

333 """ 

334 self.config = watermark_config 

335 self.utils = SFWUtils(self.config) 

336 self.detector = SFWDetector( 

337 watermarking_mask=self.utils.watermarking_mask, 

338 gt_patch=self.utils.gt_patch, 

339 w_channel=self.config.w_channel, 

340 threshold=self.config.threshold, 

341 device=self.config.device, 

342 wm_type=self.config.wm_type 

343 ) 

344 

345 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image: 

346 """Internal method to generate a watermarked image.""" 

347 if(self.config.wm_type=="HSQR"): 

348 watermarked_latents = self.utils.inject_hsqr(self.config.init_latents) 

349 else: 

350 watermarked_latents = self.utils.inject_wm(self.config.init_latents) 

351 

352 # save watermarked latents 

353 self.set_orig_watermarked_latents(watermarked_latents) 

354 

355 # Construct generation parameters 

356 generation_params = { 

357 "num_images_per_prompt": self.config.num_images, 

358 "guidance_scale": self.config.guidance_scale, 

359 "num_inference_steps": self.config.num_inference_steps, 

360 "height": self.config.image_size[0], 

361 "width": self.config.image_size[1], 

362 "latents": watermarked_latents, 

363 } 

364 

365 # Add parameters from config.gen_kwargs 

366 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs: 

367 for key, value in self.config.gen_kwargs.items(): 

368 if key not in generation_params: 

369 generation_params[key] = value 

370 

371 # Use kwargs to override default parameters 

372 for key, value in kwargs.items(): 

373 generation_params[key] = value 

374 

375 # Ensure latents parameter is not overridden 

376 generation_params["latents"] = watermarked_latents 

377 

378 return self.config.pipe( 

379 prompt, 

380 **generation_params 

381 ).images[0] 

382 

383 def _detect_watermark_in_image(self, 

384 image: Image.Image, 

385 prompt: str = "", 

386 *args, 

387 **kwargs) -> Dict[str, float]: 

388 """Detect the watermark in the image.""" 

389 # Use config values as defaults if not explicitly provided 

390 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

391 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps) 

392 

393 # Step 1: Get Text Embeddings 

394 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

395 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

396 prompt=prompt, 

397 device=self.config.device, 

398 do_classifier_free_guidance=do_classifier_free_guidance, 

399 num_images_per_prompt=1, # TODO: Multiple image generation to be supported 

400 ) 

401 

402 if do_classifier_free_guidance: 

403 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

404 else: 

405 text_embeddings = prompt_embeds 

406 

407 # Step 2: Preprocess Image 

408 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

409 

410 # Step 3: Get Image Latents 

411 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get('decoder_inv', False)) 

412 

413 # Step 4: Reverse Image Latents 

414 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters 

415 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']} 

416 

417 reversed_latents = self.config.inversion.forward_diffusion( 

418 latents=image_latents, 

419 text_embeddings=text_embeddings, 

420 guidance_scale=guidance_scale_to_use, 

421 num_inference_steps=num_steps_to_use, 

422 **inversion_kwargs 

423 )[-1] 

424 # Step 5: Evaluate Watermark 

425 if 'detector_type' in kwargs: 

426 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type']) 

427 else: 

428 return self.detector.eval_watermark(reversed_latents) 

429 

430 def get_data_for_visualize(self, 

431 image: Image.Image, 

432 prompt: str="", 

433 guidance_scale: Optional[float]=None, 

434 decoder_inv: bool=False, 

435 *args, 

436 **kwargs) -> DataForVisualization: 

437 """Get data for visualization including detection inversion""" 

438 # Use config values as defaults if not explicitly provided 

439 guidance_scale_to_use = guidance_scale if guidance_scale is not None else self.config.guidance_scale 

440 

441 # Step 1: Generate watermarked latents (generation process) 

442 set_random_seed(self.config.gen_seed) 

443 if (self.config.wm_type=="HSQR"): 

444 watermarked_latents = self.utils.inject_hsqr(self.config.init_latents) 

445 else: 

446 watermarked_latents = self.utils.inject_wm(self.config.init_latents) 

447 

448 # Step 2: Generate actual watermarked image using the same process as _generate_watermarked_image 

449 generation_params = { 

450 "num_images_per_prompt": self.config.num_images, 

451 "guidance_scale": self.config.guidance_scale, 

452 "num_inference_steps": self.config.num_inference_steps, 

453 "height": self.config.image_size[0], 

454 "width": self.config.image_size[1], 

455 "latents": watermarked_latents, 

456 } 

457 

458 # Add parameters from config.gen_kwargs 

459 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs: 

460 for key, value in self.config.gen_kwargs.items(): 

461 if key not in generation_params: 

462 generation_params[key] = value 

463 

464 # Generate the actual watermarked image 

465 watermarked_image = self.config.pipe( 

466 prompt, 

467 **generation_params 

468 ).images[0] 

469 

470 # Step 3: Perform watermark detection to get inverted latents (detection process) 

471 inverted_latents = None 

472 try: 

473 # Get Text Embeddings for detection 

474 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

475 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

476 prompt=prompt, 

477 device=self.config.device, 

478 do_classifier_free_guidance=do_classifier_free_guidance, 

479 num_images_per_prompt=1, # TODO: Multiple image generation to be supported 

480 ) 

481 

482 if do_classifier_free_guidance: 

483 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

484 else: 

485 text_embeddings = prompt_embeds 

486 

487 # Preprocess watermarked image for detection 

488 processed_image = transform_to_model_format( 

489 watermarked_image, 

490 target_size=self.config.image_size[0] 

491 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

492 

493 # Get Image Latents 

494 image_latents = get_media_latents( 

495 pipe=self.config.pipe, 

496 media=processed_image, 

497 sample=False, 

498 decoder_inv=decoder_inv 

499 ) 

500 

501 # Reverse Image Latents to get inverted noise 

502 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['prompt', 'decoder_inv', 'guidance_scale', 'num_inference_steps']} 

503 

504 reversed_latents_list = self.config.inversion.forward_diffusion( 

505 latents=image_latents, 

506 text_embeddings=text_embeddings, 

507 guidance_scale=guidance_scale_to_use, 

508 num_inference_steps=self.config.num_inference_steps, 

509 **inversion_kwargs 

510 ) 

511 

512 inverted_latents = reversed_latents_list[-1] 

513 

514 except Exception as e: 

515 print(f"Warning: Could not perform inversion for visualization: {e}") 

516 inverted_latents = None 

517 

518 # Step 4: Prepare visualization data  

519 return DataForVisualization( 

520 config=self.config, 

521 utils=self.utils, 

522 reversed_latents=reversed_latents_list, 

523 orig_watermarked_latents=self.orig_watermarked_latents, 

524 image=image 

525 )