Coverage for watermark / gm / gm.py: 95.77%

331 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from __future__ import annotations 

2 

3import copy 

4import random 

5from dataclasses import dataclass 

6from functools import reduce 

7from pathlib import Path 

8from typing import Dict, Optional, Tuple, Union 

9 

10import numpy as np 

11import torch 

12from PIL import Image 

13from Crypto.Cipher import ChaCha20 

14from Crypto.Random import get_random_bytes 

15from scipy.special import betainc 

16from scipy.stats import norm, truncnorm 

17from huggingface_hub import hf_hub_download 

18 

19from ..base import BaseConfig, BaseWatermark 

20from utils.media_utils import get_random_latents, get_media_latents, transform_to_model_format 

21from utils.utils import set_random_seed 

22from visualize.data_for_visualization import DataForVisualization 

23from .gnr import GNRRestorer 

24 

25# ----------------------------------------------------------------------------- 

26# Helper utilities adapted from the official GaussMarker implementation 

27# ----------------------------------------------------------------------------- 

28 

29def _bytes_from_seed(seed: Optional[int], length: int) -> bytes: 

30 """Generate deterministic bytes using a Python PRNG seed.""" 

31 if seed is None: 

32 return get_random_bytes(length) 

33 rng = random.Random(seed) 

34 return bytes(rng.getrandbits(8) for _ in range(length)) 

35 

36def circle_mask(size: int, radius: int, x_offset: int = 0, y_offset: int = 0) -> np.ndarray: 

37 """Create a binary circle mask with optional offset.""" 

38 x0 = y0 = size // 2 

39 x0 += x_offset 

40 y0 += y_offset 

41 grid_y, grid_x = np.ogrid[:size, :size] 

42 grid_y = grid_y[::-1] 

43 return ((grid_x - x0) ** 2 + (grid_y - y0) ** 2) <= radius ** 2 

44 

45def extract_complex_sign(complex_tensor: torch.Tensor) -> torch.Tensor: 

46 """Extract complex-valued sign encoding (4-way) from a complex tensor.""" 

47 real = complex_tensor.real 

48 imag = complex_tensor.imag 

49 

50 sign_map_real = (real <= 0).long() 

51 sign_map_imag = (imag <= 0).long() 

52 return 2 * sign_map_real + sign_map_imag 

53 

54# ----------------------------------------------------------------------------- 

55# Gaussian Shading watermark with ChaCha20 encryption (generalised dimensions) 

56# ----------------------------------------------------------------------------- 

57@dataclass 

58class GaussianShadingChaCha: 

59 channel_copy: int 

60 width_copy: int 

61 height_copy: int 

62 fpr: float 

63 user_number: int 

64 latent_channels: int 

65 latent_height: int 

66 latent_width: int 

67 dtype: torch.dtype 

68 device: torch.device 

69 watermark_seed: Optional[int] = None 

70 key_seed: Optional[int] = None 

71 nonce_seed: Optional[int] = None 

72 watermark: Optional[torch.Tensor] = None 

73 key: Optional[bytes] = None 

74 nonce: Optional[bytes] = None 

75 message_bits: Optional[np.ndarray] = None 

76 

77 def __post_init__(self) -> None: 

78 self.latentlength = self.latent_channels * self.latent_height * self.latent_width 

79 divisor = self.channel_copy * self.width_copy * self.height_copy 

80 if self.latentlength % divisor != 0: 

81 raise ValueError( 

82 "Latent volume is not divisible by channel/width/height copies. " 

83 "Please adjust w_copy/h_copy/channel_copy." 

84 ) 

85 self.marklength = self.latentlength // divisor 

86 

87 # Voting thresholds identical to official implementation 

88 if self.channel_copy == 1 and self.width_copy == 1 and self.height_copy == 1: 

89 self.threshold = 1 

90 else: 

91 self.threshold = self.channel_copy * self.width_copy * self.height_copy // 2 

92 

93 self.tau_onebit: Optional[float] = None 

94 self.tau_bits: Optional[float] = None 

95 for i in range(self.marklength): 

