Coverage for markdiffusion / watermark / sfw / sfw.py: 91.15%

305 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-14 19:25 +0000

1from ..base import BaseWatermark, BaseConfig 

2from markdiffusion.utils.media_utils import * 

3import torch 

4from typing import Dict, Optional 

5from markdiffusion.utils.utils import set_random_seed, inherit_docstring 

6from markdiffusion.utils.diffusion_config import DiffusionConfig 

7import numpy as np 

8from PIL import Image 

9from markdiffusion.visualize.data_for_visualization import DataForVisualization 

10from markdiffusion.detection.sfw.sfw_detection import SFWDetector 

11import torchvision.transforms as tforms 

12import qrcode 

13import logging 

14import os 

15 

16class SFWConfig(BaseConfig): 

17 """Config class for SFW algorithm, load config file and initialize parameters.""" 

18 

19 def initialize_parameters(self) -> None: 

20 """Initialize algorithm-specific parameters.""" 

21 self.w_seed = self.config_dict['w_seed'] 

22 self.delta=self.config_dict['delta'] 

23 self.wm_type=self.config_dict['wm_type'] # "HSTR" or "HSQR" 

24 self.threshold = self.config_dict['threshold'] 

25 self.threshold_p_value = self.config_dict.get('threshold_p_value', 0.01) 

26 self.w_channel = self.config_dict['w_channel'] 

27 

28 @property 

29 def algorithm_name(self) -> str: 

30 """Return the algorithm name.""" 

31 return 'SFW' 

32 

33class SFWUtils: 

34 """Utility class for SFW algorithm, contains helper functions.""" 

35 

36 def __init__(self, config: SFWConfig, *args, **kwargs) -> None: 

37 """ 

38 Initialize the SFW watermarking algorithm. 

39 

40 Parameters: 

41 config (SFWConfig): Configuration for the SFW algorithm. 

42 """ 

43 self.config = config 

44 self.gt_patch = self._get_watermarking_pattern() 

45 self.watermarking_mask = self._get_watermarking_mask(self.config.init_latents) 

46 

47 # [Fourier transforms] 

48 @staticmethod 

49 def fft(input_tensor): 

50 assert len(input_tensor.shape) == 4 

51 return torch.fft.fftshift(torch.fft.fft2(input_tensor), dim=(-1, -2)) 

52 

53 @staticmethod 

54 def ifft(input_tensor): 

55 assert len(input_tensor.shape) == 4 

56 return torch.fft.ifft2(torch.fft.ifftshift(input_tensor, dim=(-1, -2))) 

57 

58 @staticmethod 

59 @torch.no_grad() 

60 def rfft(input_tensor): 

61 assert len(input_tensor.shape) == 4 

62 return torch.fft.fftshift(torch.fft.rfft2(input_tensor, dim=(-2,-1)), dim=-2) 

63 

64 @staticmethod 

65 @torch.no_grad() 

66 def irfft(input_tensor): 

67 assert len(input_tensor.shape) == 4 

68 return torch.fft.irfft2(torch.fft.ifftshift(input_tensor, dim=-2), dim=(-2,-1), s=(input_tensor.shape[-2],input_tensor.shape[-2])) 

69 

70 def circle_mask(self,size: int, r=16, x_offset=0, y_offset=0): 

71 x0 = y0 = size // 2 

72 x0 += x_offset 

73 y0 += y_offset 

74 y, x = np.ogrid[:size, :size] 

75 return ((x - x0)**2 + (y-y0)**2)<= r**2 

76 

77 @torch.no_grad() 

78 def enforce_hermitian_symmetry(self,freq_tensor): 

79 B, C, H, W = freq_tensor.shape # fftshifted frequency (complex tensor) - center (32,32) 

80 assert H == W, "H != W" 

81 freq_tensor = freq_tensor.clone() 

82 freq_tensor_tmp = freq_tensor.clone() 

83 # DC point (no imaginary) 

