Coverage for watermark / ri / ri.py: 93.78%
193 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16import itertools
18import torch
19from ..base import BaseWatermark, BaseConfig
20import numpy as np
21from typing import Dict
22from torchvision import transforms
23from torchvision.transforms import functional as F
24from PIL import Image
25import random
26from detection.ri.ri_detection import RIDetector
27from utils.media_utils import *
28from utils.utils import set_random_seed
29from visualize.data_for_visualization import DataForVisualization
31class RIConfig(BaseConfig):
32 """Configuration class for the RI algorithm."""
34 def initialize_parameters(self) -> None:
35 """Initialize parameters for the RI algorithm."""
36 self.ring_width = self.config_dict['ring_width']
37 self.quantization_levels = self.config_dict['quantization_levels']
38 self.ring_value_range = self.config_dict['ring_value_range']
40 self.fix_gt = self.config_dict['fix_gt']
41 self.time_shift = self.config_dict['time_shift']
42 self.time_shift_factor = self.config_dict['time_shift_factor']
43 self.assigned_keys = self.config_dict['assigned_keys']
44 self.channel_min = self.config_dict['channel_min']
46 self.radius = self.config_dict['radius']
47 self.anchor_x_offset = self.config_dict['anchor_x_offset']
48 self.anchor_y_offset = self.config_dict['anchor_y_offset']
49 self.radius_cutoff = self.config_dict['radius_cutoff']
51 self.heter_watermark_channel = self.config_dict["heter_watermark_channel"]
52 self.ring_watermark_channel = self.config_dict["ring_watermark_channel"]
53 self.watermark_channel = sorted(self.heter_watermark_channel + self.ring_watermark_channel)
54 self.threshold = self.config_dict['threshold']
56 @property
57 def algorithm_name(self) -> str:
58 """Return the algorithm name."""
59 return "RI"
61class RIUtils:
62 """Utility class for the Ring-ID algorithm."""
64 def __init__(self, config: RIConfig, *args, **kwargs) -> None:
65 """Initialize the Ring-ID watermarking algorithm."""
66 self.config = config
68 self.latents, self.pattern, self.mask, self.pattern_list = self._prepare_fourier_pattern_and_mask()
70 def fft(self, input_tensor):
71 assert len(input_tensor.shape) == 4
72 return torch.fft.fftshift(torch.fft.fft2(input_tensor), dim=(-1, -2))
74 def ifft(self, input_tensor):
75 assert len(input_tensor.shape) == 4
76 return torch.fft.ifft2(torch.fft.ifftshift(input_tensor, dim=(-1, -2)))
79 def _ring_mask(self, size=65, r_out=16, r_in=8, x_offset=0, y_offset=0, mode='full'):
80 """
81 Construct a rotationally symmetric ring mask (fully replace the logic of RounderRingMask class)
82 """
83 assert size >= 3
84 assert mode == 'full', f"mode '{mode}' not implemented"
86 # Step 1: Initialize the frequency domain image and ring vector
87 center = size // 2
88 center_x, center_y = center + x_offset, center - y_offset
90 # Adjust r_out to fit within the image boundaries
91 if center_y + r_out > size:
92 r_out = max(0, size - center_y)
94 num_rings = r_out
95 zero_bg_freq = torch.zeros(size, size)
97 ring_vector = torch.tensor([(200 - i * 4) * (-1) ** i for i in range(num_rings)])
98 zero_bg_freq[center_x, center_y:center_y + num_rings] = ring_vector
99 zero_bg_freq = zero_bg_freq[None, None, ...]
100 ring_vector_np = ring_vector.numpy()
102 # Step 2: Rotate the frequency domain image to generate the rotationally invariant background pure_bg
103 res = torch.zeros(360, size, size)
104 res[0] = zero_bg_freq
105 for angle in range(1, 360):
106 res[angle] = F.rotate(zero_bg_freq, angle=angle)
108 res = res.numpy()
109 pure_bg = np.zeros((size, size))
111 for x in range(size):
112 for y in range(size):
113 values, count = np.unique(res[:, x, y], return_counts=True)
114 if len(count) > 2:
115 nonzero_values = values[values != 0]
116 max_value = nonzero_values[np.argmax(count[values != 0])]
117 pure_bg[x, y] = max_value
118 elif len(count) == 2:
119 pure_bg[x, y] = values[values != 0][0]
121 # Step 3: Extract the specified ring interval mask from pure_bg
122 right_end = 0 if r_in - 1 < 0 else r_in - 1
123 cand_list = ring_vector_np[r_out - 1:right_end:-1]
124 mask = np.isin(pure_bg, cand_list)
126 # Step 4: Crop the odd dimension → 64×64
127 if size % 2:
128 mask = mask[:size - 1, :size - 1]
129 return mask
131 def _make_Fourier_ringid_pattern(
132 self,
133 device,
134 key_value_combination,
135 no_watermark_latents,
136 radius,
137 radius_cutoff,
138 ring_watermark_channel,
139 heter_watermark_channel,
140 heter_watermark_region_mask=None,
141 ring_width=1,
142 ):
143 if ring_width != 1:
144 raise NotImplementedError(f'Proposed watermark generation only implemented for ring width = 1.')
146 if len(key_value_combination) != (self.config.radius - self.config.radius_cutoff):
147 raise ValueError('Mismatch between #key values and #slots')
149 shape = no_watermark_latents.shape
150 if len(shape) != 4:
151 raise ValueError(f'Invalid shape for initial latent: {shape}')
153 latents_fft = self.fft(no_watermark_latents)
154 # watermarked_latents_fft = copy.deepcopy(latents_fft)
155 watermarked_latents_fft = torch.zeros_like(latents_fft)
157 radius_list = [this_radius for this_radius in range(radius, radius_cutoff, -1)]
159 # put ring
160 for radius_index in range(len(radius_list)):
161 this_r_out = radius_list[radius_index]
162 this_r_in = this_r_out - ring_width
163 mask = torch.tensor(self._ring_mask(size=shape[-1], r_out=this_r_out, r_in=this_r_in)).to(device).to(
164 torch.float64) # sector_idx default to -1
165 for batch_index in range(shape[0]):
166 for channel_index in range(len(ring_watermark_channel)):
167 watermarked_latents_fft[batch_index, ring_watermark_channel[channel_index]].real = (1 - mask) * \
168 watermarked_latents_fft[
169 batch_index,
170 ring_watermark_channel[
171 channel_index]].real + mask * \
172 key_value_combination[
173 radius_index][
174 channel_index]
175 watermarked_latents_fft[batch_index, ring_watermark_channel[channel_index]].imag = (1 - mask) * \
176 watermarked_latents_fft[
177 batch_index,
178 ring_watermark_channel[
179 channel_index]].imag + mask * \
180 key_value_combination[
181 radius_index][
182 channel_index]
184 # put noise or zeros
185 if len(heter_watermark_channel) > 0:
186 assert len(heter_watermark_channel) == len(heter_watermark_region_mask)
187 heter_watermark_region_mask = heter_watermark_region_mask.to(torch.float64)
188 w_content = self.fft(torch.randn(*shape, device=device)) # [N, c, h, w]
190 for batch_index in range(shape[0]):
191 for channel_id, channel_mask in zip(heter_watermark_channel, heter_watermark_region_mask):
192 watermarked_latents_fft[batch_index, channel_id].real = \
193 (1 - channel_mask) * watermarked_latents_fft[batch_index, channel_id].real + channel_mask * \
194 w_content[batch_index][channel_id].real
195 watermarked_latents_fft[batch_index, channel_id].imag = \
196 (1 - channel_mask) * watermarked_latents_fft[batch_index, channel_id].imag + channel_mask * \
197 w_content[batch_index][channel_id].imag
199 return watermarked_latents_fft
201 def _prepare_fourier_pattern_and_mask(self):
202 # if self.pattern is not None and self.mask is not None:
203 # return self.latents, self.pattern, self.mask
204 # get latent shape
205 base_latents = get_random_latents(pipe=self.config.pipe, height=self.config.image_size[0], width=self.config.image_size[1])
206 original_latents_shape = base_latents.shape
207 base_latents = base_latents.to(torch.float64)
208 # self.latents = base_latents
210 sing_channel_ring_watermark_mask = torch.tensor(
211 self._ring_mask(
212 size=original_latents_shape[-1],
213 r_out=self.config.radius,
214 r_in=self.config.radius_cutoff)
215 )
217 # get heterogeneous watermark mask
218 if len(self.config.heter_watermark_channel) > 0:
219 single_channel_heter_watermark_mask = torch.tensor(
220 self._ring_mask(
221 size=original_latents_shape[-1],
222 r_out=self.config.radius,
223 r_in=self.config.radius_cutoff) # TODO: change to whole mask
224 )
225 heter_watermark_region_mask = single_channel_heter_watermark_mask.unsqueeze(0).repeat(
226 len(self.config.heter_watermark_channel), 1, 1).to(self.config.device)
228 watermark_region_mask = []
229 for channel_idx in self.config.watermark_channel:
230 if channel_idx in self.config.ring_watermark_channel:
231 watermark_region_mask.append(sing_channel_ring_watermark_mask)
232 else:
233 watermark_region_mask.append(single_channel_heter_watermark_mask)
234 watermark_region_mask = torch.stack(watermark_region_mask).to(self.config.device) # [C, 64, 64]
235 # self.mask = watermark_region_mask
237 single_channel_num_slots = self.config.radius - self.config.radius_cutoff
238 key_value_list = [[list(combo) for combo in itertools.product(
239 np.linspace(-self.config.ring_value_range, self.config.ring_value_range, self.config.quantization_levels).tolist(),
240 repeat=len(self.config.ring_watermark_channel))] for _ in range(single_channel_num_slots)]
241 key_value_combinations = list(itertools.product(*key_value_list))
243 # random select from all possible value combinations, then generate patterns for selected ones.
244 if self.config.assigned_keys > 0:
245 assert self.config.assigned_keys <= len(key_value_combinations)
246 key_value_combinations = random.sample(key_value_combinations, k=self.config.assigned_keys)
247 Fourier_watermark_pattern_list = [self._make_Fourier_ringid_pattern(self.config.device, list(combo), base_latents,
248 radius=self.config.radius, radius_cutoff=self.config.radius_cutoff,
249 ring_watermark_channel=self.config.ring_watermark_channel,
250 heter_watermark_channel=self.config.heter_watermark_channel,
251 heter_watermark_region_mask=heter_watermark_region_mask if len(
252 self.config.heter_watermark_channel) > 0 else None)
253 for _, combo in enumerate(key_value_combinations)]
254 ring_capacity = len(Fourier_watermark_pattern_list)
255 #print(ring_capacity)
257 if self.config.fix_gt:
258 Fourier_watermark_pattern_list = [self.fft(self.ifft(Fourier_watermark_pattern).real) for Fourier_watermark_pattern in
259 Fourier_watermark_pattern_list]
261 if self.config.time_shift:
262 for Fourier_watermark_pattern in Fourier_watermark_pattern_list:
263 # Fourier_watermark_pattern[:, RING_WATERMARK_CHANNEL, ...] = fft(torch.fft.fftshift(ifft(Fourier_watermark_pattern[:, RING_WATERMARK_CHANNEL, ...]), dim = (-1, -2)) * args.time_shift_factor)
264 Fourier_watermark_pattern[:, self.config.ring_watermark_channel, ...] = self.fft(
265 torch.fft.fftshift(self.ifft(Fourier_watermark_pattern[:, self.config.ring_watermark_channel, ...]), dim=(-1, -2)))
267 # self.pattern_list = Fourier_watermark_pattern_list
268 # Use a single ring pattern for verification
269 Fourier_watermark_pattern = Fourier_watermark_pattern_list[
270 -1] # [64, -64, 64, -64, 64...], select this ring pattern
271 # self.pattern = Fourier_watermark_pattern
272 return base_latents, Fourier_watermark_pattern, watermark_region_mask, Fourier_watermark_pattern_list
275 def generate_Fourier_watermark_latents(self, device, radius, radius_cutoff, watermark_region_mask, watermark_channel,
276 original_latents=None, watermark_pattern=None):
278 # set_random_seed(seed)
280 if original_latents is None:
281 # original_latents = torch.randn(*shape, device = device)
282 raise NotImplementedError('Original latents should be provided.')
284 if watermark_pattern is None:
285 raise NotImplementedError('Fourier watermark pattern should be provided.')
287 # circular_mask = torch.tensor(_ring_mask(size = original_latents.shape[-1], r_out = radius, r_in = radius_cutoff)).to(device)
288 watermarked_latents_fft = torch.fft.fftshift(torch.fft.fft2(original_latents), dim=(-1, -2))
290 # for channel in watermark_channel:
291 # watermarked_latents_fft[:, channel] = watermarked_latents_fft[:, channel] * ~circular_mask + watermark_pattern[:, channel] * circular_mask
293 assert len(watermark_channel) == len(watermark_region_mask)
294 for channel, channel_mask in zip(watermark_channel, watermark_region_mask):
295 watermarked_latents_fft[:, channel] = watermarked_latents_fft[:,
296 channel] * ~channel_mask + watermark_pattern[:,
297 channel] * channel_mask
299 return torch.fft.ifft2(torch.fft.ifftshift(watermarked_latents_fft, dim=(-1, -2))).real
302class RI(BaseWatermark):
303 """RI watermarking algorithm."""
305 def __init__(self,
306 watermark_config: RIConfig,
307 *args, **kwargs):
308 """
309 Initialize the RI algorithm.
311 Parameters:
312 watermark_config (RIConfig): Configuration instance of the RI algorithm.
313 """
314 self.config = watermark_config
315 self.utils = RIUtils(self.config)
317 self.detector = RIDetector(
318 watermarking_mask=self.utils.mask,
319 ring_watermark_channel=self.config.ring_watermark_channel,
320 heter_watermark_channel=self.config.heter_watermark_channel,
321 pattern_list=self.utils.pattern_list,
322 threshold=self.config.threshold,
323 device=self.config.device
324 )
326 def _generate_watermarked_image(self, prompt: str, *args,
327 **kwargs) -> Image.Image:
328 """Generate an image with a watermarked latent representation."""
329 watermarked_latents = self.utils.generate_Fourier_watermark_latents(
330 device=self.config.device,
331 radius=self.config.radius,
332 radius_cutoff=self.config.radius_cutoff,
333 original_latents= self.utils.latents,
334 watermark_pattern= self.utils.pattern,
335 watermark_channel=self.config.watermark_channel,
336 watermark_region_mask=self.utils.mask,
337 ).to(torch.float32)
339 # save watermarked latents
340 self.set_orig_watermarked_latents(watermarked_latents)
342 # Set gen seed
343 set_random_seed(self.config.gen_seed)
345 # Construct generation parameters
346 generation_params = {
347 "num_images_per_prompt": self.config.num_images,
348 "guidance_scale": self.config.guidance_scale,
349 "num_inference_steps": self.config.num_inference_steps,
350 "height": self.config.image_size[0],
351 "width": self.config.image_size[1],
352 "latents": watermarked_latents,
353 }
355 # Add parameters from config.gen_kwargs
356 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
357 for key, value in self.config.gen_kwargs.items():
358 if key not in generation_params:
359 generation_params[key] = value
361 # Use kwargs to override default parameters
362 for key, value in kwargs.items():
363 generation_params[key] = value
365 # Ensure latents parameter is not overridden
366 generation_params["latents"] = watermarked_latents
368 return self.config.pipe(
369 prompt,
370 **generation_params
371 ).images[0]
373 def _detect_watermark_in_image(self,
374 image: Image.Image,
375 prompt: str = "",
376 *args,
377 **kwargs) -> Dict[str, float]:
378 """Detect the watermark in the image."""
379 # Use config values as defaults if not explicitly provided
380 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
381 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps)
383 # Step 1: Get Text Embeddings
384 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
385 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
386 prompt=prompt,
387 device=self.config.device,
388 do_classifier_free_guidance=do_classifier_free_guidance,
389 num_images_per_prompt=1,
390 )
392 if do_classifier_free_guidance:
393 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
394 else:
395 text_embeddings = prompt_embeds
397 # Step 2: Preprocess Image
398 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
400 # Step 3: Get Image Latents
401 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get('decoder_inv', False))
403 # Step 4: Reverse Image Latents
404 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters
405 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']}
407 reversed_latents = self.config.inversion.forward_diffusion(
408 latents=image_latents,
409 text_embeddings=text_embeddings,
410 guidance_scale=guidance_scale_to_use,
411 num_inference_steps=num_steps_to_use,
412 **inversion_kwargs
413 )[-1]
415 if 'detector_type' in kwargs and 'mode' in kwargs:
416 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type'], mode=kwargs['mode'])
417 elif 'detector_type' in kwargs:
418 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type'])
419 elif 'mode' in kwargs:
420 return self.detector.eval_watermark(reversed_latents, mode=kwargs['mode'])
421 else:
422 return self.detector.eval_watermark(reversed_latents)
424 def get_data_for_visualize(self,
425 image: Image.Image = None,
426 prompt: str = "",
427 guidance_scale: float = 1,
428 decoder_inv: bool = False,
429 *args,
430 **kwargs) -> DataForVisualization:
431 """
432 Collect data for visualization of the RingID watermarking process.
434 Returns a DataForVisualization object containing all necessary data for RIVisualizer.
435 """
436 # Use config values as defaults if not explicitly provided
437 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
438 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps)
440 # Step 1: Get Text Embeddings
441 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
442 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
443 prompt=prompt,
444 device=self.config.device,
445 do_classifier_free_guidance=do_classifier_free_guidance,
446 num_images_per_prompt=1,
447 )
449 if do_classifier_free_guidance:
450 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
451 else:
452 text_embeddings = prompt_embeds
454 # Step 2: Preprocess Image
455 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
457 # Step 3: Get Image Latents
458 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get('decoder_inv', False))
460 # Step 4: Reverse Image Latents
461 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters
462 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']}
464 reversed_latents = self.config.inversion.forward_diffusion(
465 latents=image_latents,
466 text_embeddings=text_embeddings,
467 guidance_scale=guidance_scale_to_use,
468 num_inference_steps=num_steps_to_use,
469 **inversion_kwargs
470 )
472 # Step 4: Create DataForVisualization object with extended attributes for RI
473 data = DataForVisualization(
474 config=self.config,
475 utils=self.utils,
476 image=image,
477 reversed_latents=reversed_latents,
478 orig_watermarked_latents=self.orig_watermarked_latents,
479 )
481 return data