96 fpr_onebit = betainc(i + 1, self.marklength - i, 0.5) 

97 fpr_bits = fpr_onebit * self.user_number 

98 if fpr_onebit <= self.fpr and self.tau_onebit is None: 

99 self.tau_onebit = i / self.marklength 

100 if fpr_bits <= self.fpr and self.tau_bits is None: 

101 self.tau_bits = i / self.marklength 

102 

103 # ------------------------------------------------------------------ 

104 # Key/nonce helpers 

105 # ------------------------------------------------------------------ 

106 

107 def _ensure_key_nonce(self) -> None: 

108 if self.key is None: 

109 self.key = _bytes_from_seed(self.key_seed, 32) 

110 if self.nonce is None: 

111 self.nonce = _bytes_from_seed(self.nonce_seed, 12) 

112 

113 # ------------------------------------------------------------------ 

114 # Sampling helpers 

115 # ------------------------------------------------------------------ 

116 

117 def _truncated_sampling(self, message_bits: np.ndarray) -> torch.Tensor: 

118 z = np.zeros(self.latentlength, dtype=np.float32) 

119 denominator = 2.0 

120 ppf = [norm.ppf(j / denominator) for j in range(int(denominator) + 1)] 

121 for idx in range(self.latentlength): 

122 dec_mes = reduce(lambda a, b: 2 * a + b, message_bits[idx : idx + 1]) 

123 dec_mes = int(dec_mes) 

124 z[idx] = truncnorm.rvs(ppf[dec_mes], ppf[dec_mes + 1]) 

125 tensor = torch.from_numpy(z).reshape(1, self.latent_channels, self.latent_height, self.latent_width) 

126 return tensor.to(self.device, dtype=torch.float32) 

127 

128 def _generate_watermark(self) -> None: 

129 generator = torch.Generator(device="cpu") 

130 if self.watermark_seed is not None: 

131 generator.manual_seed(self.watermark_seed) 

132 

133 watermark = torch.randint( 

134 low=0, 

135 high=2, 

136 size=( 

137 1, 

138 self.latent_channels // self.channel_copy, 

139 self.latent_height // self.width_copy, 

140 self.latent_width // self.height_copy, 

141 ), 

142 generator=generator, 

143 dtype=torch.int64, 

144 ) 

145 self.watermark = watermark.to(self.device) 

146 

147 tiled = self.watermark.repeat(1, self.channel_copy, self.width_copy, self.height_copy) 

148 self.message_bits = self._stream_key_encrypt(tiled.flatten().cpu().numpy()) 

149 

150 # ------------------------------------------------------------------ 

151 # Encryption helpers 

152 # ------------------------------------------------------------------ 

153 def _stream_key_encrypt(self, plaintext_bits: np.ndarray) -> np.ndarray: 

154 """Encrypt plaintext bits using ChaCha20 stream cipher.""" 

155 self._ensure_key_nonce() 

156 cipher = ChaCha20.new(key=self.key, nonce=self.nonce) # 初始化 cipher 

157 packed = np.packbits(plaintext_bits).tobytes() 

158 encrypted = cipher.encrypt(packed) 

159 unpacked = np.unpackbits(np.frombuffer(encrypted, dtype=np.uint8)) 

160 return unpacked[: self.latentlength] 

161 

162 def _stream_key_decrypt(self, encrypted_bits: np.ndarray) -> torch.Tensor: 

163 self._ensure_key_nonce() 

164 cipher = ChaCha20.new(key=self.key, nonce=self.nonce) 

165 packed = np.packbits(encrypted_bits).tobytes() 

166 decrypted = cipher.decrypt(packed) 

167 bits = np.unpackbits(np.frombuffer(decrypted, dtype=np.uint8)) 

168 bits = bits[: self.latentlength] 

169 tensor = torch.from_numpy(bits.astype(np.uint8)).reshape( 

170 1, self.latent_channels, self.latent_height, self.latent_width 

171 ) 

172 return tensor.to(self.device) 

173 

174 # ------------------------------------------------------------------ 

175 # Public API 

