Coverage for markdiffusion / watermark / sfw / sfw.py: 91.15%
305 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-14 19:25 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-14 19:25 +0000
1from ..base import BaseWatermark, BaseConfig
2from markdiffusion.utils.media_utils import *
3import torch
4from typing import Dict, Optional
5from markdiffusion.utils.utils import set_random_seed, inherit_docstring
6from markdiffusion.utils.diffusion_config import DiffusionConfig
7import numpy as np
8from PIL import Image
9from markdiffusion.visualize.data_for_visualization import DataForVisualization
10from markdiffusion.detection.sfw.sfw_detection import SFWDetector
11import torchvision.transforms as tforms
12import qrcode
13import logging
14import os
16class SFWConfig(BaseConfig):
17 """Config class for SFW algorithm, load config file and initialize parameters."""
19 def initialize_parameters(self) -> None:
20 """Initialize algorithm-specific parameters."""
21 self.w_seed = self.config_dict['w_seed']
22 self.delta=self.config_dict['delta']
23 self.wm_type=self.config_dict['wm_type'] # "HSTR" or "HSQR"
24 self.threshold = self.config_dict['threshold']
25 self.threshold_p_value = self.config_dict.get('threshold_p_value', 0.01)
26 self.w_channel = self.config_dict['w_channel']
28 @property
29 def algorithm_name(self) -> str:
30 """Return the algorithm name."""
31 return 'SFW'
33class SFWUtils:
34 """Utility class for SFW algorithm, contains helper functions."""
36 def __init__(self, config: SFWConfig, *args, **kwargs) -> None:
37 """
38 Initialize the SFW watermarking algorithm.
40 Parameters:
41 config (SFWConfig): Configuration for the SFW algorithm.
42 """
43 self.config = config
44 self.gt_patch = self._get_watermarking_pattern()
45 self.watermarking_mask = self._get_watermarking_mask(self.config.init_latents)
47 # [Fourier transforms]
48 @staticmethod
49 def fft(input_tensor):
50 assert len(input_tensor.shape) == 4
51 return torch.fft.fftshift(torch.fft.fft2(input_tensor), dim=(-1, -2))
53 @staticmethod
54 def ifft(input_tensor):
55 assert len(input_tensor.shape) == 4
56 return torch.fft.ifft2(torch.fft.ifftshift(input_tensor, dim=(-1, -2)))
58 @staticmethod
59 @torch.no_grad()
60 def rfft(input_tensor):
61 assert len(input_tensor.shape) == 4
62 return torch.fft.fftshift(torch.fft.rfft2(input_tensor, dim=(-2,-1)), dim=-2)
64 @staticmethod
65 @torch.no_grad()
66 def irfft(input_tensor):
67 assert len(input_tensor.shape) == 4
68 return torch.fft.irfft2(torch.fft.ifftshift(input_tensor, dim=-2), dim=(-2,-1), s=(input_tensor.shape[-2],input_tensor.shape[-2]))
70 def circle_mask(self,size: int, r=16, x_offset=0, y_offset=0):
71 x0 = y0 = size // 2
72 x0 += x_offset
73 y0 += y_offset
74 y, x = np.ogrid[:size, :size]
75 return ((x - x0)**2 + (y-y0)**2)<= r**2
77 @torch.no_grad()
78 def enforce_hermitian_symmetry(self,freq_tensor):
79 B, C, H, W = freq_tensor.shape # fftshifted frequency (complex tensor) - center (32,32)
80 assert H == W, "H != W"
81 freq_tensor = freq_tensor.clone()
82 freq_tensor_tmp = freq_tensor.clone()
83 # DC point (no imaginary)
84 freq_tensor[:, :, H//2, W//2] = torch.real(freq_tensor_tmp[:, :, H//2, W//2])
85 if H % 2 == 0: # Even
86 # Nyquist Points (no imaginary)
87 freq_tensor[:, :, 0, 0] = torch.real(freq_tensor_tmp[:, :, 0, 0])
88 freq_tensor[:, :, H//2, 0] = torch.real(freq_tensor_tmp[:, :, H//2, 0]) # (32, 0)
89 freq_tensor[:, :, 0, W//2] = torch.real(freq_tensor_tmp[:, :, 0, W//2]) # (0, 32)
91 # Nyquist axis - conjugate
92 freq_tensor[:, :, 0, 1:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, 0, W//2+1:], dims=[2]))
93 freq_tensor[:, :, H//2, 1:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2, W//2+1:], dims=[2]))
94 freq_tensor[:, :, 1:H//2, 0] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, 0], dims=[2]))
95 freq_tensor[:, :, 1:H//2, W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, W//2], dims=[2]))
96 # Square quadrants - conjugate
97 freq_tensor[:, :, 1:H//2, 1:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, W//2+1:], dims=[2, 3]))
98 freq_tensor[:, :, H//2+1:, 1:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, 1:H//2, W//2+1:], dims=[2, 3]))
99 else: # Odd
100 # Nyquist axis - conjugate
101 freq_tensor[:, :, H//2, 0:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2, W//2+1:], dims=[2]))
102 freq_tensor[:, :, 0:H//2, W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, W//2], dims=[2]))
103 # Square quadrants - conjugate
104 freq_tensor[:, :, 0:H//2, 0:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, W//2+1:], dims=[2, 3]))
105 freq_tensor[:, :, H//2+1:, 0:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, 0:H//2, W//2+1:], dims=[2, 3]))
106 return freq_tensor
108 @torch.no_grad()
109 def make_Fourier_treering_pattern(self,pipe, shape, w_seed=999999, resolution=512):
110 assert shape[-1] == shape[-2] # 64==64
111 g = torch.Generator(device=self.config.device).manual_seed(w_seed)
112 gt_init = pipe.prepare_latents(1, pipe.unet.in_channels, resolution, resolution, pipe.unet.dtype, torch.device(self.config.device), g) # (1,4,64,64)
113 # [HSTR] center-aware design
114 watermarked_latents_fft = SFWUtils.fft(torch.zeros(shape, device=self.config.device)) # (1,4,64,64) complex64
116 h = shape[-2]
117 # Use a relative size or default to 44 if possible
118 if h >= 64:
119 patch_size = 44
120 else:
121 patch_size = h - 2 if h > 2 else h # Use almost full size for small latents
123 start = (h - patch_size) // 2
124 end = start + patch_size
126 center_slice = (slice(None), slice(None), slice(start, end), slice(start, end))
127 gt_patch_tmp = SFWUtils.fft(gt_init[center_slice]).clone().detach() # (1,4,44,44) complex64
128 center_len = gt_patch_tmp.shape[-1] // 2 # 22
129 for radius in range(center_len-1, 0, -1): # [21,20,...,1]
130 tmp_mask = torch.tensor(self.circle_mask(size=shape[-1], r=radius)) # (64,64)
131 for j in range(watermarked_latents_fft.shape[1]): # GT : all channel Tree-Ring
132 watermarked_latents_fft[:, j, tmp_mask] = gt_patch_tmp[0, j, center_len, center_len + radius].item() # Use (22,22+radius) element.
133 # Gaussian noise key (Heterogenous watermark in RingID)
134 watermarked_latents_fft[:,[0], start:end, start:end] = gt_patch_tmp[:, [0]] # (1,1,44,44) complex64
135 # [Hermitian Symmetric Fourier] HSTR
136 return self.enforce_hermitian_symmetry(watermarked_latents_fft)
138 # HSQR - hermitian symmetric QR pattern
139 class QRCodeGenerator:
140 def __init__(self, box_size=2, border=1, qr_version=1):
141 self.qr = qrcode.QRCode(version=qr_version, box_size=box_size, border=border,
142 error_correction = qrcode.constants.ERROR_CORRECT_H)
144 def make_qr_tensor(self, data, filename='qrcode.png', save_img=False):
145 self.qr.add_data(data)
146 self.qr.make(fit=True)
147 img = self.qr.make_image(fill_color="black", back_color="white")
148 if save_img:
149 img.save(filename)
150 self.clear()
151 img_array = np.array(img)
152 tensor = torch.from_numpy(img_array)
153 return tensor.clone().detach() # boolean (h,w)
155 def clear(self):
156 self.qr.clear()
158 @torch.no_grad()
159 def make_hsqr_pattern(self,idx: int):
160 qr_generator = self.QRCodeGenerator(box_size=2, border=0, qr_version=1)
161 data = f"HSQR{idx % 10000}"
162 qr_tensor = qr_generator.make_qr_tensor(data=data) # (42,42) boolean tensor
163 qr_tensor = qr_tensor.repeat(len([self.config.w_channel]), 1, 1) # (c_wm,42,42) boolean tensor
164 return qr_tensor # (c_wm,42,42) boolean tensor
166 def _get_watermarking_pattern(self) -> torch.Tensor:
167 """Get the ground truth watermarking pattern."""
168 set_random_seed(self.config.w_seed)
169 latent_h = self.config.image_size[0] // 8
170 latent_w = self.config.image_size[1] // 8
171 shape = (1, 4, latent_h, latent_w)
172 if self.config.wm_type == "HSQR":
173 Fourier_watermark_pattern_list = [self.make_hsqr_pattern(idx=self.config.w_seed)]
174 else:
175 Fourier_watermark_pattern_list = [self.make_Fourier_treering_pattern(self.config.pipe, shape, self.config.w_seed)]
176 pattern_gt_batch = [Fourier_watermark_pattern_list[0]]
177 # adjust dims of pattern_gt_batch
178 if len(pattern_gt_batch[0].shape) == 4:
179 pattern_gt_batch = torch.cat(pattern_gt_batch, dim=0) # (N,4,64,64) for HSTR
180 elif len(pattern_gt_batch[0].shape) == 3:
181 pattern_gt_batch = torch.stack(pattern_gt_batch, dim=0) # (N,c_wm,42,42) for HSQR
182 else:
183 raise ValueError(f"Unexpected pattern_gt_batch shape: {pattern_gt_batch[0].shape}")
184 assert len(pattern_gt_batch.shape) == 4
186 return pattern_gt_batch
188 # ring mask
189 class RounderRingMask:
190 def __init__(self, size=65, r_out=14):
191 assert size >= 3
192 self.size = size
193 self.r_out = r_out
195 num_rings = r_out
196 zero_bg_freq = torch.zeros(size, size)
197 center = size // 2
198 center_x, center_y = center, center
200 ring_vector = torch.tensor([(200 - i*4) * (-1)**i for i in range(num_rings)])
201 zero_bg_freq[center_x, center_y:center_y+num_rings] = ring_vector
202 zero_bg_freq = zero_bg_freq[None, None, ...]
203 self.ring_vector_np = ring_vector.numpy()
205 res = torch.zeros(360, size, size)
206 res[0] = zero_bg_freq
207 for angle in range(1, 360):
208 zero_bg_freq_rot = tforms.functional.rotate(zero_bg_freq, angle=angle)
209 res[angle] = zero_bg_freq_rot
211 res = res.numpy()
212 self.res = res
213 self.pure_bg = np.zeros((size, size))
214 for x in range(size):
215 for y in range(size):
216 values, count = np.unique(res[:, x, y], return_counts=True)
217 if len(count) > 2:
218 self.pure_bg[x, y] = values[count == max(count[values!=0])][0]
219 elif len(count) == 2:
220 self.pure_bg[x, y] = values[values!=0][0]
222 def get_ring_mask(self, r_out, r_in):
223 # get mask from pure_bg
224 assert r_out <= self.r_out
225 if r_in - 1 < 0:
226 right_end = 0 # None, to take the center
227 else:
228 right_end = r_in - 1
229 cand_list = self.ring_vector_np[r_out-1:right_end:-1]
230 mask = np.isin(self.pure_bg, cand_list)
231 if self.size % 2:
232 mask = mask[:self.size-1, :self.size-1] # [64, 64]
233 return mask
235 def _get_watermarking_mask(self, init_latents: torch.Tensor) -> torch.Tensor:
236 """Get the watermarking mask."""
237 shape=(1,4,64,64)
238 tree_masks = torch.zeros(shape, dtype=torch.bool) # (1,4,64,64)
239 single_channel_tree_watermark_mask = torch.tensor(self.circle_mask(size=shape[-1], r=14)) # (64,64)
240 tree_masks[:, [self.config.w_channel]] = single_channel_tree_watermark_mask # (64,64)
241 masks = tree_masks
242 mask_obj = self.RounderRingMask(size=65, r_out=14)
243 single_channel_heter_watermark_mask = torch.tensor(mask_obj.get_ring_mask(r_out=14, r_in=3) ) # (64,64)
244 masks[:, [0]] = single_channel_heter_watermark_mask # (64,64) RounderRingMask for Hetero Watermark (noise)
245 return masks
247 @torch.no_grad()
248 def inject_wm(self,init_latents: torch.Tensor):
249 # for HSTR
250 assert len(self.gt_patch.shape) == 4
251 assert len(self.watermarking_mask.shape) == 4
252 batch_size = init_latents.shape[0]
253 self.watermarking_mask = self.watermarking_mask.repeat(batch_size, 1, 1, 1)
255 init_latents =init_latents.to(self.config.device)
256 self.gt_patch = self.gt_patch.to(self.config.device)
257 self.watermarking_mask = self.watermarking_mask.to(self.config.device)
259 # inject watermarks in fourier space
260 h = init_latents.shape[-2]
261 if h >= 64:
262 patch_size = 44
263 else:
264 patch_size = h - 2 if h > 2 else h
266 start = (h - patch_size) // 2
267 end = start + patch_size
269 center_slice = (slice(None), slice(None), slice(start, end), slice(start, end))
270 assert len(init_latents[center_slice].shape) == 4
271 center_latent_fft=torch.fft.fftshift(torch.fft.fft2(init_latents[center_slice]), dim=(-1, -2))# (N,4,44,44) complex64
272 #injection
273 temp_mask = self.watermarking_mask[center_slice] # (N,4,44,44) boolean
274 temp_pattern = self.gt_patch[center_slice] # (N,4,44,44) complex64
275 center_latent_fft[temp_mask] = temp_pattern[temp_mask].clone() # (N,4,44,44) complex64
276 # IFFT
277 assert len(center_latent_fft.shape) == 4
278 center_latent_ifft=torch.fft.ifft2(torch.fft.ifftshift(center_latent_fft, dim=(-1, -2)))# (N,4,44,44)
279 center_latent_ifft = center_latent_ifft.real if center_latent_ifft.imag.abs().max() < 1e-3 else center_latent_ifft
281 init_latents = init_latents.clone()
282 init_latents[center_slice] = center_latent_ifft
283 init_latents[init_latents == float("Inf")] = 4
284 init_latents[init_latents == float("-Inf")] = -4
285 return init_latents # float32
287 @torch.no_grad()
288 def inject_hsqr(self,inverted_latent): # (N,4,64,64) -> (N,4,64,64)
289 assert len(self.gt_patch.shape) == 4 # (N,c_wm,42,42)
290 inverted_latent = inverted_latent.to(self.config.device)
291 self.gt_patch = self.gt_patch.to(self.config.device)
292 qr_pix_len = self.gt_patch.shape[-1] # 42
293 qr_pix_half = (qr_pix_len + 1) // 2 # 21
294 qr_left = self.gt_patch[:, :, :, :qr_pix_half] # (N,c_wm,42,21) boolean
295 qr_right = self.gt_patch[:, :, :, qr_pix_half:] # (N,c_wm,42,21) boolean
296 # rfft
297 h, w = inverted_latent.shape[-2:]
298 # The original code used a 44x44 slice for a 42x42 patch (padding of 1 on each side?)
299 # 54 - 10 = 44.
300 patch_size = qr_pix_len + 2 # 44
302 if h < patch_size or w < patch_size:
303 logging.warning(f"Latent size ({h}x{w}) too small for SFW HSQR injection (required {patch_size}x{patch_size}). Skipping injection.")
304 return inverted_latent
306 start_h = (h - patch_size) // 2
307 start_w = (w - patch_size) // 2
308 end_h = start_h + patch_size
309 end_w = start_w + patch_size
311 center_slice = (slice(None), slice(None), slice(start_h, end_h), slice(start_w, end_w))
312 center_latent_rfft = SFWUtils.rfft(inverted_latent[center_slice]) # (N,4,44,44) -> # (N,4,44,23) complex64
313 center_real_batch = center_latent_rfft.real # (N,4,44,23) f32
314 center_imag_batch = center_latent_rfft.imag # (N,4,44,23) f32
315 real_slice = (slice(None),[self.config.w_channel], slice(1, 1+qr_pix_len), slice(1, 1+qr_pix_half))
316 imag_slice = (slice(None), [self.config.w_channel], slice(1, 1+qr_pix_len), slice(1, 1+qr_pix_half))
317 center_real_batch[real_slice] = torch.where(qr_left, center_real_batch[real_slice].abs() + self.config.delta, -center_real_batch[real_slice].abs() - self.config.delta)
318 center_imag_batch[imag_slice] = torch.where(qr_right, center_imag_batch[imag_slice].abs() + self.config.delta, -center_imag_batch[imag_slice].abs() - self.config.delta)
319 center_latent_ifft = SFWUtils.irfft(torch.complex(center_real_batch, center_imag_batch)) # (N,4,44,44) f32
320 inverted_latent = inverted_latent.clone()
321 inverted_latent[center_slice] = center_latent_ifft
322 return inverted_latent # (N,4,64,64)
324@inherit_docstring
325class SFW(BaseWatermark):
326 def __init__(self,
327 watermark_config: SFWConfig,
328 *args, **kwargs):
329 """
330 Initialize the SFW watermarking algorithm.
332 Parameters:
333 watermark_config (SFWConfig): Configuration instance of the SFW algorithm.
334 """
335 self.config = watermark_config
336 self.utils = SFWUtils(self.config)
337 self.detector = SFWDetector(
338 watermarking_mask=self.utils.watermarking_mask,
339 gt_patch=self.utils.gt_patch,
340 w_channel=self.config.w_channel,
341 threshold=self.config.threshold,
342 device=self.config.device,
343 wm_type=self.config.wm_type,
344 threshold_p_value=self.config.threshold_p_value,
345 )
347 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image:
348 """Internal method to generate a watermarked image."""
349 if(self.config.wm_type=="HSQR"):
350 watermarked_latents = self.utils.inject_hsqr(self.config.init_latents)
351 else:
352 watermarked_latents = self.utils.inject_wm(self.config.init_latents)
354 # save watermarked latents
355 self.set_orig_watermarked_latents(watermarked_latents)
357 # Construct generation parameters
358 generation_params = {
359 "num_images_per_prompt": self.config.num_images,
360 "guidance_scale": self.config.guidance_scale,
361 "num_inference_steps": self.config.num_inference_steps,
362 "height": self.config.image_size[0],
363 "width": self.config.image_size[1],
364 "latents": watermarked_latents,
365 }
367 # Add parameters from config.gen_kwargs
368 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
369 for key, value in self.config.gen_kwargs.items():
370 if key not in generation_params:
371 generation_params[key] = value
373 # Use kwargs to override default parameters
374 for key, value in kwargs.items():
375 generation_params[key] = value
377 # Ensure latents parameter is not overridden
378 generation_params["latents"] = watermarked_latents
380 return self.config.pipe(
381 prompt,
382 **generation_params
383 ).images[0]
385 def _detect_watermark_in_image(self,
386 image: Image.Image,
387 prompt: str = "",
388 *args,
389 **kwargs) -> Dict[str, float]:
390 """Detect the watermark in the image."""
391 # Use config values as defaults if not explicitly provided
392 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
393 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps)
395 # Step 1: Get Text Embeddings
396 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
397 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
398 prompt=prompt,
399 device=self.config.device,
400 do_classifier_free_guidance=do_classifier_free_guidance,
401 num_images_per_prompt=1, # TODO: Multiple image generation to be supported
402 )
404 if do_classifier_free_guidance:
405 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
406 else:
407 text_embeddings = prompt_embeds
409 # Step 2: Preprocess Image
410 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
412 # Step 3: Get Image Latents
413 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get('decoder_inv', False))
415 # Step 4: Reverse Image Latents
416 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters
417 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']}
419 reversed_latents = self.config.inversion.forward_diffusion(
420 latents=image_latents,
421 text_embeddings=text_embeddings,
422 guidance_scale=guidance_scale_to_use,
423 num_inference_steps=num_steps_to_use,
424 **inversion_kwargs
425 )[-1]
426 # Step 5: Evaluate Watermark
427 if 'detector_type' in kwargs:
428 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type'])
429 else:
430 return self.detector.eval_watermark(reversed_latents)
432 def get_data_for_visualize(self,
433 image: Image.Image,
434 prompt: str="",
435 guidance_scale: Optional[float]=None,
436 decoder_inv: bool=False,
437 *args,
438 **kwargs) -> DataForVisualization:
439 """Get data for visualization including detection inversion"""
440 # Use config values as defaults if not explicitly provided
441 guidance_scale_to_use = guidance_scale if guidance_scale is not None else self.config.guidance_scale
443 # Step 1: Generate watermarked latents (generation process)
444 set_random_seed(self.config.gen_seed)
445 if (self.config.wm_type=="HSQR"):
446 watermarked_latents = self.utils.inject_hsqr(self.config.init_latents)
447 else:
448 watermarked_latents = self.utils.inject_wm(self.config.init_latents)
450 # Step 2: Generate actual watermarked image using the same process as _generate_watermarked_image
451 generation_params = {
452 "num_images_per_prompt": self.config.num_images,
453 "guidance_scale": self.config.guidance_scale,
454 "num_inference_steps": self.config.num_inference_steps,
455 "height": self.config.image_size[0],
456 "width": self.config.image_size[1],
457 "latents": watermarked_latents,
458 }
460 # Add parameters from config.gen_kwargs
461 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
462 for key, value in self.config.gen_kwargs.items():
463 if key not in generation_params:
464 generation_params[key] = value
466 # Generate the actual watermarked image
467 watermarked_image = self.config.pipe(
468 prompt,
469 **generation_params
470 ).images[0]
472 # Step 3: Perform watermark detection to get inverted latents (detection process)
473 inverted_latents = None
474 try:
475 # Get Text Embeddings for detection
476 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
477 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
478 prompt=prompt,
479 device=self.config.device,
480 do_classifier_free_guidance=do_classifier_free_guidance,
481 num_images_per_prompt=1, # TODO: Multiple image generation to be supported
482 )
484 if do_classifier_free_guidance:
485 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
486 else:
487 text_embeddings = prompt_embeds
489 # Preprocess watermarked image for detection
490 processed_image = transform_to_model_format(
491 watermarked_image,
492 target_size=self.config.image_size[0]
493 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
495 # Get Image Latents
496 image_latents = get_media_latents(
497 pipe=self.config.pipe,
498 media=processed_image,
499 sample=False,
500 decoder_inv=decoder_inv
501 )
503 # Reverse Image Latents to get inverted noise
504 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['prompt', 'decoder_inv', 'guidance_scale', 'num_inference_steps']}
506 reversed_latents_list = self.config.inversion.forward_diffusion(
507 latents=image_latents,
508 text_embeddings=text_embeddings,
509 guidance_scale=guidance_scale_to_use,
510 num_inference_steps=self.config.num_inference_steps,
511 **inversion_kwargs
512 )
514 inverted_latents = reversed_latents_list[-1]
516 except Exception as e:
517 print(f"Warning: Could not perform inversion for visualization: {e}")
518 inverted_latents = None
520 # Step 4: Prepare visualization data
521 return DataForVisualization(
522 config=self.config,
523 utils=self.utils,
524 reversed_latents=reversed_latents_list,
525 orig_watermarked_latents=self.orig_watermarked_latents,
526 image=image
527 )