84 freq_tensor[:, :, H//2, W//2] = torch.real(freq_tensor_tmp[:, :, H//2, W//2]) 

85 if H % 2 == 0: # Even 

86 # Nyquist Points (no imaginary) 

87 freq_tensor[:, :, 0, 0] = torch.real(freq_tensor_tmp[:, :, 0, 0]) 

88 freq_tensor[:, :, H//2, 0] = torch.real(freq_tensor_tmp[:, :, H//2, 0]) # (32, 0) 

89 freq_tensor[:, :, 0, W//2] = torch.real(freq_tensor_tmp[:, :, 0, W//2]) # (0, 32) 

90 

91 # Nyquist axis - conjugate 

92 freq_tensor[:, :, 0, 1:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, 0, W//2+1:], dims=[2])) 

93 freq_tensor[:, :, H//2, 1:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2, W//2+1:], dims=[2])) 

94 freq_tensor[:, :, 1:H//2, 0] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, 0], dims=[2])) 

95 freq_tensor[:, :, 1:H//2, W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, W//2], dims=[2])) 

96 # Square quadrants - conjugate 

97 freq_tensor[:, :, 1:H//2, 1:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, W//2+1:], dims=[2, 3])) 

98 freq_tensor[:, :, H//2+1:, 1:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, 1:H//2, W//2+1:], dims=[2, 3])) 

99 else: # Odd 

100 # Nyquist axis - conjugate 

101 freq_tensor[:, :, H//2, 0:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2, W//2+1:], dims=[2])) 

102 freq_tensor[:, :, 0:H//2, W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, W//2], dims=[2])) 

103 # Square quadrants - conjugate 

104 freq_tensor[:, :, 0:H//2, 0:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, W//2+1:], dims=[2, 3])) 

105 freq_tensor[:, :, H//2+1:, 0:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, 0:H//2, W//2+1:], dims=[2, 3])) 

106 return freq_tensor 

107 

108 @torch.no_grad() 

109 def make_Fourier_treering_pattern(self,pipe, shape, w_seed=999999, resolution=512): 

110 assert shape[-1] == shape[-2] # 64==64 

111 g = torch.Generator(device=self.config.device).manual_seed(w_seed) 

112 gt_init = pipe.prepare_latents(1, pipe.unet.in_channels, resolution, resolution, pipe.unet.dtype, torch.device(self.config.device), g) # (1,4,64,64) 

113 # [HSTR] center-aware design 

114 watermarked_latents_fft = SFWUtils.fft(torch.zeros(shape, device=self.config.device)) # (1,4,64,64) complex64 

115 

116 h = shape[-2] 

117 # Use a relative size or default to 44 if possible 

118 if h >= 64: 

119 patch_size = 44 

120 else: 

121 patch_size = h - 2 if h > 2 else h # Use almost full size for small latents 

122 

123 start = (h - patch_size) // 2 

124 end = start + patch_size 

125 

126 center_slice = (slice(None), slice(None), slice(start, end), slice(start, end)) 

127 gt_patch_tmp = SFWUtils.fft(gt_init[center_slice]).clone().detach() # (1,4,44,44) complex64 

128 center_len = gt_patch_tmp.shape[-1] // 2 # 22 

129 for radius in range(center_len-1, 0, -1): # [21,20,...,1] 

130 tmp_mask = torch.tensor(self.circle_mask(size=shape[-1], r=radius)) # (64,64) 

131 for j in range(watermarked_latents_fft.shape[1]): # GT : all channel Tree-Ring 

132 watermarked_latents_fft[:, j, tmp_mask] = gt_patch_tmp[0, j, center_len, center_len + radius].item() # Use (22,22+radius) element. 

133 # Gaussian noise key (Heterogenous watermark in RingID) 

134 watermarked_latents_fft[:,[0], start:end, start:end] = gt_patch_tmp[:, [0]] # (1,1,44,44) complex64 

135 # [Hermitian Symmetric Fourier] HSTR  

136 return self.enforce_hermitian_symmetry(watermarked_latents_fft) 

137 

138 # HSQR - hermitian symmetric QR pattern 

139 class QRCodeGenerator: 

140 def __init__(self, box_size=2, border=1, qr_version=1): 

141 self.qr = qrcode.QRCode(version=qr_version, box_size=box_size, border=border, 

142 error_correction = qrcode.constants.ERROR_CORRECT_H) 

143 

144 def make_qr_tensor(self, data, filename='qrcode.png', save_img=False): 

145 self.qr.add_data(data) 

146 self.qr.make(fit=True) 

147 img = self.qr.make_image(fill_color="black", back_color="white") 

148 if save_img: 

149 img.save(filename) 

150 self.clear() 

151 img_array = np.array(img) 

152 tensor = torch.from_numpy(img_array) 

153 return tensor.clone().detach() # boolean (h,w) 

154 

155 def clear(self): 

156 self.qr.clear() 

157 

158 @torch.no_grad() 

159 def make_hsqr_pattern(self,idx: int): 

160 qr_generator = self.QRCodeGenerator(box_size=2, border=0, qr_version=1) 

161 data = f"HSQR{idx % 10000}" 

162 qr_tensor = qr_generator.make_qr_tensor(data=data) # (42,42) boolean tensor 

163 qr_tensor = qr_tensor.repeat(len([self.config.w_channel]), 1, 1) # (c_wm,42,42) boolean tensor 

164 return qr_tensor # (c_wm,42,42) boolean tensor 

165 

166 def _get_watermarking_pattern(self) -> torch.Tensor: 

167 """Get the ground truth watermarking pattern.""" 

168 set_random_seed(self.config.w_seed) 

169 latent_h = self.config.image_size[0] // 8 

170 latent_w = self.config.image_size[1] // 8 

171 shape = (1, 4, latent_h, latent_w) 

172 if self.config.wm_type == "HSQR": 

173 Fourier_watermark_pattern_list = [self.make_hsqr_pattern(idx=self.config.w_seed)] 

174 else: 

175 Fourier_watermark_pattern_list = [self.make_Fourier_treering_pattern(self.config.pipe, shape, self.config.w_seed)] 

176 pattern_gt_batch = [Fourier_watermark_pattern_list[0]] 

177 # adjust dims of pattern_gt_batch 

178 if len(pattern_gt_batch[0].shape) == 4: 

179 pattern_gt_batch = torch.cat(pattern_gt_batch, dim=0) # (N,4,64,64) for HSTR 

180 elif len(pattern_gt_batch[0].shape) == 3: 

181 pattern_gt_batch = torch.stack(pattern_gt_batch, dim=0) # (N,c_wm,42,42) for HSQR 

182 else: 

183 raise ValueError(f"Unexpected pattern_gt_batch shape: {pattern_gt_batch[0].shape}") 

184 assert len(pattern_gt_batch.shape) == 4 

185 

186 return pattern_gt_batch 

187 

188 # ring mask 

189 class RounderRingMask: 

190 def __init__(self, size=65, r_out=14): 

191 assert size >= 3 

192 self.size = size 

193 self.r_out = r_out 

194 

195 num_rings = r_out 

196 zero_bg_freq = torch.zeros(size, size) 

197 center = size // 2 

198 center_x, center_y = center, center 

199 

200 ring_vector = torch.tensor([(200 - i*4) * (-1)**i for i in range(num_rings)]) 

201 zero_bg_freq[center_x, center_y:center_y+num_rings] = ring_vector 

202 zero_bg_freq = zero_bg_freq[None, None, ...] 

203 self.ring_vector_np = ring_vector.numpy() 

204 

205 res = torch.zeros(360, size, size) 

206 res[0] = zero_bg_freq 

207 for angle in range(1, 360): 

208 zero_bg_freq_rot = tforms.functional.rotate(zero_bg_freq, angle=angle) 

209 res[angle] = zero_bg_freq_rot 

210 

211 res = res.numpy() 

212 self.res = res 

213 self.pure_bg = np.zeros((size, size)) 

214 for x in range(size): 

215 for y in range(size): 

216 values, count = np.unique(res[:, x, y], return_counts=True) 

217 if len(count) > 2: 

218 self.pure_bg[x, y] = values[count == max(count[values!=0])][0] 

219 elif len(count) == 2: 

220 self.pure_bg[x, y] = values[values!=0][0] 

221 

222 def get_ring_mask(self, r_out, r_in): 

223 # get mask from pure_bg 

224 assert r_out <= self.r_out 

225 if r_in - 1 < 0: 

226 right_end = 0 # None, to take the center 

227 else: 

228 right_end = r_in - 1 

229 cand_list = self.ring_vector_np[r_out-1:right_end:-1] 

230 mask = np.isin(self.pure_bg, cand_list) 

231 if self.size % 2: 

232 mask = mask[:self.size-1, :self.size-1] # [64, 64] 

233 return mask 

234 

235 def _get_watermarking_mask(self, init_latents: torch.Tensor) -> torch.Tensor: 

236 """Get the watermarking mask.""" 

237 shape=(1,4,64,64) 

238 tree_masks = torch.zeros(shape, dtype=torch.bool) # (1,4,64,64) 

239 single_channel_tree_watermark_mask = torch.tensor(self.circle_mask(size=shape[-1], r=14)) # (64,64) 

240 tree_masks[:, [self.config.w_channel]] = single_channel_tree_watermark_mask # (64,64) 

241 masks = tree_masks 

242 mask_obj = self.RounderRingMask(size=65, r_out=14) 

243 single_channel_heter_watermark_mask = torch.tensor(mask_obj.get_ring_mask(r_out=14, r_in=3) ) # (64,64) 

244 masks[:, [0]] = single_channel_heter_watermark_mask # (64,64) RounderRingMask for Hetero Watermark (noise) 

245 return masks 

246 

247 @torch.no_grad() 

248 def inject_wm(self,init_latents: torch.Tensor): 

249 # for HSTR 

250 assert len(self.gt_patch.shape) == 4 

251 assert len(self.watermarking_mask.shape) == 4 

252 batch_size = init_latents.shape[0] 

253 self.watermarking_mask = self.watermarking_mask.repeat(batch_size, 1, 1, 1) 

254 

255 init_latents =init_latents.to(self.config.device) 

256 self.gt_patch = self.gt_patch.to(self.config.device) 

257 self.watermarking_mask = self.watermarking_mask.to(self.config.device) 

258 

259 # inject watermarks in fourier space 

260 h = init_latents.shape[-2] 

261 if h >= 64: 

262 patch_size = 44 

263 else: 

264 patch_size = h - 2 if h > 2 else h 

265 

266 start = (h - patch_size) // 2 

267 end = start + patch_size 

268 

269 center_slice = (slice(None), slice(None), slice(start, end), slice(start, end)) 

270 assert len(init_latents[center_slice].shape) == 4 

271 center_latent_fft=torch.fft.fftshift(torch.fft.fft2(init_latents[center_slice]), dim=(-1, -2))# (N,4,44,44) complex64 

272 #injection 

273 temp_mask = self.watermarking_mask[center_slice] # (N,4,44,44) boolean 

274 temp_pattern = self.gt_patch[center_slice] # (N,4,44,44) complex64 

275 center_latent_fft[temp_mask] = temp_pattern[temp_mask].clone() # (N,4,44,44) complex64 

276 # IFFT 

277 assert len(center_latent_fft.shape) == 4 

278 center_latent_ifft=torch.fft.ifft2(torch.fft.ifftshift(center_latent_fft, dim=(-1, -2)))# (N,4,44,44) 

279 center_latent_ifft = center_latent_ifft.real if center_latent_ifft.imag.abs().max() < 1e-3 else center_latent_ifft 

280 

281 init_latents = init_latents.clone() 

282 init_latents[center_slice] = center_latent_ifft 

283 init_latents[init_latents == float("Inf")] = 4 

284 init_latents[init_latents == float("-Inf")] = -4 

285 return init_latents # float32 

286 

287 @torch.no_grad() 

288 def inject_hsqr(self,inverted_latent): # (N,4,64,64) -> (N,4,64,64) 

289 assert len(self.gt_patch.shape) == 4 # (N,c_wm,42,42) 

290 inverted_latent = inverted_latent.to(self.config.device) 

291 self.gt_patch = self.gt_patch.to(self.config.device) 

292 qr_pix_len = self.gt_patch.shape[-1] # 42 

293 qr_pix_half = (qr_pix_len + 1) // 2 # 21 

294 qr_left = self.gt_patch[:, :, :, :qr_pix_half] # (N,c_wm,42,21) boolean 

295 qr_right = self.gt_patch[:, :, :, qr_pix_half:] # (N,c_wm,42,21) boolean 

296 # rfft 

297 h, w = inverted_latent.shape[-2:] 

298 # The original code used a 44x44 slice for a 42x42 patch (padding of 1 on each side?) 

299 # 54 - 10 = 44.  

300 patch_size = qr_pix_len + 2 # 44 

301 

302 if h < patch_size or w < patch_size: 

303 logging.warning(f"Latent size ({h}x{w}) too small for SFW HSQR injection (required {patch_size}x{patch_size}). Skipping injection.") 

304 return inverted_latent 

305 

306 start_h = (h - patch_size) // 2 

307 start_w = (w - patch_size) // 2 

308 end_h = start_h + patch_size 

309 end_w = start_w + patch_size 

310 

311 center_slice = (slice(None), slice(None), slice(start_h, end_h), slice(start_w, end_w)) 

312 center_latent_rfft = SFWUtils.rfft(inverted_latent[center_slice]) # (N,4,44,44) -> # (N,4,44,23) complex64 

313 center_real_batch = center_latent_rfft.real # (N,4,44,23) f32 

314 center_imag_batch = center_latent_rfft.imag # (N,4,44,23) f32 

315 real_slice = (slice(None),[self.config.w_channel], slice(1, 1+qr_pix_len), slice(1, 1+qr_pix_half)) 

316 imag_slice = (slice(None), [self.config.w_channel], slice(1, 1+qr_pix_len), slice(1, 1+qr_pix_half)) 

317 center_real_batch[real_slice] = torch.where(qr_left, center_real_batch[real_slice].abs() + self.config.delta, -center_real_batch[real_slice].abs() - self.config.delta) 

318 center_imag_batch[imag_slice] = torch.where(qr_right, center_imag_batch[imag_slice].abs() + self.config.delta, -center_imag_batch[imag_slice].abs() - self.config.delta) 

319 center_latent_ifft = SFWUtils.irfft(torch.complex(center_real_batch, center_imag_batch)) # (N,4,44,44) f32 

320 inverted_latent = inverted_latent.clone() 

321 inverted_latent[center_slice] = center_latent_ifft 

322 return inverted_latent # (N,4,64,64) 

323 

324@inherit_docstring 

325class SFW(BaseWatermark): 

326 def __init__(self, 

327 watermark_config: SFWConfig, 

328 *args, **kwargs): 

329 """ 

330 Initialize the SFW watermarking algorithm. 

331  

332 Parameters: 

333 watermark_config (SFWConfig): Configuration instance of the SFW algorithm. 

334 """ 

335 self.config = watermark_config 

336 self.utils = SFWUtils(self.config) 

337 self.detector = SFWDetector( 

338 watermarking_mask=self.utils.watermarking_mask, 

339 gt_patch=self.utils.gt_patch, 

340 w_channel=self.config.w_channel, 

341 threshold=self.config.threshold, 

342 device=self.config.device, 

343 wm_type=self.config.wm_type, 

344 threshold_p_value=self.config.threshold_p_value, 

345 ) 

346 

347 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image: 

348 """Internal method to generate a watermarked image.""" 

349 if(self.config.wm_type=="HSQR"): 

350 watermarked_latents = self.utils.inject_hsqr(self.config.init_latents) 

351 else: 

352 watermarked_latents = self.utils.inject_wm(self.config.init_latents) 

353 

354 # save watermarked latents 

355 self.set_orig_watermarked_latents(watermarked_latents) 

356 

357 # Construct generation parameters 

358 generation_params = { 

359 "num_images_per_prompt": self.config.num_images, 

360 "guidance_scale": self.config.guidance_scale, 

361 "num_inference_steps": self.config.num_inference_steps, 

362 "height": self.config.image_size[0], 

363 "width": self.config.image_size[1], 

364 "latents": watermarked_latents, 

365 } 

366 

367 # Add parameters from config.gen_kwargs 

368 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs: 

369 for key, value in self.config.gen_kwargs.items(): 

370 if key not in generation_params: 

371 generation_params[key] = value 

372 

373 # Use kwargs to override default parameters 

374 for key, value in kwargs.items(): 

375 generation_params[key] = value 

376 

377 # Ensure latents parameter is not overridden 

378 generation_params["latents"] = watermarked_latents 

379 

380 return self.config.pipe( 

381 prompt, 

382 **generation_params 

383 ).images[0] 

384 

385 def _detect_watermark_in_image(self, 

386 image: Image.Image, 

387 prompt: str = "", 

388 *args, 

389 **kwargs) -> Dict[str, float]: 

390 """Detect the watermark in the image.""" 

391 # Use config values as defaults if not explicitly provided 

392 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

393 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps) 

394 

395 # Step 1: Get Text Embeddings 

396 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

397 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

398 prompt=prompt, 

399 device=self.config.device, 

400 do_classifier_free_guidance=do_classifier_free_guidance, 

401 num_images_per_prompt=1, # TODO: Multiple image generation to be supported 

402 ) 

403 

404 if do_classifier_free_guidance: 

405 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

406 else: 

407 text_embeddings = prompt_embeds 

408 

409 # Step 2: Preprocess Image 

410 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

411 

412 # Step 3: Get Image Latents 

413 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get('decoder_inv', False)) 

414 

415 # Step 4: Reverse Image Latents 

416 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters 

417 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']} 

418 

419 reversed_latents = self.config.inversion.forward_diffusion( 

420 latents=image_latents, 

421 text_embeddings=text_embeddings, 

422 guidance_scale=guidance_scale_to_use, 

423 num_inference_steps=num_steps_to_use, 

424 **inversion_kwargs 

425 )[-1] 

426 # Step 5: Evaluate Watermark 

427 if 'detector_type' in kwargs: 

428 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type']) 

429 else: 

430 return self.detector.eval_watermark(reversed_latents) 

431 

432 def get_data_for_visualize(self, 

433 image: Image.Image, 

434 prompt: str="", 

435 guidance_scale: Optional[float]=None, 

436 decoder_inv: bool=False, 

437 *args, 

438 **kwargs) -> DataForVisualization: 

439 """Get data for visualization including detection inversion""" 

440 # Use config values as defaults if not explicitly provided 

441 guidance_scale_to_use = guidance_scale if guidance_scale is not None else self.config.guidance_scale 

442 

443 # Step 1: Generate watermarked latents (generation process) 

444 set_random_seed(self.config.gen_seed) 

445 if (self.config.wm_type=="HSQR"): 

446 watermarked_latents = self.utils.inject_hsqr(self.config.init_latents) 

447 else: 

448 watermarked_latents = self.utils.inject_wm(self.config.init_latents) 

449 

450 # Step 2: Generate actual watermarked image using the same process as _generate_watermarked_image 

451 generation_params = { 

452 "num_images_per_prompt": self.config.num_images, 

453 "guidance_scale": self.config.guidance_scale, 

454 "num_inference_steps": self.config.num_inference_steps, 

455 "height": self.config.image_size[0], 

456 "width": self.config.image_size[1], 

457 "latents": watermarked_latents, 

458 } 

459 

460 # Add parameters from config.gen_kwargs 

461 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs: 

462 for key, value in self.config.gen_kwargs.items(): 

463 if key not in generation_params: 

464 generation_params[key] = value 

465 

466 # Generate the actual watermarked image 

467 watermarked_image = self.config.pipe( 

468 prompt, 

469 **generation_params 

470 ).images[0] 

471 

472 # Step 3: Perform watermark detection to get inverted latents (detection process) 

473 inverted_latents = None 

474 try: 

475 # Get Text Embeddings for detection 

476 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

477 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

478 prompt=prompt, 

479 device=self.config.device, 

480 do_classifier_free_guidance=do_classifier_free_guidance, 

481 num_images_per_prompt=1, # TODO: Multiple image generation to be supported 

482 ) 

483 

484 if do_classifier_free_guidance: 

485 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

486 else: 

487 text_embeddings = prompt_embeds 

488 

489 # Preprocess watermarked image for detection 

490 processed_image = transform_to_model_format( 

491 watermarked_image, 

492 target_size=self.config.image_size[0] 

493 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

494 

495 # Get Image Latents 

496 image_latents = get_media_latents( 

497 pipe=self.config.pipe, 

498 media=processed_image, 

499 sample=False, 

500 decoder_inv=decoder_inv 

501 ) 

502 

503 # Reverse Image Latents to get inverted noise 

504 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['prompt', 'decoder_inv', 'guidance_scale', 'num_inference_steps']} 

505 

506 reversed_latents_list = self.config.inversion.forward_diffusion( 

507 latents=image_latents, 

508 text_embeddings=text_embeddings, 

509 guidance_scale=guidance_scale_to_use, 

510 num_inference_steps=self.config.num_inference_steps, 

511 **inversion_kwargs 

512 ) 

513 

514 inverted_latents = reversed_latents_list[-1] 

515 

516 except Exception as e: 

517 print(f"Warning: Could not perform inversion for visualization: {e}") 

518 inverted_latents = None 

519 

520 # Step 4: Prepare visualization data  

521 return DataForVisualization( 

522 config=self.config, 

523 utils=self.utils, 

524 reversed_latents=reversed_latents_list, 

525 orig_watermarked_latents=self.orig_watermarked_latents, 

526 image=image 

527 )