176 # ------------------------------------------------------------------ 

177 

178 def create_watermark_and_return_w_m(self) -> Tuple[torch.Tensor, torch.Tensor]: 

179 if self.watermark is None or self.message_bits is None: 

180 self._generate_watermark() 

181 message_bits = self.message_bits 

182 sampled = self._truncated_sampling(message_bits) 

183 sampled = sampled.to(self.device, dtype=torch.float32) 

184 m_tensor = torch.from_numpy(message_bits.astype(np.float32)).reshape( 

185 1, self.latent_channels, self.latent_height, self.latent_width 

186 ).to(self.device) 

187 return sampled, m_tensor 

188 

189 def diffusion_inverse(self, spread_tensor: torch.Tensor) -> torch.Tensor: 

190 tensor = spread_tensor.to(self.device).reshape( 

191 1, 

192 self.channel_copy, 

193 self.latent_channels // self.channel_copy, 

194 self.width_copy, 

195 self.latent_height // self.width_copy, 

196 self.height_copy, 

197 self.latent_width // self.height_copy, 

198 ) 

199 # Move channel copy to front, height/width copies accordingly 

200 tensor = tensor.sum(dim=(1, 3, 5)) 

201 vote = tensor.clone() 

202 vote[vote <= self.threshold] = 0 

203 vote[vote > self.threshold] = 1 

204 return vote.to(torch.int64) 

205 

206 def pred_m_from_latent(self, reversed_latents: torch.Tensor) -> torch.Tensor: 

207 return (reversed_latents > 0).int().to(self.device) 

208 

209 def pred_w_from_latent(self, reversed_latents: torch.Tensor) -> torch.Tensor: 

210 reversed_m = self.pred_m_from_latent(reversed_latents) 

211 spread_bits = reversed_m.flatten().detach().cpu().numpy().astype(np.uint8) 

212 decrypted = self._stream_key_decrypt(spread_bits) 

213 return self.diffusion_inverse(decrypted) 

214 

215 def pred_w_from_m(self, reversed_m: torch.Tensor) -> torch.Tensor: 

216 spread_bits = reversed_m.flatten().detach().cpu().numpy().astype(np.uint8) 

217 decrypted = self._stream_key_decrypt(spread_bits) 

218 return self.diffusion_inverse(decrypted) 

219 

220 def watermark_tensor(self, device: Optional[torch.device] = None) -> torch.Tensor: 

221 if self.watermark is None: 

222 self._generate_watermark() 

223 device = device or self.device 

224 return self.watermark.to(device) 

225 

226# ----------------------------------------------------------------------------- 

227# Utility helpers for GaussMarker 

228# ----------------------------------------------------------------------------- 

229 

230class GMUtils: 

231 def __init__(self, config: "GMConfig") -> None: 

232 self.config = config 

233 self.device = config.device 

234 self.latent_shape = ( 

235 1, 

236 config.latent_channels, 

237 config.latent_height, 

238 config.latent_width, 

239 ) 

240 try: 

241 self.pipeline_dtype = next(config.pipe.unet.parameters()).dtype 

242 except StopIteration: 

243 self.pipeline_dtype = config.dtype 

244 

245 watermark_cls = GaussianShadingChaCha 

246 self.watermark_generator = watermark_cls( 

247 channel_copy=config.channel_copy, 

248 width_copy=config.w_copy, 

249 height_copy=config.h_copy, 

250 fpr=config.fpr, 

251 user_number=config.user_number, 

252 latent_channels=config.latent_channels, 

253 latent_height=config.latent_height, 

254 latent_width=config.latent_width, 

255 dtype=torch.float32, 

256 device=torch.device(config.device), 

257 watermark_seed=config.watermark_seed, 

258 key_seed=config.chacha_key_seed, 

259 nonce_seed=config.chacha_nonce_seed, 

260 ) 

261 

262 # Pre-initialize watermark to keep deterministic behaviour 

263 set_random_seed(config.watermark_seed) 

264 self.base_watermark_latents, self.base_message = self.watermark_generator.create_watermark_and_return_w_m() 

