Coverage for watermark / prc / prc.py: 94.29%
175 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from ..base import BaseWatermark, BaseConfig
2import torch
3from typing import Dict, Tuple
4from utils.diffusion_config import DiffusionConfig
5import numpy as np
6import galois
7from scipy.sparse import csr_matrix
8from scipy.special import binom
9from visualize.data_for_visualization import DataForVisualization
10from detection.prc.prc_detection import PRCDetector
11from utils.media_utils import *
12from utils.utils import set_random_seed
13from PIL import Image
15class PRCConfig(BaseConfig):
16 """Config class for PRC algorithm."""
18 def initialize_parameters(self) -> None:
19 """Initialize algorithm-specific parameters."""
20 self.fpr = self.config_dict['fpr']
21 self.t = self.config_dict['prc_t']
22 self.var = self.config_dict['var']
23 self.threshold = self.config_dict['threshold']
24 self.message = self._str_to_binary_array(self.config_dict['message'])
25 self.message_length = len(self.message) # 8-bit for each character, <= 512 bits for robustness
26 self.latents_height = self.image_size[0] // self.pipe.vae_scale_factor
27 self.latents_width = self.image_size[1] // self.pipe.vae_scale_factor
28 self.latents_channel = self.pipe.unet.config.in_channels
29 self.n = self.latents_height * self.latents_width * self.latents_channel # Dimension of the latent space
30 self.GF = galois.GF(2)
32 # Seeds for key generation
33 self.gen_matrix_seed = self.config_dict['keygen']['gen_matrix_seed']
34 self.indice_seed = self.config_dict['keygen']['indice_seed']
35 self.one_time_pad_seed = self.config_dict['keygen']['one_time_pad_seed']
36 self.test_bits_seed = self.config_dict['keygen']['test_bits_seed']
37 self.permute_bits_seed = self.config_dict['keygen']['permute_bits_seed']
39 # Seeds for encoding
40 self.payload_seed = self.config_dict['encode']['payload_seed']
41 self.error_seed = self.config_dict['encode']['error_seed']
42 self.pseudogaussian_seed = self.config_dict['encode']['pseudogaussian_seed']
44 @property
45 def algorithm_name(self) -> str:
46 """Return the algorithm name."""
47 return 'PRC'
49 def _str_to_binary_array(self, s: str) -> np.ndarray:
50 """Convert string to binary array."""
51 # Convert string to binary string
52 binary_str = ''.join(format(ord(c), '08b') for c in s)
54 # Convert binary string to NumPy array
55 binary_array = np.array([int(bit) for bit in binary_str])
57 return binary_array
60class PRCUtils:
61 """Utility class for PRC algorithm."""
63 def __init__(self, config: PRCConfig, *args, **kwargs) -> None:
64 """Initialize PRC utility."""
65 self.config = config
66 self.encoding_key, self.decoding_key = self._generate_encoding_key(self.config.message_length)
68 def _generate_encoding_key(self, message_length: int) -> Tuple[tuple, tuple]:
69 """Generate encoding key for PRC algorithm."""
70 # Set basic scheme parameters
71 num_test_bits = int(np.ceil(np.log2(1 / self.config.fpr)))
72 secpar = int(np.log2(binom(self.config.n, self.config.t)))
73 g = secpar
74 k = message_length + g + num_test_bits
75 r = self.config.n - k - secpar
76 noise_rate = 1 - 2 ** (-secpar / g ** 2)
78 # Sample n by k generator matrix (all but the first n-r of these will be over-written)
79 generator_matrix = self.config.GF.Random(shape=(self.config.n, k), seed=self.config.gen_matrix_seed)
81 # Sample scipy.sparse parity-check matrix together with the last n-r rows of the generator matrix
82 row_indices = []
83 col_indices = []
84 data = []
85 for i, row in enumerate(range(r)):
86 np.random.seed(self.config.indice_seed + i)
87 chosen_indices = np.random.choice(self.config.n - r + row, self.config.t - 1, replace=False)
88 chosen_indices = np.append(chosen_indices, self.config.n - r + row)
89 row_indices.extend([row] * self.config.t)
90 col_indices.extend(chosen_indices)
91 data.extend([1] * self.config.t)
92 generator_matrix[self.config.n - r + row] = generator_matrix[chosen_indices[:-1]].sum(axis=0)
93 parity_check_matrix = csr_matrix((data, (row_indices, col_indices)))
95 # Compute scheme parameters
96 max_bp_iter = int(np.log(self.config.n) / np.log(self.config.t))
98 # Sample one-time pad and test bits
99 one_time_pad = self.config.GF.Random(self.config.n, seed=self.config.one_time_pad_seed)
100 test_bits = self.config.GF.Random(num_test_bits, seed=self.config.test_bits_seed)
102 # Permute bits
103 np.random.seed(self.config.permute_bits_seed)
104 permutation = np.random.permutation(self.config.n)
105 generator_matrix = generator_matrix[permutation]
106 one_time_pad = one_time_pad[permutation]
107 parity_check_matrix = parity_check_matrix[:, permutation]
109 return (generator_matrix, one_time_pad, test_bits, g, noise_rate), (generator_matrix, parity_check_matrix, one_time_pad, self.config.fpr, noise_rate, test_bits, g, max_bp_iter, self.config.t)
111 def _encode_message(self, encoding_key: tuple, message: str = None) -> np.ndarray:
112 """Encode a message using PRC algorithm."""
113 generator_matrix, one_time_pad, test_bits, g, noise_rate = encoding_key
114 n, k = generator_matrix.shape
116 if message is None:
117 payload = np.concatenate((test_bits, self.config.GF.Random(k - len(test_bits), seed=self.config.payload_seed)))
118 else:
119 assert len(message) <= k-len(test_bits)-g, "Message is too long"
120 payload = np.concatenate((test_bits, self.config.GF.Random(g, seed=self.config.payload_seed), self.config.GF(message), self.config.GF.Zeros(k-len(test_bits)-g-len(message))))
122 np.random.seed(self.config.error_seed)
123 error = self.config.GF(np.random.binomial(1, noise_rate, n))
125 return 1 - 2 * torch.tensor(payload @ generator_matrix.T + one_time_pad + error, dtype=float)
127 def _sample_prc_codeword(self, codeword: torch.Tensor, basis: torch.Tensor = None) -> torch.Tensor:
128 """Sample a PRC codeword."""
129 codeword_np = codeword.numpy()
130 np.random.seed(self.config.pseudogaussian_seed)
131 pseudogaussian_np = codeword_np * np.abs(np.random.randn(*codeword_np.shape))
132 pseudogaussian = torch.from_numpy(pseudogaussian_np).to(dtype=torch.float32)
133 if basis is None:
134 return pseudogaussian
135 return pseudogaussian @ basis.T
137 def inject_watermark(self) -> torch.Tensor:
138 """Generate watermarked latents from PRC codeword."""
139 # Step 1: Encode message
140 prc_codeword = self._encode_message(self.encoding_key, self.config.message)
141 # Step 2: Sample PRC codeword and get watermarked latents
142 watermarked_latents = self._sample_prc_codeword(prc_codeword).reshape(1, self.config.latents_channel, self.config.latents_height, self.config.latents_width).to(self.config.device)
144 return watermarked_latents
146class PRC(BaseWatermark):
147 """PRC watermark class."""
149 def __init__(self,
150 watermark_config: PRCConfig,
151 *args, **kwargs):
152 """
153 Initialize PRC watermarking algorithm.
155 Parameters:
156 watermark_config (PRCConfig): Configuration instance of the PRC algorithm.
157 """
158 self.config = watermark_config
159 self.utils = PRCUtils(self.config)
161 self.detector = PRCDetector(
162 var=self.config.var,
163 decoding_key=self.utils.decoding_key,
164 GF=self.config.GF,
165 threshold=self.config.threshold,
166 device=self.config.device
167 )
169 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> torch.Tensor:
170 """Generate watermarked image."""
171 watermarked_latents = self.utils.inject_watermark()
173 # save watermarked latents
174 self.set_orig_watermarked_latents(watermarked_latents)
176 # Set gen seed
177 set_random_seed(self.config.gen_seed)
179 # Construct generation parameters
180 generation_params = {
181 "num_images_per_prompt": self.config.num_images,
182 "guidance_scale": self.config.guidance_scale,
183 "num_inference_steps": self.config.num_inference_steps,
184 "height": self.config.image_size[0],
185 "width": self.config.image_size[1],
186 "latents": watermarked_latents,
187 }
189 # Add parameters from config.gen_kwargs
190 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
191 for key, value in self.config.gen_kwargs.items():
192 if key not in generation_params:
193 generation_params[key] = value
195 # Use kwargs to override default parameters
196 for key, value in kwargs.items():
197 generation_params[key] = value
199 # Ensure latents parameter is not overridden
200 generation_params["latents"] = watermarked_latents
202 return self.config.pipe(
203 prompt,
204 **generation_params
205 ).images[0]
207 def _detect_watermark_in_image(self,
208 image: Image.Image,
209 prompt: str="",
210 *args,
211 **kwargs) -> Dict[str, float]:
212 """Detect watermark in image."""
213 # Use config values as defaults if not explicitly provided
214 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
215 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps)
217 # Step 1: Get Text Embeddings
218 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
219 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
220 prompt=prompt,
221 device=self.config.device,
222 do_classifier_free_guidance=do_classifier_free_guidance,
223 num_images_per_prompt=1,
224 )
226 if do_classifier_free_guidance:
227 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
228 else:
229 text_embeddings = prompt_embeds
231 # Step 2: Preprocess Image
232 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
234 # Step 3: Get Image Latents
235 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get('decoder_inv', False))
237 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters
238 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']}
240 # Step 4: Reverse Image Latents
241 reversed_latents = self.config.inversion.forward_diffusion(
242 latents=image_latents,
243 text_embeddings=text_embeddings,
244 guidance_scale=guidance_scale_to_use,
245 num_inference_steps=num_steps_to_use,
246 **inversion_kwargs
247 )[-1]
249 if 'detector_type' in kwargs:
250 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type'])
251 else:
252 return self.detector.eval_watermark(reversed_latents)
254 def get_data_for_visualize(self,
255 image: Image.Image,
256 prompt: str="",
257 guidance_scale: float=1,
258 decoder_inv: bool=False,
259 *args,
260 **kwargs) -> DataForVisualization:
261 # 1. Generate watermarked latents and collect intermediate data
262 set_random_seed(self.config.gen_seed)
264 # Step 1: Encode message
265 prc_codeword = self.utils._encode_message(self.utils.encoding_key, self.config.message)
267 # Step 2: Sample PRC codeword
268 pseudogaussian_noise = self.utils._sample_prc_codeword(prc_codeword)
270 # Step 3: Generate watermarked latents
271 watermarked_latents = pseudogaussian_noise.reshape(1, self.config.latents_channel, self.config.latents_height, self.config.latents_width).to(self.config.device)
273 # 2. Generate actual watermarked image using the same process as _generate_watermarked_image
274 generation_params = {
275 "num_images_per_prompt": self.config.num_images,
276 "guidance_scale": self.config.guidance_scale,
277 "num_inference_steps": self.config.num_inference_steps,
278 "height": self.config.image_size[0],
279 "width": self.config.image_size[1],
280 "latents": watermarked_latents,
281 }
283 # Add parameters from config.gen_kwargs
284 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
285 for key, value in self.config.gen_kwargs.items():
286 if key not in generation_params:
287 generation_params[key] = value
289 # Generate the actual watermarked image
290 watermarked_image = self.config.pipe(
291 prompt,
292 **generation_params
293 ).images[0]
295 # 3. Perform watermark detection to get inverted latents (for comparison)
296 inverted_latents = None
297 try:
298 # Use the same detection process as _detect_watermark_in_image
299 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
300 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps)
302 # Get Text Embeddings
303 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
304 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
305 prompt=prompt,
306 device=self.config.device,
307 do_classifier_free_guidance=do_classifier_free_guidance,
308 num_images_per_prompt=1,
309 )
311 if do_classifier_free_guidance:
312 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
313 else:
314 text_embeddings = prompt_embeds
316 # Preprocess watermarked image for detection
317 processed_image = transform_to_model_format(
318 watermarked_image,
319 target_size=self.config.image_size[0]
320 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
322 # Get Image Latents
323 image_latents = get_media_latents(
324 pipe=self.config.pipe,
325 media=processed_image,
326 sample=False,
327 decoder_inv=decoder_inv
328 )
330 # Reverse Image Latents to get inverted noise
331 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']}
333 reversed_latents = self.config.inversion.forward_diffusion(
334 latents=image_latents,
335 text_embeddings=text_embeddings,
336 guidance_scale=guidance_scale_to_use,
337 num_inference_steps=num_steps_to_use,
338 **inversion_kwargs
339 )[-1]
341 inverted_latents = reversed_latents
343 except Exception as e:
344 print(f"Warning: Could not perform inversion for visualization: {e}")
345 inverted_latents = None
347 # 3.5. Run actual detection to get recovered PRC codeword
348 recovered_prc = None
349 try:
350 if inverted_latents is not None:
351 # Use the detector to recover the PRC codeword
352 detection_result = self.detector.eval_watermark(inverted_latents)
353 # The detector should have recovered_prc attribute or return it
354 if hasattr(self.detector, 'recovered_prc') and self.detector.recovered_prc is not None:
355 recovered_prc = self.detector.recovered_prc
356 elif 'recovered_prc' in detection_result:
357 recovered_prc = detection_result['recovered_prc']
358 else:
359 print("Warning: Detector did not provide recovered_prc")
360 except Exception as e:
361 print(f"Warning: Could not recover PRC codeword for visualization: {e}")
362 recovered_prc = None
364 # 4. Prepare PRC-specific data
365 # Convert message to binary
366 message_bits = torch.tensor(self.config._str_to_binary_array(self.config.config_dict['message']), dtype=torch.float32)
368 # Get generator matrix
369 generator_matrix = torch.tensor(np.array(self.utils.encoding_key[0], dtype=float), dtype=torch.float32)
371 # Get parity check matrix
372 parity_check_matrix = self.utils.decoding_key[1]
374 # PRC parameters for visualization
375 prc_params = {
376 'FPR': self.config.fpr,
377 'Parameter t': self.config.t,
378 'Variance': self.config.var,
379 'Threshold': self.config.threshold,
380 'Message Length': self.config.message_length,
381 'Latent Dimension': self.config.n
382 }
384 # 5. Prepare visualization data
385 # Convert inverted_latents to list format to match base class expectations
386 reversed_latents_list = [inverted_latents] if inverted_latents is not None else [None]
388 return DataForVisualization(
389 config=self.config,
390 utils=self.utils,
391 latent_lists=[watermarked_latents],
392 orig_latents=watermarked_latents,
393 orig_watermarked_latents=watermarked_latents,
394 watermarked_latents=watermarked_latents,
395 watermarked_image=watermarked_image,
396 image=image,
397 reversed_latents=reversed_latents_list,
398 inverted_latents=inverted_latents,
399 # PRC-specific data
400 message_bits=message_bits,
401 prc_codeword=torch.tensor(prc_codeword, dtype=torch.float32),
402 pseudogaussian_noise=torch.tensor(pseudogaussian_noise, dtype=torch.float32),
403 generator_matrix=generator_matrix,
404 parity_check_matrix=parity_check_matrix,
405 prc_params=prc_params,
406 threshold=self.config.threshold,
407 recovered_prc=torch.tensor(recovered_prc, dtype=torch.float32) if recovered_prc is not None else None
408 )