Coverage for watermark / tr / tr.py: 91.67%
144 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from ..base import BaseWatermark, BaseConfig
2from utils.media_utils import *
3import torch
4from typing import Dict, Union, List, Optional
5from utils.utils import set_random_seed, inherit_docstring
6from utils.diffusion_config import DiffusionConfig
7import copy
8import numpy as np
9from PIL import Image
10from visualize.data_for_visualization import DataForVisualization
11from detection.tr.tr_detection import TRDetector
13class TRConfig(BaseConfig):
14 """Config class for TR algorithm, load config file and initialize parameters."""
16 def initialize_parameters(self) -> None:
17 """Initialize algorithm-specific parameters."""
18 self.w_seed = self.config_dict['w_seed']
19 self.w_channel = self.config_dict['w_channel']
20 self.w_pattern = self.config_dict['w_pattern']
21 # self.w_mask_shape = self.config_dict['w_mask_shape']
22 self.w_radius = self.config_dict['w_radius']
23 self.w_pattern_const = self.config_dict['w_pattern_const']
24 self.threshold = self.config_dict['threshold']
26 @property
27 def algorithm_name(self) -> str:
28 """Return the algorithm name."""
29 return 'TR'
31class TRUtils:
32 """Utility class for TR algorithm, contains helper functions."""
34 def __init__(self, config: TRConfig, *args, **kwargs) -> None:
35 """
36 Initialize the Tree-Ring watermarking algorithm.
38 Parameters:
39 config (TRConfig): Configuration for the Tree-Ring algorithm.
40 """
41 self.config = config
42 self.gt_patch = self._get_watermarking_pattern()
43 self.watermarking_mask = self._get_watermarking_mask(self.config.init_latents)
45 def _circle_mask(self, size: int=64, r: int=10, x_offset: int=0, y_offset: int=0) -> np.ndarray:
46 """Generate a circular mask."""
47 x0 = y0 = size // 2
48 x0 += x_offset
49 y0 += y_offset
50 y, x = np.ogrid[:size, :size]
51 y = y[::-1]
53 return ((x - x0)**2 + (y-y0)**2)<= r**2
55 def _get_watermarking_pattern(self) -> torch.Tensor:
56 """Get the ground truth watermarking pattern."""
57 set_random_seed(self.config.w_seed)
59 gt_init = get_random_latents(pipe=self.config.pipe, height=self.config.image_size[0], width=self.config.image_size[1])
61 if 'seed_ring' in self.config.w_pattern:
62 gt_patch = gt_init
64 gt_patch_tmp = copy.deepcopy(gt_patch)
65 for i in range(self.config.w_radius, 0, -1):
66 tmp_mask = self._circle_mask(gt_init.shape[-1], r=i)
67 tmp_mask = torch.tensor(tmp_mask).to(self.config.device)
69 for j in range(gt_patch.shape[1]):
70 gt_patch[:, j, tmp_mask] = gt_patch_tmp[0, j, 0, i].item()
71 elif 'seed_zeros' in self.config.w_pattern:
72 gt_patch = gt_init * 0
73 elif 'seed_rand' in self.config.w_pattern:
74 gt_patch = gt_init
75 elif 'rand' in self.config.w_pattern:
76 gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2))
77 gt_patch[:] = gt_patch[0]
78 elif 'zeros' in self.config.w_pattern:
79 gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) * 0
80 elif 'const' in self.config.w_pattern:
81 gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) * 0
82 gt_patch += self.config.w_pattern_const
83 elif 'ring' in self.config.w_pattern:
84 gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2))
86 gt_patch_tmp = copy.deepcopy(gt_patch)
87 for i in range(self.config.w_radius, 0, -1):
88 tmp_mask = self._circle_mask(gt_init.shape[-1], r=i)
89 tmp_mask = torch.tensor(tmp_mask).to(self.config.device)
91 for j in range(gt_patch.shape[1]):
92 gt_patch[:, j, tmp_mask] = gt_patch_tmp[0, j, 0, i].item()
94 return gt_patch
96 def _get_watermarking_mask(self, init_latents: torch.Tensor) -> torch.Tensor:
97 """Get the watermarking mask."""
98 watermarking_mask = torch.zeros(init_latents.shape, dtype=torch.bool).to(self.config.device)
100 # if self.config.w_mask_shape == 'circle':
101 np_mask = self._circle_mask(init_latents.shape[-1], r=self.config.w_radius)
102 torch_mask = torch.tensor(np_mask).to(self.config.device)
104 if self.config.w_channel == -1:
105 # all channels
106 watermarking_mask[:, :] = torch_mask
107 else:
108 watermarking_mask[:, self.config.w_channel] = torch_mask
109 # elif self.config.w_mask_shape == 'square':
110 # anchor_p = init_latents.shape[-1] // 2
111 # if self.config.w_channel == -1:
112 # # all channels
113 # watermarking_mask[:, :, anchor_p-self.config.w_radius:anchor_p+self.config.w_radius, anchor_p-self.config.w_radius:anchor_p+self.config.w_radius] = True
114 # else:
115 # watermarking_mask[:, self.config.w_channel, anchor_p-self.config.w_radius:anchor_p+self.config.w_radius, anchor_p-self.config.w_radius:anchor_p+self.config.w_radius] = True
116 # elif self.config.w_mask_shape == 'no':
117 # pass
118 # else:
119 # raise NotImplementedError(f'w_mask_shape: {self.config.w_mask_shape}')
121 return watermarking_mask
123 def inject_watermark(self, init_latents: torch.Tensor) -> torch.Tensor:
124 init_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(init_latents), dim=(-1, -2))
125 target_patch = self.gt_patch
127 if not torch.is_complex(target_patch):
128 real = target_patch.to(torch.float32)
129 imag = torch.zeros_like(real)
130 target_patch = torch.complex(real, imag)
131 target_patch = target_patch.to(init_latents_w_fft.dtype)
133 init_latents_w_fft[self.watermarking_mask] = target_patch[self.watermarking_mask].clone()
135 init_latents_w = torch.fft.ifft2(torch.fft.ifftshift(init_latents_w_fft, dim=(-1, -2))).real
136 return init_latents_w
138@inherit_docstring
139class TR(BaseWatermark):
140 def __init__(self,
141 watermark_config: TRConfig,
142 *args, **kwargs):
143 """
144 Initialize the TR watermarking algorithm.
146 Parameters:
147 watermark_config (TRConfig): Configuration instance of the Tree-Ring algorithm.
148 """
149 self.config = watermark_config
150 self.utils = TRUtils(self.config)
152 self.detector = TRDetector(
153 watermarking_mask=self.utils.watermarking_mask,
154 gt_patch=self.utils.gt_patch,
155 threshold=self.config.threshold,
156 device=self.config.device
157 )
159 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image:
160 """Internal method to generate a watermarked image."""
161 watermarked_latents = self.utils.inject_watermark(self.config.init_latents)
163 # save watermarked latents
164 self.set_orig_watermarked_latents(watermarked_latents)
166 # Construct generation parameters
167 generation_params = {
168 "num_images_per_prompt": self.config.num_images,
169 "guidance_scale": self.config.guidance_scale,
170 "num_inference_steps": self.config.num_inference_steps,
171 "height": self.config.image_size[0],
172 "width": self.config.image_size[1],
173 "latents": watermarked_latents,
174 }
176 # Add parameters from config.gen_kwargs
177 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
178 for key, value in self.config.gen_kwargs.items():
179 if key not in generation_params:
180 generation_params[key] = value
182 # Use kwargs to override default parameters
183 for key, value in kwargs.items():
184 generation_params[key] = value
186 # Ensure latents parameter is not overridden
187 generation_params["latents"] = watermarked_latents
189 return self.config.pipe(
190 prompt,
191 **generation_params
192 ).images[0]
194 def _detect_watermark_in_image(self,
195 image: Image.Image,
196 prompt: str = "",
197 *args,
198 **kwargs) -> Dict[str, float]:
199 """Detect the watermark in the image."""
200 # Use config values as defaults if not explicitly provided
201 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
202 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps)
204 # Step 1: Get Text Embeddings
205 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
206 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
207 prompt=prompt,
208 device=self.config.device,
209 do_classifier_free_guidance=do_classifier_free_guidance,
210 num_images_per_prompt=1, # TODO: Multiple image generation to be supported
211 )
213 if do_classifier_free_guidance:
214 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
215 else:
216 text_embeddings = prompt_embeds
218 # Step 2: Preprocess Image
219 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
221 # Step 3: Get Image Latents
222 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get('decoder_inv', False))
224 # Step 4: Reverse Image Latents
225 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters
226 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']}
228 reversed_latents = self.config.inversion.forward_diffusion(
229 latents=image_latents,
230 text_embeddings=text_embeddings,
231 guidance_scale=guidance_scale_to_use,
232 num_inference_steps=num_steps_to_use,
233 **inversion_kwargs
234 )[-1]
236 # Step 5: Evaluate Watermark
237 if 'detector_type' in kwargs:
238 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type'])
239 else:
240 return self.detector.eval_watermark(reversed_latents)
242 def get_data_for_visualize(self,
243 image: Image.Image,
244 prompt: str="",
245 guidance_scale: Optional[float]=None,
246 decoder_inv: bool=False,
247 *args,
248 **kwargs) -> DataForVisualization:
249 """Get data for visualization including detection inversion - similar to GS logic."""
250 # Use config values as defaults if not explicitly provided
251 guidance_scale_to_use = guidance_scale if guidance_scale is not None else self.config.guidance_scale
253 # Step 1: Generate watermarked latents (generation process)
254 set_random_seed(self.config.gen_seed)
255 watermarked_latents = self.utils.inject_watermark(self.config.init_latents)
257 # Step 2: Generate actual watermarked image using the same process as _generate_watermarked_image
258 generation_params = {
259 "num_images_per_prompt": self.config.num_images,
260 "guidance_scale": self.config.guidance_scale,
261 "num_inference_steps": self.config.num_inference_steps,
262 "height": self.config.image_size[0],
263 "width": self.config.image_size[1],
264 "latents": watermarked_latents,
265 }
267 # Add parameters from config.gen_kwargs
268 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
269 for key, value in self.config.gen_kwargs.items():
270 if key not in generation_params:
271 generation_params[key] = value
273 # Generate the actual watermarked image
274 watermarked_image = self.config.pipe(
275 prompt,
276 **generation_params
277 ).images[0]
279 # Step 3: Perform watermark detection to get inverted latents (detection process)
280 inverted_latents = None
281 try:
282 # Get Text Embeddings for detection
283 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
284 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
285 prompt=prompt,
286 device=self.config.device,
287 do_classifier_free_guidance=do_classifier_free_guidance,
288 num_images_per_prompt=1, # TODO: Multiple image generation to be supported
289 )
291 if do_classifier_free_guidance:
292 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
293 else:
294 text_embeddings = prompt_embeds
296 # Preprocess watermarked image for detection
297 processed_image = transform_to_model_format(
298 watermarked_image,
299 target_size=self.config.image_size[0]
300 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
302 # Get Image Latents
303 image_latents = get_media_latents(
304 pipe=self.config.pipe,
305 media=processed_image,
306 sample=False,
307 decoder_inv=decoder_inv
308 )
310 # Reverse Image Latents to get inverted noise
311 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['prompt', 'decoder_inv', 'guidance_scale', 'num_inference_steps']}
313 reversed_latents_list = self.config.inversion.forward_diffusion(
314 latents=image_latents,
315 text_embeddings=text_embeddings,
316 guidance_scale=guidance_scale_to_use,
317 num_inference_steps=self.config.num_inference_steps,
318 **inversion_kwargs
319 )
321 inverted_latents = reversed_latents_list[-1]
323 except Exception as e:
324 print(f"Warning: Could not perform inversion for visualization: {e}")
325 inverted_latents = None
327 # Step 4: Prepare visualization data
328 return DataForVisualization(
329 config=self.config,
330 utils=self.utils,
331 reversed_latents=reversed_latents_list,
332 orig_watermarked_latents=self.orig_watermarked_latents,
333 image=image,
334 )
335# try tr