265 self.base_message = self.base_message.to(self.device, dtype=torch.float32) 

266 

267 self.radius_list = list(range(config.w_radius, 0, -1)) 

268 self.gt_patch = self._build_watermarking_pattern() 

269 self.watermarking_mask = self._build_watermarking_mask() 

270 

271 # 延迟导入 GMDetector,避免循环导入 

272 from detection.gm.gm_detection import GMDetector 

273 

274 # Build detector (delegates all detection work) 

275 self.detector = GMDetector( 

276 watermark_generator=self.watermark_generator, 

277 watermarking_mask=self.watermarking_mask, 

278 gt_patch=self.gt_patch, 

279 w_measurement=self.config.w_measurement, 

280 device=self.device, 

281 bit_threshold=self.watermark_generator.tau_bits, 

282 message_threshold=self.watermark_generator.tau_onebit, 

283 l1_threshold=None, 

284 gnr_checkpoint=self.config.gnr_checkpoint, 

285 gnr_classifier_type=self.config.gnr_classifier_type, 

286 gnr_model_nf=self.config.gnr_model_nf, 

287 gnr_binary_threshold=self.config.gnr_binary_threshold, 

288 gnr_use_for_decision=self.config.gnr_use_for_decision, 

289 gnr_threshold=self.config.gnr_threshold, 

290 fuser_checkpoint=self.config.fuser_checkpoint, 

291 fuser_threshold=self.config.fuser_threshold, 

292 fuser_frequency_scale=self.config.fuser_frequency_scale, 

293 huggingface_repo=self.config.huggingface_repo, 

294 hf_dir=self.config.hf_dir, 

295 ) 

296 

297 

298 # ------------------------------------------------------------------ 

299 # Pattern / mask construction 

300 # ------------------------------------------------------------------ 

301 def _build_watermarking_pattern(self) -> torch.Tensor: 

302 set_random_seed(self.config.w_seed) 

303 base_latents = get_random_latents( 

304 pipe=self.config.pipe, 

305 height=self.config.image_size[0], 

306 width=self.config.image_size[1], 

307 ).to(self.device, dtype=torch.float32) 

308 

309 pattern = self.config.w_pattern.lower() 

310 if "seed_ring" in pattern: 

311 gt_patch = base_latents.clone() 

312 tmp = copy.deepcopy(gt_patch) 

313 for radius in self.radius_list: 

314 mask = torch.tensor(circle_mask(gt_patch.shape[-1], radius), device=self.device, dtype=torch.bool) 

315 for ch in range(gt_patch.shape[1]): 

316 gt_patch[:, ch, mask] = tmp[0, ch, 0, radius].item() 

317 elif "seed_zeros" in pattern: 

318 gt_patch = torch.zeros_like(base_latents) 

319 elif "seed_rand" in pattern: 

320 gt_patch = base_latents.clone() 

321 elif "rand" in pattern: 

322 gt_patch = torch.fft.fftshift(torch.fft.fft2(base_latents), dim=(-1, -2)) 

323 gt_patch[:] = gt_patch[0] 

324 elif "zeros" in pattern: 

325 gt_patch = torch.fft.fftshift(torch.fft.fft2(base_latents), dim=(-1, -2)) * 0 

326 elif "const" in pattern: 

327 gt_patch = torch.fft.fftshift(torch.fft.fft2(base_latents), dim=(-1, -2)) * 0 

328 gt_patch += self.config.w_pattern_const 

329 elif "signal_ring" in pattern: 

330 gt_patch = torch.randint_like(base_latents, low=0, high=2, dtype=torch.int64) 

331 if self.config.w_length is None: 

332 self.config.w_length = len(self.radius_list) * base_latents.shape[1] 

333 watermark_signal = torch.randint(low=0, high=4, size=(self.config.w_length,)) 

334 idx = 0 

335 for radius in self.radius_list: 

336 mask = torch.tensor(circle_mask(base_latents.shape[-1], radius), device=self.device) 

337 for ch in range(gt_patch.shape[1]): 

338 signal = watermark_signal[idx % len(watermark_signal)].item() 

