Coverage for watermark / gm / gm.py: 95.77%
331 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from __future__ import annotations
3import copy
4import random
5from dataclasses import dataclass
6from functools import reduce
7from pathlib import Path
8from typing import Dict, Optional, Tuple, Union
10import numpy as np
11import torch
12from PIL import Image
13from Crypto.Cipher import ChaCha20
14from Crypto.Random import get_random_bytes
15from scipy.special import betainc
16from scipy.stats import norm, truncnorm
17from huggingface_hub import hf_hub_download
19from ..base import BaseConfig, BaseWatermark
20from utils.media_utils import get_random_latents, get_media_latents, transform_to_model_format
21from utils.utils import set_random_seed
22from visualize.data_for_visualization import DataForVisualization
23from .gnr import GNRRestorer
25# -----------------------------------------------------------------------------
26# Helper utilities adapted from the official GaussMarker implementation
27# -----------------------------------------------------------------------------
29def _bytes_from_seed(seed: Optional[int], length: int) -> bytes:
30 """Generate deterministic bytes using a Python PRNG seed."""
31 if seed is None:
32 return get_random_bytes(length)
33 rng = random.Random(seed)
34 return bytes(rng.getrandbits(8) for _ in range(length))
36def circle_mask(size: int, radius: int, x_offset: int = 0, y_offset: int = 0) -> np.ndarray:
37 """Create a binary circle mask with optional offset."""
38 x0 = y0 = size // 2
39 x0 += x_offset
40 y0 += y_offset
41 grid_y, grid_x = np.ogrid[:size, :size]
42 grid_y = grid_y[::-1]
43 return ((grid_x - x0) ** 2 + (grid_y - y0) ** 2) <= radius ** 2
45def extract_complex_sign(complex_tensor: torch.Tensor) -> torch.Tensor:
46 """Extract complex-valued sign encoding (4-way) from a complex tensor."""
47 real = complex_tensor.real
48 imag = complex_tensor.imag
50 sign_map_real = (real <= 0).long()
51 sign_map_imag = (imag <= 0).long()
52 return 2 * sign_map_real + sign_map_imag
54# -----------------------------------------------------------------------------
55# Gaussian Shading watermark with ChaCha20 encryption (generalised dimensions)
56# -----------------------------------------------------------------------------
57@dataclass
58class GaussianShadingChaCha:
59 channel_copy: int
60 width_copy: int
61 height_copy: int
62 fpr: float
63 user_number: int
64 latent_channels: int
65 latent_height: int
66 latent_width: int
67 dtype: torch.dtype
68 device: torch.device
69 watermark_seed: Optional[int] = None
70 key_seed: Optional[int] = None
71 nonce_seed: Optional[int] = None
72 watermark: Optional[torch.Tensor] = None
73 key: Optional[bytes] = None
74 nonce: Optional[bytes] = None
75 message_bits: Optional[np.ndarray] = None
77 def __post_init__(self) -> None:
78 self.latentlength = self.latent_channels * self.latent_height * self.latent_width
79 divisor = self.channel_copy * self.width_copy * self.height_copy
80 if self.latentlength % divisor != 0:
81 raise ValueError(
82 "Latent volume is not divisible by channel/width/height copies. "
83 "Please adjust w_copy/h_copy/channel_copy."
84 )
85 self.marklength = self.latentlength // divisor
87 # Voting thresholds identical to official implementation
88 if self.channel_copy == 1 and self.width_copy == 1 and self.height_copy == 1:
89 self.threshold = 1
90 else:
91 self.threshold = self.channel_copy * self.width_copy * self.height_copy // 2
93 self.tau_onebit: Optional[float] = None
94 self.tau_bits: Optional[float] = None
95 for i in range(self.marklength):
96 fpr_onebit = betainc(i + 1, self.marklength - i, 0.5)
97 fpr_bits = fpr_onebit * self.user_number
98 if fpr_onebit <= self.fpr and self.tau_onebit is None:
99 self.tau_onebit = i / self.marklength
100 if fpr_bits <= self.fpr and self.tau_bits is None:
101 self.tau_bits = i / self.marklength
103 # ------------------------------------------------------------------
104 # Key/nonce helpers
105 # ------------------------------------------------------------------
107 def _ensure_key_nonce(self) -> None:
108 if self.key is None:
109 self.key = _bytes_from_seed(self.key_seed, 32)
110 if self.nonce is None:
111 self.nonce = _bytes_from_seed(self.nonce_seed, 12)
113 # ------------------------------------------------------------------
114 # Sampling helpers
115 # ------------------------------------------------------------------
117 def _truncated_sampling(self, message_bits: np.ndarray) -> torch.Tensor:
118 z = np.zeros(self.latentlength, dtype=np.float32)
119 denominator = 2.0
120 ppf = [norm.ppf(j / denominator) for j in range(int(denominator) + 1)]
121 for idx in range(self.latentlength):
122 dec_mes = reduce(lambda a, b: 2 * a + b, message_bits[idx : idx + 1])
123 dec_mes = int(dec_mes)
124 z[idx] = truncnorm.rvs(ppf[dec_mes], ppf[dec_mes + 1])
125 tensor = torch.from_numpy(z).reshape(1, self.latent_channels, self.latent_height, self.latent_width)
126 return tensor.to(self.device, dtype=torch.float32)
128 def _generate_watermark(self) -> None:
129 generator = torch.Generator(device="cpu")
130 if self.watermark_seed is not None:
131 generator.manual_seed(self.watermark_seed)
133 watermark = torch.randint(
134 low=0,
135 high=2,
136 size=(
137 1,
138 self.latent_channels // self.channel_copy,
139 self.latent_height // self.width_copy,
140 self.latent_width // self.height_copy,
141 ),
142 generator=generator,
143 dtype=torch.int64,
144 )
145 self.watermark = watermark.to(self.device)
147 tiled = self.watermark.repeat(1, self.channel_copy, self.width_copy, self.height_copy)
148 self.message_bits = self._stream_key_encrypt(tiled.flatten().cpu().numpy())
150 # ------------------------------------------------------------------
151 # Encryption helpers
152 # ------------------------------------------------------------------
153 def _stream_key_encrypt(self, plaintext_bits: np.ndarray) -> np.ndarray:
154 """Encrypt plaintext bits using ChaCha20 stream cipher."""
155 self._ensure_key_nonce()
156 cipher = ChaCha20.new(key=self.key, nonce=self.nonce) # 初始化 cipher
157 packed = np.packbits(plaintext_bits).tobytes()
158 encrypted = cipher.encrypt(packed)
159 unpacked = np.unpackbits(np.frombuffer(encrypted, dtype=np.uint8))
160 return unpacked[: self.latentlength]
162 def _stream_key_decrypt(self, encrypted_bits: np.ndarray) -> torch.Tensor:
163 self._ensure_key_nonce()
164 cipher = ChaCha20.new(key=self.key, nonce=self.nonce)
165 packed = np.packbits(encrypted_bits).tobytes()
166 decrypted = cipher.decrypt(packed)
167 bits = np.unpackbits(np.frombuffer(decrypted, dtype=np.uint8))
168 bits = bits[: self.latentlength]
169 tensor = torch.from_numpy(bits.astype(np.uint8)).reshape(
170 1, self.latent_channels, self.latent_height, self.latent_width
171 )
172 return tensor.to(self.device)
174 # ------------------------------------------------------------------
175 # Public API
176 # ------------------------------------------------------------------
178 def create_watermark_and_return_w_m(self) -> Tuple[torch.Tensor, torch.Tensor]:
179 if self.watermark is None or self.message_bits is None:
180 self._generate_watermark()
181 message_bits = self.message_bits
182 sampled = self._truncated_sampling(message_bits)
183 sampled = sampled.to(self.device, dtype=torch.float32)
184 m_tensor = torch.from_numpy(message_bits.astype(np.float32)).reshape(
185 1, self.latent_channels, self.latent_height, self.latent_width
186 ).to(self.device)
187 return sampled, m_tensor
189 def diffusion_inverse(self, spread_tensor: torch.Tensor) -> torch.Tensor:
190 tensor = spread_tensor.to(self.device).reshape(
191 1,
192 self.channel_copy,
193 self.latent_channels // self.channel_copy,
194 self.width_copy,
195 self.latent_height // self.width_copy,
196 self.height_copy,
197 self.latent_width // self.height_copy,
198 )
199 # Move channel copy to front, height/width copies accordingly
200 tensor = tensor.sum(dim=(1, 3, 5))
201 vote = tensor.clone()
202 vote[vote <= self.threshold] = 0
203 vote[vote > self.threshold] = 1
204 return vote.to(torch.int64)
206 def pred_m_from_latent(self, reversed_latents: torch.Tensor) -> torch.Tensor:
207 return (reversed_latents > 0).int().to(self.device)
209 def pred_w_from_latent(self, reversed_latents: torch.Tensor) -> torch.Tensor:
210 reversed_m = self.pred_m_from_latent(reversed_latents)
211 spread_bits = reversed_m.flatten().detach().cpu().numpy().astype(np.uint8)
212 decrypted = self._stream_key_decrypt(spread_bits)
213 return self.diffusion_inverse(decrypted)
215 def pred_w_from_m(self, reversed_m: torch.Tensor) -> torch.Tensor:
216 spread_bits = reversed_m.flatten().detach().cpu().numpy().astype(np.uint8)
217 decrypted = self._stream_key_decrypt(spread_bits)
218 return self.diffusion_inverse(decrypted)
220 def watermark_tensor(self, device: Optional[torch.device] = None) -> torch.Tensor:
221 if self.watermark is None:
222 self._generate_watermark()
223 device = device or self.device
224 return self.watermark.to(device)
226# -----------------------------------------------------------------------------
227# Utility helpers for GaussMarker
228# -----------------------------------------------------------------------------
230class GMUtils:
231 def __init__(self, config: "GMConfig") -> None:
232 self.config = config
233 self.device = config.device
234 self.latent_shape = (
235 1,
236 config.latent_channels,
237 config.latent_height,
238 config.latent_width,
239 )
240 try:
241 self.pipeline_dtype = next(config.pipe.unet.parameters()).dtype
242 except StopIteration:
243 self.pipeline_dtype = config.dtype
245 watermark_cls = GaussianShadingChaCha
246 self.watermark_generator = watermark_cls(
247 channel_copy=config.channel_copy,
248 width_copy=config.w_copy,
249 height_copy=config.h_copy,
250 fpr=config.fpr,
251 user_number=config.user_number,
252 latent_channels=config.latent_channels,
253 latent_height=config.latent_height,
254 latent_width=config.latent_width,
255 dtype=torch.float32,
256 device=torch.device(config.device),
257 watermark_seed=config.watermark_seed,
258 key_seed=config.chacha_key_seed,
259 nonce_seed=config.chacha_nonce_seed,
260 )
262 # Pre-initialize watermark to keep deterministic behaviour
263 set_random_seed(config.watermark_seed)
264 self.base_watermark_latents, self.base_message = self.watermark_generator.create_watermark_and_return_w_m()
265 self.base_message = self.base_message.to(self.device, dtype=torch.float32)
267 self.radius_list = list(range(config.w_radius, 0, -1))
268 self.gt_patch = self._build_watermarking_pattern()
269 self.watermarking_mask = self._build_watermarking_mask()
271 # 延迟导入 GMDetector,避免循环导入
272 from detection.gm.gm_detection import GMDetector
274 # Build detector (delegates all detection work)
275 self.detector = GMDetector(
276 watermark_generator=self.watermark_generator,
277 watermarking_mask=self.watermarking_mask,
278 gt_patch=self.gt_patch,
279 w_measurement=self.config.w_measurement,
280 device=self.device,
281 bit_threshold=self.watermark_generator.tau_bits,
282 message_threshold=self.watermark_generator.tau_onebit,
283 l1_threshold=None,
284 gnr_checkpoint=self.config.gnr_checkpoint,
285 gnr_classifier_type=self.config.gnr_classifier_type,
286 gnr_model_nf=self.config.gnr_model_nf,
287 gnr_binary_threshold=self.config.gnr_binary_threshold,
288 gnr_use_for_decision=self.config.gnr_use_for_decision,
289 gnr_threshold=self.config.gnr_threshold,
290 fuser_checkpoint=self.config.fuser_checkpoint,
291 fuser_threshold=self.config.fuser_threshold,
292 fuser_frequency_scale=self.config.fuser_frequency_scale,
293 huggingface_repo=self.config.huggingface_repo,
294 hf_dir=self.config.hf_dir,
295 )
298 # ------------------------------------------------------------------
299 # Pattern / mask construction
300 # ------------------------------------------------------------------
301 def _build_watermarking_pattern(self) -> torch.Tensor:
302 set_random_seed(self.config.w_seed)
303 base_latents = get_random_latents(
304 pipe=self.config.pipe,
305 height=self.config.image_size[0],
306 width=self.config.image_size[1],
307 ).to(self.device, dtype=torch.float32)
309 pattern = self.config.w_pattern.lower()
310 if "seed_ring" in pattern:
311 gt_patch = base_latents.clone()
312 tmp = copy.deepcopy(gt_patch)
313 for radius in self.radius_list:
314 mask = torch.tensor(circle_mask(gt_patch.shape[-1], radius), device=self.device, dtype=torch.bool)
315 for ch in range(gt_patch.shape[1]):
316 gt_patch[:, ch, mask] = tmp[0, ch, 0, radius].item()
317 elif "seed_zeros" in pattern:
318 gt_patch = torch.zeros_like(base_latents)
319 elif "seed_rand" in pattern:
320 gt_patch = base_latents.clone()
321 elif "rand" in pattern:
322 gt_patch = torch.fft.fftshift(torch.fft.fft2(base_latents), dim=(-1, -2))
323 gt_patch[:] = gt_patch[0]
324 elif "zeros" in pattern:
325 gt_patch = torch.fft.fftshift(torch.fft.fft2(base_latents), dim=(-1, -2)) * 0
326 elif "const" in pattern:
327 gt_patch = torch.fft.fftshift(torch.fft.fft2(base_latents), dim=(-1, -2)) * 0
328 gt_patch += self.config.w_pattern_const
329 elif "signal_ring" in pattern:
330 gt_patch = torch.randint_like(base_latents, low=0, high=2, dtype=torch.int64)
331 if self.config.w_length is None:
332 self.config.w_length = len(self.radius_list) * base_latents.shape[1]
333 watermark_signal = torch.randint(low=0, high=4, size=(self.config.w_length,))
334 idx = 0
335 for radius in self.radius_list:
336 mask = torch.tensor(circle_mask(base_latents.shape[-1], radius), device=self.device)
337 for ch in range(gt_patch.shape[1]):
338 signal = watermark_signal[idx % len(watermark_signal)].item()
339 gt_patch[:, ch, mask] = signal
340 idx += 1
341 else: # default ring
342 gt_patch = torch.fft.fftshift(torch.fft.fft2(base_latents), dim=(-1, -2))
343 tmp = gt_patch.clone()
344 for radius in self.radius_list:
345 mask = torch.tensor(circle_mask(gt_patch.shape[-1], radius), device=self.device, dtype=torch.bool)
346 for ch in range(gt_patch.shape[1]):
347 gt_patch[:, ch, mask] = tmp[0, ch, 0, radius].item()
348 return gt_patch.to(self.device)
350 def _build_watermarking_mask(self) -> torch.Tensor:
351 mask = torch.zeros(self.latent_shape, dtype=torch.bool, device=self.device)
352 shape = self.config.w_mask_shape.lower()
354 if shape == "circle":
355 base_mask = torch.tensor(circle_mask(self.latent_shape[-1], self.config.w_radius), device=self.device)
356 if self.config.w_channel == -1:
357 mask[:, :, base_mask] = True
358 else:
359 mask[:, self.config.w_channel, base_mask] = True
360 else:
361 raise NotImplementedError(f"Unsupported watermark mask shape: {shape}")
363 return mask
365 # ------------------------------------------------------------------
366 # Watermark injection / detection helpers
367 # ------------------------------------------------------------------
368 def _inject_complex(self, latents: torch.Tensor) -> torch.Tensor:
369 fft_latents = torch.fft.fftshift(torch.fft.fft2(latents), dim=(-1, -2))
370 target_patch = self.gt_patch
371 if not torch.is_complex(target_patch):
372 real = target_patch.to(torch.float32)
373 imag = torch.zeros_like(real)
374 target_patch = torch.complex(real, imag)
375 target_patch = target_patch.to(fft_latents.dtype)
377 mask = self.watermarking_mask
378 if mask.dtype != torch.bool:
379 fft_latents[mask != 0] = target_patch[mask != 0].clone()
380 else:
381 fft_latents[mask] = target_patch[mask].clone()
382 injected = torch.fft.ifft2(torch.fft.ifftshift(fft_latents, dim=(-1, -2))).real
383 return injected
385 def inject_watermark(self, base_latents: torch.Tensor) -> torch.Tensor:
386 base_latents = base_latents.to(self.device, dtype=torch.float32)
387 injection = self.config.w_injection.lower()
388 if "complex" in injection:
389 watermarked = self._inject_complex(base_latents)
390 else:
391 raise NotImplementedError(f"Unsupported injection mode: {self.config.w_injection}")
392 return watermarked.to(self.config.dtype)
394 def generate_watermarked_latents(self, seed: Optional[int] = None) -> torch.Tensor:
395 if seed is None:
396 seed = self.config.gen_seed
397 set_random_seed(seed)
398 sampled_latents, _ = self.watermark_generator.create_watermark_and_return_w_m()
399 sampled_latents = sampled_latents.to(self.device, dtype=torch.float32)
400 watermarked = self.inject_watermark(sampled_latents)
401 target_dtype = self.pipeline_dtype or self.config.dtype
402 return watermarked.to(target_dtype)
404# -----------------------------------------------------------------------------
405# Configuration for GaussMarker
406# -----------------------------------------------------------------------------
408class GMConfig(BaseConfig):
409 def initialize_parameters(self) -> None:
410 cfg = self.config_dict
411 self.channel_copy = cfg.get("channel_copy", 1)
412 self.w_copy = cfg.get("w_copy", 8)
413 self.h_copy = cfg.get("h_copy", 8)
414 self.user_number = cfg.get("user_number", 1_000_000)
415 self.fpr = cfg.get("fpr", 1e-6)
416 self.chacha_key_seed = cfg.get("chacha_key_seed")
417 self.chacha_nonce_seed = cfg.get("chacha_nonce_seed")
418 self.watermark_seed = cfg.get("watermark_seed", self.gen_seed)
419 self.w_seed = cfg.get("w_seed", 999_999)
420 self.w_channel = cfg.get("w_channel", -1)
421 self.w_pattern = cfg.get("w_pattern", "ring")
422 self.w_mask_shape = cfg.get("w_mask_shape", "circle")
423 self.w_radius = cfg.get("w_radius", 4)
424 self.w_measurement = cfg.get("w_measurement", "l1_complex")
425 self.w_injection = cfg.get("w_injection", "complex")
426 self.w_pattern_const = cfg.get("w_pattern_const", 0.0)
427 self.w_length = cfg.get("w_length")
429 self.gnr_checkpoint = cfg.get("gnr_checkpoint")
430 self.gnr_classifier_type = cfg.get("gnr_classifier_type", 0)
431 self.gnr_model_nf = cfg.get("gnr_model_nf", 128)
432 self.gnr_binary_threshold = cfg.get("gnr_binary_threshold", 0.5)
433 self.gnr_use_for_decision = cfg.get("gnr_use_for_decision", True)
434 self.gnr_threshold = cfg.get("gnr_threshold")
435 self.huggingface_repo = cfg.get("huggingface_repo")
436 self.fuser_checkpoint = cfg.get("fuser_checkpoint")
437 self.fuser_threshold = cfg.get("fuser_threshold")
438 self.fuser_frequency_scale = cfg.get("fuser_frequency_scale", 0.01)
439 self.hf_dir = cfg.get("hf_dir")
441 self.latent_channels = self.pipe.unet.config.in_channels
442 self.latent_height = self.image_size[0] // self.pipe.vae_scale_factor
443 self.latent_width = self.image_size[1] // self.pipe.vae_scale_factor
445 if self.latent_channels % self.channel_copy != 0:
446 raise ValueError("channel_copy must divide latent channels")
447 if self.latent_height % self.w_copy != 0 or self.latent_width % self.h_copy != 0:
448 raise ValueError("w_copy and h_copy must divide latent spatial dimensions")
450 @property
451 def algorithm_name(self) -> str:
452 return "GM"
455# -----------------------------------------------------------------------------
456# Main GaussMarker watermark class
457# -----------------------------------------------------------------------------
458class GM(BaseWatermark):
459 def __init__(self, watermark_config: GMConfig, *args, **kwargs) -> None:
460 self.config = watermark_config
461 self.utils = GMUtils(self.config)
462 super().__init__(self.config)
464 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image:
465 seed = kwargs.pop("seed", self.config.gen_seed)
466 watermarked_latents = self.utils.generate_watermarked_latents(seed=seed)
467 self.set_orig_watermarked_latents(watermarked_latents)
469 generation_params = {
470 "num_images_per_prompt": self.config.num_images,
471 "guidance_scale": kwargs.pop("guidance_scale", self.config.guidance_scale),
472 "num_inference_steps": kwargs.pop("num_inference_steps", self.config.num_inference_steps),
473 "height": self.config.image_size[0],
474 "width": self.config.image_size[1],
475 "latents": watermarked_latents,
476 }
478 for key, value in self.config.gen_kwargs.items():
479 generation_params.setdefault(key, value)
480 generation_params.update(kwargs)
481 generation_params["latents"] = watermarked_latents
483 images = self.config.pipe(prompt, **generation_params).images
484 return images[0]
486 def _detect_watermark_in_image(
487 self,
488 image: Image.Image,
489 prompt: str = "",
490 *args,
491 **kwargs,
492 ) -> Dict[str, Union[float, bool]]:
493 guidance_scale = kwargs.get("guidance_scale", self.config.guidance_scale)
494 num_steps = kwargs.get("num_inference_steps", self.config.num_inference_steps)
496 do_cfg = guidance_scale > 1.0
497 prompt_embeds, negative_embeds = self.config.pipe.encode_prompt(
498 prompt=prompt,
499 device=self.config.device,
500 do_classifier_free_guidance=do_cfg,
501 num_images_per_prompt=1,
502 )
503 if do_cfg:
504 text_embeddings = torch.cat([negative_embeds, prompt_embeds])
505 else:
506 text_embeddings = prompt_embeds
508 processed = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(
509 text_embeddings.dtype
510 ).to(self.config.device)
511 image_latents = get_media_latents(
512 pipe=self.config.pipe,
513 media=processed,
514 sample=False,
515 decoder_inv=kwargs.get("decoder_inv", False),
516 )
518 inversion_kwargs = {
519 key: val
520 for key, val in kwargs.items()
521 if key not in {"decoder_inv", "guidance_scale", "num_inference_steps", "detector_type"}
522 }
524 reversed_series = self.config.inversion.forward_diffusion(
525 latents=image_latents,
526 text_embeddings=text_embeddings,
527 guidance_scale=guidance_scale,
528 num_inference_steps=num_steps,
529 **inversion_kwargs,
530 )
531 reversed_latents = reversed_series[-1]
533 # Delegate detection to GMDetector
534 return self.utils.detector.eval_watermark(
535 reversed_latents=reversed_latents,
536 detector_type=kwargs.get("detector_type", "bit_acc"),
537 )
539 def get_data_for_visualize(
540 self,
541 image: Image.Image,
542 prompt: str = "",
543 guidance_scale: Optional[float] = None,
544 decoder_inv: bool = False,
545 *args,
546 **kwargs,
547 ) -> DataForVisualization:
548 guidance = guidance_scale if guidance_scale is not None else self.config.guidance_scale
549 set_random_seed(self.config.gen_seed)
550 watermarked_latents = self.utils.generate_watermarked_latents(seed=self.config.gen_seed)
552 generation_params = {
553 "num_images_per_prompt": self.config.num_images,
554 "guidance_scale": guidance,
555 "num_inference_steps": self.config.num_inference_steps,
556 "height": self.config.image_size[0],
557 "width": self.config.image_size[1],
558 "latents": watermarked_latents,
559 }
560 for key, value in self.config.gen_kwargs.items():
561 generation_params.setdefault(key, value)
563 watermarked_image = self.config.pipe(prompt, **generation_params).images[0]
565 do_cfg = guidance > 1.0
566 prompt_embeds, negative_embeds = self.config.pipe.encode_prompt(
567 prompt=prompt,
568 device=self.config.device,
569 do_classifier_free_guidance=do_cfg,
570 num_images_per_prompt=1,
571 )
572 text_embeddings = torch.cat([negative_embeds, prompt_embeds]) if do_cfg else prompt_embeds
574 processed = transform_to_model_format(watermarked_image, target_size=self.config.image_size[0]).unsqueeze(0)
575 processed = processed.to(text_embeddings.dtype).to(self.config.device)
576 image_latents = get_media_latents(
577 pipe=self.config.pipe,
578 media=processed,
579 sample=False,
580 decoder_inv=decoder_inv,
581 )
583 reversed_series = self.config.inversion.forward_diffusion(
584 latents=image_latents,
585 text_embeddings=text_embeddings,
586 guidance_scale=guidance,
587 num_inference_steps=self.config.num_inversion_steps,
588 )
590 return DataForVisualization(
591 config=self.config,
592 utils=self.utils,
593 orig_watermarked_latents=self.get_orig_watermarked_latents(),
594 reversed_latents=reversed_series,
595 image=image,
596 )