Coverage for watermark / sfw / sfw.py: 91.12%
304 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from ..base import BaseWatermark, BaseConfig
2from utils.media_utils import *
3import torch
4from typing import Dict, Optional
5from utils.utils import set_random_seed, inherit_docstring
6from utils.diffusion_config import DiffusionConfig
7import numpy as np
8from PIL import Image
9from visualize.data_for_visualization import DataForVisualization
10from detection.sfw.sfw_detection import SFWDetector
11import torchvision.transforms as tforms
12import qrcode
13import logging
14import os
16class SFWConfig(BaseConfig):
17 """Config class for SFW algorithm, load config file and initialize parameters."""
19 def initialize_parameters(self) -> None:
20 """Initialize algorithm-specific parameters."""
21 self.w_seed = self.config_dict['w_seed']
22 self.delta=self.config_dict['delta']
23 self.wm_type=self.config_dict['wm_type'] # "HSTR" or "HSQR"
24 self.threshold = self.config_dict['threshold']
25 self.w_channel = self.config_dict['w_channel']
27 @property
28 def algorithm_name(self) -> str:
29 """Return the algorithm name."""
30 return 'SFW'
32class SFWUtils:
33 """Utility class for SFW algorithm, contains helper functions."""
35 def __init__(self, config: SFWConfig, *args, **kwargs) -> None:
36 """
37 Initialize the SFW watermarking algorithm.
39 Parameters:
40 config (SFWConfig): Configuration for the SFW algorithm.
41 """
42 self.config = config
43 self.gt_patch = self._get_watermarking_pattern()
44 self.watermarking_mask = self._get_watermarking_mask(self.config.init_latents)
46 # [Fourier transforms]
47 @staticmethod
48 def fft(input_tensor):
49 assert len(input_tensor.shape) == 4
50 return torch.fft.fftshift(torch.fft.fft2(input_tensor), dim=(-1, -2))
52 @staticmethod
53 def ifft(input_tensor):
54 assert len(input_tensor.shape) == 4
55 return torch.fft.ifft2(torch.fft.ifftshift(input_tensor, dim=(-1, -2)))
57 @staticmethod
58 @torch.no_grad()
59 def rfft(input_tensor):
60 assert len(input_tensor.shape) == 4
61 return torch.fft.fftshift(torch.fft.rfft2(input_tensor, dim=(-2,-1)), dim=-2)
63 @staticmethod
64 @torch.no_grad()
65 def irfft(input_tensor):
66 assert len(input_tensor.shape) == 4
67 return torch.fft.irfft2(torch.fft.ifftshift(input_tensor, dim=-2), dim=(-2,-1), s=(input_tensor.shape[-2],input_tensor.shape[-2]))
69 def circle_mask(self,size: int, r=16, x_offset=0, y_offset=0):
70 x0 = y0 = size // 2
71 x0 += x_offset
72 y0 += y_offset
73 y, x = np.ogrid[:size, :size]
74 return ((x - x0)**2 + (y-y0)**2)<= r**2
76 @torch.no_grad()
77 def enforce_hermitian_symmetry(self,freq_tensor):
78 B, C, H, W = freq_tensor.shape # fftshifted frequency (complex tensor) - center (32,32)
79 assert H == W, "H != W"
80 freq_tensor = freq_tensor.clone()
81 freq_tensor_tmp = freq_tensor.clone()
82 # DC point (no imaginary)
83 freq_tensor[:, :, H//2, W//2] = torch.real(freq_tensor_tmp[:, :, H//2, W//2])
84 if H % 2 == 0: # Even
85 # Nyquist Points (no imaginary)
86 freq_tensor[:, :, 0, 0] = torch.real(freq_tensor_tmp[:, :, 0, 0])
87 freq_tensor[:, :, H//2, 0] = torch.real(freq_tensor_tmp[:, :, H//2, 0]) # (32, 0)
88 freq_tensor[:, :, 0, W//2] = torch.real(freq_tensor_tmp[:, :, 0, W//2]) # (0, 32)
90 # Nyquist axis - conjugate
91 freq_tensor[:, :, 0, 1:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, 0, W//2+1:], dims=[2]))
92 freq_tensor[:, :, H//2, 1:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2, W//2+1:], dims=[2]))
93 freq_tensor[:, :, 1:H//2, 0] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, 0], dims=[2]))
94 freq_tensor[:, :, 1:H//2, W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, W//2], dims=[2]))
95 # Square quadrants - conjugate
96 freq_tensor[:, :, 1:H//2, 1:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, W//2+1:], dims=[2, 3]))
97 freq_tensor[:, :, H//2+1:, 1:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, 1:H//2, W//2+1:], dims=[2, 3]))
98 else: # Odd
99 # Nyquist axis - conjugate
100 freq_tensor[:, :, H//2, 0:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2, W//2+1:], dims=[2]))
101 freq_tensor[:, :, 0:H//2, W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, W//2], dims=[2]))
102 # Square quadrants - conjugate
103 freq_tensor[:, :, 0:H//2, 0:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, H//2+1:, W//2+1:], dims=[2, 3]))
104 freq_tensor[:, :, H//2+1:, 0:W//2] = torch.conj(torch.flip(freq_tensor_tmp[:, :, 0:H//2, W//2+1:], dims=[2, 3]))
105 return freq_tensor
107 @torch.no_grad()
108 def make_Fourier_treering_pattern(self,pipe, shape, w_seed=999999, resolution=512):
109 assert shape[-1] == shape[-2] # 64==64
110 g = torch.Generator(device=self.config.device).manual_seed(w_seed)
111 gt_init = pipe.prepare_latents(1, pipe.unet.in_channels, resolution, resolution, pipe.unet.dtype, torch.device(self.config.device), g) # (1,4,64,64)
112 # [HSTR] center-aware design
113 watermarked_latents_fft = SFWUtils.fft(torch.zeros(shape, device=self.config.device)) # (1,4,64,64) complex64
115 h = shape[-2]
116 # Use a relative size or default to 44 if possible
117 if h >= 64:
118 patch_size = 44
119 else:
120 patch_size = h - 2 if h > 2 else h # Use almost full size for small latents
122 start = (h - patch_size) // 2
123 end = start + patch_size
125 center_slice = (slice(None), slice(None), slice(start, end), slice(start, end))
126 gt_patch_tmp = SFWUtils.fft(gt_init[center_slice]).clone().detach() # (1,4,44,44) complex64
127 center_len = gt_patch_tmp.shape[-1] // 2 # 22
128 for radius in range(center_len-1, 0, -1): # [21,20,...,1]
129 tmp_mask = torch.tensor(self.circle_mask(size=shape[-1], r=radius)) # (64,64)
130 for j in range(watermarked_latents_fft.shape[1]): # GT : all channel Tree-Ring
131 watermarked_latents_fft[:, j, tmp_mask] = gt_patch_tmp[0, j, center_len, center_len + radius].item() # Use (22,22+radius) element.
132 # Gaussian noise key (Heterogenous watermark in RingID)
133 watermarked_latents_fft[:,[0], start:end, start:end] = gt_patch_tmp[:, [0]] # (1,1,44,44) complex64
134 # [Hermitian Symmetric Fourier] HSTR
135 return self.enforce_hermitian_symmetry(watermarked_latents_fft)
137 # HSQR - hermitian symmetric QR pattern
138 class QRCodeGenerator:
139 def __init__(self, box_size=2, border=1, qr_version=1):
140 self.qr = qrcode.QRCode(version=qr_version, box_size=box_size, border=border,
141 error_correction = qrcode.constants.ERROR_CORRECT_H)
143 def make_qr_tensor(self, data, filename='qrcode.png', save_img=False):
144 self.qr.add_data(data)
145 self.qr.make(fit=True)
146 img = self.qr.make_image(fill_color="black", back_color="white")
147 if save_img:
148 img.save(filename)
149 self.clear()
150 img_array = np.array(img)
151 tensor = torch.from_numpy(img_array)
152 return tensor.clone().detach() # boolean (h,w)
154 def clear(self):
155 self.qr.clear()
157 @torch.no_grad()
158 def make_hsqr_pattern(self,idx: int):
159 qr_generator = self.QRCodeGenerator(box_size=2, border=0, qr_version=1)
160 data = f"HSQR{idx % 10000}"
161 qr_tensor = qr_generator.make_qr_tensor(data=data) # (42,42) boolean tensor
162 qr_tensor = qr_tensor.repeat(len([self.config.w_channel]), 1, 1) # (c_wm,42,42) boolean tensor
163 return qr_tensor # (c_wm,42,42) boolean tensor
165 def _get_watermarking_pattern(self) -> torch.Tensor:
166 """Get the ground truth watermarking pattern."""
167 set_random_seed(self.config.w_seed)
168 latent_h = self.config.image_size[0] // 8
169 latent_w = self.config.image_size[1] // 8
170 shape = (1, 4, latent_h, latent_w)
171 if self.config.wm_type == "HSQR":
172 Fourier_watermark_pattern_list = [self.make_hsqr_pattern(idx=self.config.w_seed)]
173 else:
174 Fourier_watermark_pattern_list = [self.make_Fourier_treering_pattern(self.config.pipe, shape, self.config.w_seed)]
175 pattern_gt_batch = [Fourier_watermark_pattern_list[0]]
176 # adjust dims of pattern_gt_batch
177 if len(pattern_gt_batch[0].shape) == 4:
178 pattern_gt_batch = torch.cat(pattern_gt_batch, dim=0) # (N,4,64,64) for HSTR
179 elif len(pattern_gt_batch[0].shape) == 3:
180 pattern_gt_batch = torch.stack(pattern_gt_batch, dim=0) # (N,c_wm,42,42) for HSQR
181 else:
182 raise ValueError(f"Unexpected pattern_gt_batch shape: {pattern_gt_batch[0].shape}")
183 assert len(pattern_gt_batch.shape) == 4
185 return pattern_gt_batch
187 # ring mask
188 class RounderRingMask:
189 def __init__(self, size=65, r_out=14):
190 assert size >= 3
191 self.size = size
192 self.r_out = r_out
194 num_rings = r_out
195 zero_bg_freq = torch.zeros(size, size)
196 center = size // 2
197 center_x, center_y = center, center
199 ring_vector = torch.tensor([(200 - i*4) * (-1)**i for i in range(num_rings)])
200 zero_bg_freq[center_x, center_y:center_y+num_rings] = ring_vector
201 zero_bg_freq = zero_bg_freq[None, None, ...]
202 self.ring_vector_np = ring_vector.numpy()
204 res = torch.zeros(360, size, size)
205 res[0] = zero_bg_freq
206 for angle in range(1, 360):
207 zero_bg_freq_rot = tforms.functional.rotate(zero_bg_freq, angle=angle)
208 res[angle] = zero_bg_freq_rot
210 res = res.numpy()
211 self.res = res
212 self.pure_bg = np.zeros((size, size))
213 for x in range(size):
214 for y in range(size):
215 values, count = np.unique(res[:, x, y], return_counts=True)
216 if len(count) > 2:
217 self.pure_bg[x, y] = values[count == max(count[values!=0])][0]
218 elif len(count) == 2:
219 self.pure_bg[x, y] = values[values!=0][0]
221 def get_ring_mask(self, r_out, r_in):
222 # get mask from pure_bg
223 assert r_out <= self.r_out
224 if r_in - 1 < 0:
225 right_end = 0 # None, to take the center
226 else:
227 right_end = r_in - 1
228 cand_list = self.ring_vector_np[r_out-1:right_end:-1]
229 mask = np.isin(self.pure_bg, cand_list)
230 if self.size % 2:
231 mask = mask[:self.size-1, :self.size-1] # [64, 64]
232 return mask
234 def _get_watermarking_mask(self, init_latents: torch.Tensor) -> torch.Tensor:
235 """Get the watermarking mask."""
236 shape=(1,4,64,64)
237 tree_masks = torch.zeros(shape, dtype=torch.bool) # (1,4,64,64)
238 single_channel_tree_watermark_mask = torch.tensor(self.circle_mask(size=shape[-1], r=14)) # (64,64)
239 tree_masks[:, [self.config.w_channel]] = single_channel_tree_watermark_mask # (64,64)
240 masks = tree_masks
241 mask_obj = self.RounderRingMask(size=65, r_out=14)
242 single_channel_heter_watermark_mask = torch.tensor(mask_obj.get_ring_mask(r_out=14, r_in=3) ) # (64,64)
243 masks[:, [0]] = single_channel_heter_watermark_mask # (64,64) RounderRingMask for Hetero Watermark (noise)
244 return masks
246 @torch.no_grad()
247 def inject_wm(self,init_latents: torch.Tensor):
248 # for HSTR
249 assert len(self.gt_patch.shape) == 4
250 assert len(self.watermarking_mask.shape) == 4
251 batch_size = init_latents.shape[0]
252 self.watermarking_mask = self.watermarking_mask.repeat(batch_size, 1, 1, 1)
254 init_latents =init_latents.to(self.config.device)
255 self.gt_patch = self.gt_patch.to(self.config.device)
256 self.watermarking_mask = self.watermarking_mask.to(self.config.device)
258 # inject watermarks in fourier space
259 h = init_latents.shape[-2]
260 if h >= 64:
261 patch_size = 44
262 else:
263 patch_size = h - 2 if h > 2 else h
265 start = (h - patch_size) // 2
266 end = start + patch_size
268 center_slice = (slice(None), slice(None), slice(start, end), slice(start, end))
269 assert len(init_latents[center_slice].shape) == 4
270 center_latent_fft=torch.fft.fftshift(torch.fft.fft2(init_latents[center_slice]), dim=(-1, -2))# (N,4,44,44) complex64
271 #injection
272 temp_mask = self.watermarking_mask[center_slice] # (N,4,44,44) boolean
273 temp_pattern = self.gt_patch[center_slice] # (N,4,44,44) complex64
274 center_latent_fft[temp_mask] = temp_pattern[temp_mask].clone() # (N,4,44,44) complex64
275 # IFFT
276 assert len(center_latent_fft.shape) == 4
277 center_latent_ifft=torch.fft.ifft2(torch.fft.ifftshift(center_latent_fft, dim=(-1, -2)))# (N,4,44,44)
278 center_latent_ifft = center_latent_ifft.real if center_latent_ifft.imag.abs().max() < 1e-3 else center_latent_ifft
280 init_latents = init_latents.clone()
281 init_latents[center_slice] = center_latent_ifft
282 init_latents[init_latents == float("Inf")] = 4
283 init_latents[init_latents == float("-Inf")] = -4
284 return init_latents # float32
286 @torch.no_grad()
287 def inject_hsqr(self,inverted_latent): # (N,4,64,64) -> (N,4,64,64)
288 assert len(self.gt_patch.shape) == 4 # (N,c_wm,42,42)
289 inverted_latent = inverted_latent.to(self.config.device)
290 self.gt_patch = self.gt_patch.to(self.config.device)
291 qr_pix_len = self.gt_patch.shape[-1] # 42
292 qr_pix_half = (qr_pix_len + 1) // 2 # 21
293 qr_left = self.gt_patch[:, :, :, :qr_pix_half] # (N,c_wm,42,21) boolean
294 qr_right = self.gt_patch[:, :, :, qr_pix_half:] # (N,c_wm,42,21) boolean
295 # rfft
296 h, w = inverted_latent.shape[-2:]
297 # The original code used a 44x44 slice for a 42x42 patch (padding of 1 on each side?)
298 # 54 - 10 = 44.
299 patch_size = qr_pix_len + 2 # 44
301 if h < patch_size or w < patch_size:
302 logging.warning(f"Latent size ({h}x{w}) too small for SFW HSQR injection (required {patch_size}x{patch_size}). Skipping injection.")
303 return inverted_latent
305 start_h = (h - patch_size) // 2
306 start_w = (w - patch_size) // 2
307 end_h = start_h + patch_size
308 end_w = start_w + patch_size
310 center_slice = (slice(None), slice(None), slice(start_h, end_h), slice(start_w, end_w))
311 center_latent_rfft = SFWUtils.rfft(inverted_latent[center_slice]) # (N,4,44,44) -> # (N,4,44,23) complex64
312 center_real_batch = center_latent_rfft.real # (N,4,44,23) f32
313 center_imag_batch = center_latent_rfft.imag # (N,4,44,23) f32
314 real_slice = (slice(None),[self.config.w_channel], slice(1, 1+qr_pix_len), slice(1, 1+qr_pix_half))
315 imag_slice = (slice(None), [self.config.w_channel], slice(1, 1+qr_pix_len), slice(1, 1+qr_pix_half))
316 center_real_batch[real_slice] = torch.where(qr_left, center_real_batch[real_slice].abs() + self.config.delta, -center_real_batch[real_slice].abs() - self.config.delta)
317 center_imag_batch[imag_slice] = torch.where(qr_right, center_imag_batch[imag_slice].abs() + self.config.delta, -center_imag_batch[imag_slice].abs() - self.config.delta)
318 center_latent_ifft = SFWUtils.irfft(torch.complex(center_real_batch, center_imag_batch)) # (N,4,44,44) f32
319 inverted_latent = inverted_latent.clone()
320 inverted_latent[center_slice] = center_latent_ifft
321 return inverted_latent # (N,4,64,64)
323@inherit_docstring
324class SFW(BaseWatermark):
325 def __init__(self,
326 watermark_config: SFWConfig,
327 *args, **kwargs):
328 """
329 Initialize the SFW watermarking algorithm.
331 Parameters:
332 watermark_config (SFWConfig): Configuration instance of the SFW algorithm.
333 """
334 self.config = watermark_config
335 self.utils = SFWUtils(self.config)
336 self.detector = SFWDetector(
337 watermarking_mask=self.utils.watermarking_mask,
338 gt_patch=self.utils.gt_patch,
339 w_channel=self.config.w_channel,
340 threshold=self.config.threshold,
341 device=self.config.device,
342 wm_type=self.config.wm_type
343 )
345 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image:
346 """Internal method to generate a watermarked image."""
347 if(self.config.wm_type=="HSQR"):
348 watermarked_latents = self.utils.inject_hsqr(self.config.init_latents)
349 else:
350 watermarked_latents = self.utils.inject_wm(self.config.init_latents)
352 # save watermarked latents
353 self.set_orig_watermarked_latents(watermarked_latents)
355 # Construct generation parameters
356 generation_params = {
357 "num_images_per_prompt": self.config.num_images,
358 "guidance_scale": self.config.guidance_scale,
359 "num_inference_steps": self.config.num_inference_steps,
360 "height": self.config.image_size[0],
361 "width": self.config.image_size[1],
362 "latents": watermarked_latents,
363 }
365 # Add parameters from config.gen_kwargs
366 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
367 for key, value in self.config.gen_kwargs.items():
368 if key not in generation_params:
369 generation_params[key] = value
371 # Use kwargs to override default parameters
372 for key, value in kwargs.items():
373 generation_params[key] = value
375 # Ensure latents parameter is not overridden
376 generation_params["latents"] = watermarked_latents
378 return self.config.pipe(
379 prompt,
380 **generation_params
381 ).images[0]
383 def _detect_watermark_in_image(self,
384 image: Image.Image,
385 prompt: str = "",
386 *args,
387 **kwargs) -> Dict[str, float]:
388 """Detect the watermark in the image."""
389 # Use config values as defaults if not explicitly provided
390 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
391 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps)
393 # Step 1: Get Text Embeddings
394 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
395 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
396 prompt=prompt,
397 device=self.config.device,
398 do_classifier_free_guidance=do_classifier_free_guidance,
399 num_images_per_prompt=1, # TODO: Multiple image generation to be supported
400 )
402 if do_classifier_free_guidance:
403 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
404 else:
405 text_embeddings = prompt_embeds
407 # Step 2: Preprocess Image
408 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
410 # Step 3: Get Image Latents
411 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get('decoder_inv', False))
413 # Step 4: Reverse Image Latents
414 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters
415 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']}
417 reversed_latents = self.config.inversion.forward_diffusion(
418 latents=image_latents,
419 text_embeddings=text_embeddings,
420 guidance_scale=guidance_scale_to_use,
421 num_inference_steps=num_steps_to_use,
422 **inversion_kwargs
423 )[-1]
424 # Step 5: Evaluate Watermark
425 if 'detector_type' in kwargs:
426 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type'])
427 else:
428 return self.detector.eval_watermark(reversed_latents)
430 def get_data_for_visualize(self,
431 image: Image.Image,
432 prompt: str="",
433 guidance_scale: Optional[float]=None,
434 decoder_inv: bool=False,
435 *args,
436 **kwargs) -> DataForVisualization:
437 """Get data for visualization including detection inversion"""
438 # Use config values as defaults if not explicitly provided
439 guidance_scale_to_use = guidance_scale if guidance_scale is not None else self.config.guidance_scale
441 # Step 1: Generate watermarked latents (generation process)
442 set_random_seed(self.config.gen_seed)
443 if (self.config.wm_type=="HSQR"):
444 watermarked_latents = self.utils.inject_hsqr(self.config.init_latents)
445 else:
446 watermarked_latents = self.utils.inject_wm(self.config.init_latents)
448 # Step 2: Generate actual watermarked image using the same process as _generate_watermarked_image
449 generation_params = {
450 "num_images_per_prompt": self.config.num_images,
451 "guidance_scale": self.config.guidance_scale,
452 "num_inference_steps": self.config.num_inference_steps,
453 "height": self.config.image_size[0],
454 "width": self.config.image_size[1],
455 "latents": watermarked_latents,
456 }
458 # Add parameters from config.gen_kwargs
459 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
460 for key, value in self.config.gen_kwargs.items():
461 if key not in generation_params:
462 generation_params[key] = value
464 # Generate the actual watermarked image
465 watermarked_image = self.config.pipe(
466 prompt,
467 **generation_params
468 ).images[0]
470 # Step 3: Perform watermark detection to get inverted latents (detection process)
471 inverted_latents = None
472 try:
473 # Get Text Embeddings for detection
474 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
475 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
476 prompt=prompt,
477 device=self.config.device,
478 do_classifier_free_guidance=do_classifier_free_guidance,
479 num_images_per_prompt=1, # TODO: Multiple image generation to be supported
480 )
482 if do_classifier_free_guidance:
483 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
484 else:
485 text_embeddings = prompt_embeds
487 # Preprocess watermarked image for detection
488 processed_image = transform_to_model_format(
489 watermarked_image,
490 target_size=self.config.image_size[0]
491 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
493 # Get Image Latents
494 image_latents = get_media_latents(
495 pipe=self.config.pipe,
496 media=processed_image,
497 sample=False,
498 decoder_inv=decoder_inv
499 )
501 # Reverse Image Latents to get inverted noise
502 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['prompt', 'decoder_inv', 'guidance_scale', 'num_inference_steps']}
504 reversed_latents_list = self.config.inversion.forward_diffusion(
505 latents=image_latents,
506 text_embeddings=text_embeddings,
507 guidance_scale=guidance_scale_to_use,
508 num_inference_steps=self.config.num_inference_steps,
509 **inversion_kwargs
510 )
512 inverted_latents = reversed_latents_list[-1]
514 except Exception as e:
515 print(f"Warning: Could not perform inversion for visualization: {e}")
516 inverted_latents = None
518 # Step 4: Prepare visualization data
519 return DataForVisualization(
520 config=self.config,
521 utils=self.utils,
522 reversed_latents=reversed_latents_list,
523 orig_watermarked_latents=self.orig_watermarked_latents,
524 image=image
525 )