339 gt_patch[:, ch, mask] = signal 

340 idx += 1 

341 else: # default ring 

342 gt_patch = torch.fft.fftshift(torch.fft.fft2(base_latents), dim=(-1, -2)) 

343 tmp = gt_patch.clone() 

344 for radius in self.radius_list: 

345 mask = torch.tensor(circle_mask(gt_patch.shape[-1], radius), device=self.device, dtype=torch.bool) 

346 for ch in range(gt_patch.shape[1]): 

347 gt_patch[:, ch, mask] = tmp[0, ch, 0, radius].item() 

348 return gt_patch.to(self.device) 

349 

350 def _build_watermarking_mask(self) -> torch.Tensor: 

351 mask = torch.zeros(self.latent_shape, dtype=torch.bool, device=self.device) 

352 shape = self.config.w_mask_shape.lower() 

353 

354 if shape == "circle": 

355 base_mask = torch.tensor(circle_mask(self.latent_shape[-1], self.config.w_radius), device=self.device) 

356 if self.config.w_channel == -1: 

357 mask[:, :, base_mask] = True 

358 else: 

359 mask[:, self.config.w_channel, base_mask] = True 

360 else: 

361 raise NotImplementedError(f"Unsupported watermark mask shape: {shape}") 

362 

363 return mask 

364 

365 # ------------------------------------------------------------------ 

366 # Watermark injection / detection helpers 

367 # ------------------------------------------------------------------ 

368 def _inject_complex(self, latents: torch.Tensor) -> torch.Tensor: 

369 fft_latents = torch.fft.fftshift(torch.fft.fft2(latents), dim=(-1, -2)) 

370 target_patch = self.gt_patch 

371 if not torch.is_complex(target_patch): 

372 real = target_patch.to(torch.float32) 

373 imag = torch.zeros_like(real) 

374 target_patch = torch.complex(real, imag) 

375 target_patch = target_patch.to(fft_latents.dtype) 

376 

377 mask = self.watermarking_mask 

378 if mask.dtype != torch.bool: 

379 fft_latents[mask != 0] = target_patch[mask != 0].clone() 

380 else: 

381 fft_latents[mask] = target_patch[mask].clone() 

382 injected = torch.fft.ifft2(torch.fft.ifftshift(fft_latents, dim=(-1, -2))).real 

383 return injected 

384 

385 def inject_watermark(self, base_latents: torch.Tensor) -> torch.Tensor: 

386 base_latents = base_latents.to(self.device, dtype=torch.float32) 

387 injection = self.config.w_injection.lower() 

388 if "complex" in injection: 

389 watermarked = self._inject_complex(base_latents) 

390 else: 

391 raise NotImplementedError(f"Unsupported injection mode: {self.config.w_injection}") 

392 return watermarked.to(self.config.dtype) 

393 

394 def generate_watermarked_latents(self, seed: Optional[int] = None) -> torch.Tensor: 

395 if seed is None: 

396 seed = self.config.gen_seed 

397 set_random_seed(seed) 

398 sampled_latents, _ = self.watermark_generator.create_watermark_and_return_w_m() 

399 sampled_latents = sampled_latents.to(self.device, dtype=torch.float32) 

400 watermarked = self.inject_watermark(sampled_latents) 

401 target_dtype = self.pipeline_dtype or self.config.dtype 

402 return watermarked.to(target_dtype) 

403 

404# ----------------------------------------------------------------------------- 

405# Configuration for GaussMarker 

406# ----------------------------------------------------------------------------- 

407 

408class GMConfig(BaseConfig): 

409 def initialize_parameters(self) -> None: 

410 cfg = self.config_dict 

411 self.channel_copy = cfg.get("channel_copy", 1) 

412 self.w_copy = cfg.get("w_copy", 8) 

413 self.h_copy = cfg.get("h_copy", 8) 

414 self.user_number = cfg.get("user_number", 1_000_000) 

415 self.fpr = cfg.get("fpr", 1e-6) 

416 self.chacha_key_seed = cfg.get("chacha_key_seed") 

417 self.chacha_nonce_seed = cfg.get("chacha_nonce_seed") 

418 self.watermark_seed = cfg.get("watermark_seed", self.gen_seed) 

419 self.w_seed = cfg.get("w_seed", 999_999) 

420 self.w_channel = cfg.get("w_channel", -1) 

421 self.w_pattern = cfg.get("w_pattern", "ring") 

422 self.w_mask_shape = cfg.get("w_mask_shape", "circle") 

423 self.w_radius = cfg.get("w_radius", 4) 

424 self.w_measurement = cfg.get("w_measurement", "l1_complex") 

425 self.w_injection = cfg.get("w_injection", "complex") 

426 self.w_pattern_const = cfg.get("w_pattern_const", 0.0) 

427 self.w_length = cfg.get("w_length") 

428 

429 self.gnr_checkpoint = cfg.get("gnr_checkpoint") 

430 self.gnr_classifier_type = cfg.get("gnr_classifier_type", 0) 

431 self.gnr_model_nf = cfg.get("gnr_model_nf", 128) 

432 self.gnr_binary_threshold = cfg.get("gnr_binary_threshold", 0.5) 

433 self.gnr_use_for_decision = cfg.get("gnr_use_for_decision", True) 

434 self.gnr_threshold = cfg.get("gnr_threshold") 

435 self.huggingface_repo = cfg.get("huggingface_repo") 

436 self.fuser_checkpoint = cfg.get("fuser_checkpoint") 

437 self.fuser_threshold = cfg.get("fuser_threshold") 

438 self.fuser_frequency_scale = cfg.get("fuser_frequency_scale", 0.01) 

439 self.hf_dir = cfg.get("hf_dir") 

440 

441 self.latent_channels = self.pipe.unet.config.in_channels 

442 self.latent_height = self.image_size[0] // self.pipe.vae_scale_factor 

443 self.latent_width = self.image_size[1] // self.pipe.vae_scale_factor 

444 

445 if self.latent_channels % self.channel_copy != 0: 

446 raise ValueError("channel_copy must divide latent channels") 

447 if self.latent_height % self.w_copy != 0 or self.latent_width % self.h_copy != 0: 

448 raise ValueError("w_copy and h_copy must divide latent spatial dimensions") 

449 

450 @property 

451 def algorithm_name(self) -> str: 

452 return "GM" 

453 

454 

455# ----------------------------------------------------------------------------- 

456# Main GaussMarker watermark class 

457# ----------------------------------------------------------------------------- 

458class GM(BaseWatermark): 

459 def __init__(self, watermark_config: GMConfig, *args, **kwargs) -> None: 

460 self.config = watermark_config 

461 self.utils = GMUtils(self.config) 

462 super().__init__(self.config) 

463 

464 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image: 

465 seed = kwargs.pop("seed", self.config.gen_seed) 

466 watermarked_latents = self.utils.generate_watermarked_latents(seed=seed) 

467 self.set_orig_watermarked_latents(watermarked_latents) 

468 

469 generation_params = { 

470 "num_images_per_prompt": self.config.num_images, 

471 "guidance_scale": kwargs.pop("guidance_scale", self.config.guidance_scale), 

472 "num_inference_steps": kwargs.pop("num_inference_steps", self.config.num_inference_steps), 

473 "height": self.config.image_size[0], 

474 "width": self.config.image_size[1], 

475 "latents": watermarked_latents, 

476 } 

477 

478 for key, value in self.config.gen_kwargs.items(): 

479 generation_params.setdefault(key, value) 

480 generation_params.update(kwargs) 

481 generation_params["latents"] = watermarked_latents 

482 

483 images = self.config.pipe(prompt, **generation_params).images 

484 return images[0] 

485 

486 def _detect_watermark_in_image( 

487 self, 

488 image: Image.Image, 

489 prompt: str = "", 

490 *args, 

491 **kwargs, 

492 ) -> Dict[str, Union[float, bool]]: 

493 guidance_scale = kwargs.get("guidance_scale", self.config.guidance_scale) 

494 num_steps = kwargs.get("num_inference_steps", self.config.num_inference_steps) 

495 

496 do_cfg = guidance_scale > 1.0 

497 prompt_embeds, negative_embeds = self.config.pipe.encode_prompt( 

498 prompt=prompt, 

499 device=self.config.device, 

500 do_classifier_free_guidance=do_cfg, 

501 num_images_per_prompt=1, 

502 ) 

503 if do_cfg: 

504 text_embeddings = torch.cat([negative_embeds, prompt_embeds]) 

505 else: 

506 text_embeddings = prompt_embeds 

507 

508 processed = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to( 

509 text_embeddings.dtype 

510 ).to(self.config.device) 

511 image_latents = get_media_latents( 

512 pipe=self.config.pipe, 

513 media=processed, 

514 sample=False, 

515 decoder_inv=kwargs.get("decoder_inv", False), 

516 ) 

517 

518 inversion_kwargs = { 

519 key: val 

520 for key, val in kwargs.items() 

521 if key not in {"decoder_inv", "guidance_scale", "num_inference_steps", "detector_type"} 

522 } 

523 

524 reversed_series = self.config.inversion.forward_diffusion( 

525 latents=image_latents, 

526 text_embeddings=text_embeddings, 

527 guidance_scale=guidance_scale, 

528 num_inference_steps=num_steps, 

529 **inversion_kwargs, 

530 ) 

531 reversed_latents = reversed_series[-1] 

532 

533 # Delegate detection to GMDetector 

534 return self.utils.detector.eval_watermark( 

535 reversed_latents=reversed_latents, 

536 detector_type=kwargs.get("detector_type", "bit_acc"), 

537 ) 

538 

539 def get_data_for_visualize( 

540 self, 

541 image: Image.Image, 

542 prompt: str = "", 

543 guidance_scale: Optional[float] = None, 

544 decoder_inv: bool = False, 

545 *args, 

546 **kwargs, 

547 ) -> DataForVisualization: 

548 guidance = guidance_scale if guidance_scale is not None else self.config.guidance_scale 

549 set_random_seed(self.config.gen_seed) 

550 watermarked_latents = self.utils.generate_watermarked_latents(seed=self.config.gen_seed) 

551 

552 generation_params = { 

553 "num_images_per_prompt": self.config.num_images, 

554 "guidance_scale": guidance, 

555 "num_inference_steps": self.config.num_inference_steps, 

556 "height": self.config.image_size[0], 

557 "width": self.config.image_size[1], 

558 "latents": watermarked_latents, 

559 } 

560 for key, value in self.config.gen_kwargs.items(): 

561 generation_params.setdefault(key, value) 

562 

563 watermarked_image = self.config.pipe(prompt, **generation_params).images[0] 

564 

565 do_cfg = guidance > 1.0 

566 prompt_embeds, negative_embeds = self.config.pipe.encode_prompt( 

567 prompt=prompt, 

568 device=self.config.device, 

569 do_classifier_free_guidance=do_cfg, 

570 num_images_per_prompt=1, 

571 ) 

572 text_embeddings = torch.cat([negative_embeds, prompt_embeds]) if do_cfg else prompt_embeds 

573 

574 processed = transform_to_model_format(watermarked_image, target_size=self.config.image_size[0]).unsqueeze(0) 

575 processed = processed.to(text_embeddings.dtype).to(self.config.device) 

576 image_latents = get_media_latents( 

577 pipe=self.config.pipe, 

578 media=processed, 

579 sample=False, 

580 decoder_inv=decoder_inv, 

581 ) 

582 

583 reversed_series = self.config.inversion.forward_diffusion( 

584 latents=image_latents, 

585 text_embeddings=text_embeddings, 

586 guidance_scale=guidance, 

587 num_inference_steps=self.config.num_inversion_steps, 

588 ) 

589 

590 return DataForVisualization( 

591 config=self.config, 

592 utils=self.utils, 

593 orig_watermarked_latents=self.get_orig_watermarked_latents(), 

594 reversed_latents=reversed_series, 

595 image=image, 

596 ) 

597