Coverage for watermark / prc / prc.py: 94.29%

175 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from ..base import BaseWatermark, BaseConfig 

2import torch 

3from typing import Dict, Tuple 

4from utils.diffusion_config import DiffusionConfig 

5import numpy as np 

6import galois 

7from scipy.sparse import csr_matrix 

8from scipy.special import binom 

9from visualize.data_for_visualization import DataForVisualization 

10from detection.prc.prc_detection import PRCDetector 

11from utils.media_utils import * 

12from utils.utils import set_random_seed 

13from PIL import Image 

14 

15class PRCConfig(BaseConfig): 

16 """Config class for PRC algorithm.""" 

17 

18 def initialize_parameters(self) -> None: 

19 """Initialize algorithm-specific parameters.""" 

20 self.fpr = self.config_dict['fpr'] 

21 self.t = self.config_dict['prc_t'] 

22 self.var = self.config_dict['var'] 

23 self.threshold = self.config_dict['threshold'] 

24 self.message = self._str_to_binary_array(self.config_dict['message']) 

25 self.message_length = len(self.message) # 8-bit for each character, <= 512 bits for robustness 

26 self.latents_height = self.image_size[0] // self.pipe.vae_scale_factor 

27 self.latents_width = self.image_size[1] // self.pipe.vae_scale_factor 

28 self.latents_channel = self.pipe.unet.config.in_channels 

29 self.n = self.latents_height * self.latents_width * self.latents_channel # Dimension of the latent space 

30 self.GF = galois.GF(2) 

31 

32 # Seeds for key generation 

33 self.gen_matrix_seed = self.config_dict['keygen']['gen_matrix_seed'] 

34 self.indice_seed = self.config_dict['keygen']['indice_seed'] 

35 self.one_time_pad_seed = self.config_dict['keygen']['one_time_pad_seed'] 

36 self.test_bits_seed = self.config_dict['keygen']['test_bits_seed'] 

37 self.permute_bits_seed = self.config_dict['keygen']['permute_bits_seed'] 

38 

39 # Seeds for encoding 

40 self.payload_seed = self.config_dict['encode']['payload_seed'] 

41 self.error_seed = self.config_dict['encode']['error_seed'] 

42 self.pseudogaussian_seed = self.config_dict['encode']['pseudogaussian_seed'] 

43 

44 @property 

45 def algorithm_name(self) -> str: 

46 """Return the algorithm name.""" 

47 return 'PRC' 

48 

49 def _str_to_binary_array(self, s: str) -> np.ndarray: 

50 """Convert string to binary array.""" 

51 # Convert string to binary string 

52 binary_str = ''.join(format(ord(c), '08b') for c in s) 

53 

54 # Convert binary string to NumPy array 

55 binary_array = np.array([int(bit) for bit in binary_str]) 

56 

57 return binary_array 

58 

59 

60class PRCUtils: 

61 """Utility class for PRC algorithm.""" 

62 

63 def __init__(self, config: PRCConfig, *args, **kwargs) -> None: 

64 """Initialize PRC utility.""" 

65 self.config = config 

66 self.encoding_key, self.decoding_key = self._generate_encoding_key(self.config.message_length) 

67 

68 def _generate_encoding_key(self, message_length: int) -> Tuple[tuple, tuple]: 

69 """Generate encoding key for PRC algorithm.""" 

70 # Set basic scheme parameters 

71 num_test_bits = int(np.ceil(np.log2(1 / self.config.fpr))) 

72 secpar = int(np.log2(binom(self.config.n, self.config.t))) 

73 g = secpar 

74 k = message_length + g + num_test_bits 

75 r = self.config.n - k - secpar 

76 noise_rate = 1 - 2 ** (-secpar / g ** 2) 

77 

78 # Sample n by k generator matrix (all but the first n-r of these will be over-written) 

79 generator_matrix = self.config.GF.Random(shape=(self.config.n, k), seed=self.config.gen_matrix_seed) 

80 

81 # Sample scipy.sparse parity-check matrix together with the last n-r rows of the generator matrix 

82 row_indices = [] 

83 col_indices = [] 

84 data = [] 

85 for i, row in enumerate(range(r)): 

86 np.random.seed(self.config.indice_seed + i) 

87 chosen_indices = np.random.choice(self.config.n - r + row, self.config.t - 1, replace=False) 

88 chosen_indices = np.append(chosen_indices, self.config.n - r + row) 

89 row_indices.extend([row] * self.config.t) 

90 col_indices.extend(chosen_indices) 

91 data.extend([1] * self.config.t) 

92 generator_matrix[self.config.n - r + row] = generator_matrix[chosen_indices[:-1]].sum(axis=0) 

93 parity_check_matrix = csr_matrix((data, (row_indices, col_indices))) 

94 

95 # Compute scheme parameters 

96 max_bp_iter = int(np.log(self.config.n) / np.log(self.config.t)) 

97 

98 # Sample one-time pad and test bits 

99 one_time_pad = self.config.GF.Random(self.config.n, seed=self.config.one_time_pad_seed) 

100 test_bits = self.config.GF.Random(num_test_bits, seed=self.config.test_bits_seed) 

101 

102 # Permute bits 

103 np.random.seed(self.config.permute_bits_seed) 

104 permutation = np.random.permutation(self.config.n) 

105 generator_matrix = generator_matrix[permutation] 

106 one_time_pad = one_time_pad[permutation] 

107 parity_check_matrix = parity_check_matrix[:, permutation] 

108 

109 return (generator_matrix, one_time_pad, test_bits, g, noise_rate), (generator_matrix, parity_check_matrix, one_time_pad, self.config.fpr, noise_rate, test_bits, g, max_bp_iter, self.config.t) 

110 

111 def _encode_message(self, encoding_key: tuple, message: str = None) -> np.ndarray: 

112 """Encode a message using PRC algorithm.""" 

113 generator_matrix, one_time_pad, test_bits, g, noise_rate = encoding_key 

114 n, k = generator_matrix.shape 

115 

116 if message is None: 

117 payload = np.concatenate((test_bits, self.config.GF.Random(k - len(test_bits), seed=self.config.payload_seed))) 

118 else: 

119 assert len(message) <= k-len(test_bits)-g, "Message is too long" 

120 payload = np.concatenate((test_bits, self.config.GF.Random(g, seed=self.config.payload_seed), self.config.GF(message), self.config.GF.Zeros(k-len(test_bits)-g-len(message)))) 

121 

122 np.random.seed(self.config.error_seed) 

123 error = self.config.GF(np.random.binomial(1, noise_rate, n)) 

124 

125 return 1 - 2 * torch.tensor(payload @ generator_matrix.T + one_time_pad + error, dtype=float) 

126 

127 def _sample_prc_codeword(self, codeword: torch.Tensor, basis: torch.Tensor = None) -> torch.Tensor: 

128 """Sample a PRC codeword.""" 

129 codeword_np = codeword.numpy() 

130 np.random.seed(self.config.pseudogaussian_seed) 

131 pseudogaussian_np = codeword_np * np.abs(np.random.randn(*codeword_np.shape)) 

132 pseudogaussian = torch.from_numpy(pseudogaussian_np).to(dtype=torch.float32) 

133 if basis is None: 

134 return pseudogaussian 

135 return pseudogaussian @ basis.T 

136 

137 def inject_watermark(self) -> torch.Tensor: 

138 """Generate watermarked latents from PRC codeword.""" 

139 # Step 1: Encode message 

140 prc_codeword = self._encode_message(self.encoding_key, self.config.message) 

141 # Step 2: Sample PRC codeword and get watermarked latents 

142 watermarked_latents = self._sample_prc_codeword(prc_codeword).reshape(1, self.config.latents_channel, self.config.latents_height, self.config.latents_width).to(self.config.device) 

143 

144 return watermarked_latents 

145 

146class PRC(BaseWatermark): 

147 """PRC watermark class.""" 

148 

149 def __init__(self, 

150 watermark_config: PRCConfig, 

151 *args, **kwargs): 

152 """ 

153 Initialize PRC watermarking algorithm. 

154  

155 Parameters: 

156 watermark_config (PRCConfig): Configuration instance of the PRC algorithm. 

157 """ 

158 self.config = watermark_config 

159 self.utils = PRCUtils(self.config) 

160 

161 self.detector = PRCDetector( 

162 var=self.config.var, 

163 decoding_key=self.utils.decoding_key, 

164 GF=self.config.GF, 

165 threshold=self.config.threshold, 

166 device=self.config.device 

167 ) 

168 

169 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> torch.Tensor: 

170 """Generate watermarked image.""" 

171 watermarked_latents = self.utils.inject_watermark() 

172 

173 # save watermarked latents 

174 self.set_orig_watermarked_latents(watermarked_latents) 

175 

176 # Set gen seed 

177 set_random_seed(self.config.gen_seed) 

178 

179 # Construct generation parameters 

180 generation_params = { 

181 "num_images_per_prompt": self.config.num_images, 

182 "guidance_scale": self.config.guidance_scale, 

183 "num_inference_steps": self.config.num_inference_steps, 

184 "height": self.config.image_size[0], 

185 "width": self.config.image_size[1], 

186 "latents": watermarked_latents, 

187 } 

188 

189 # Add parameters from config.gen_kwargs 

190 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs: 

191 for key, value in self.config.gen_kwargs.items(): 

192 if key not in generation_params: 

193 generation_params[key] = value 

194 

195 # Use kwargs to override default parameters 

196 for key, value in kwargs.items(): 

197 generation_params[key] = value 

198 

199 # Ensure latents parameter is not overridden 

200 generation_params["latents"] = watermarked_latents 

201 

202 return self.config.pipe( 

203 prompt, 

204 **generation_params 

205 ).images[0] 

206 

207 def _detect_watermark_in_image(self, 

208 image: Image.Image, 

209 prompt: str="", 

210 *args, 

211 **kwargs) -> Dict[str, float]: 

212 """Detect watermark in image.""" 

213 # Use config values as defaults if not explicitly provided 

214 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

215 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps) 

216 

217 # Step 1: Get Text Embeddings 

218 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

219 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

220 prompt=prompt, 

221 device=self.config.device, 

222 do_classifier_free_guidance=do_classifier_free_guidance, 

223 num_images_per_prompt=1, 

224 ) 

225 

226 if do_classifier_free_guidance: 

227 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

228 else: 

229 text_embeddings = prompt_embeds 

230 

231 # Step 2: Preprocess Image 

232 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

233 

234 # Step 3: Get Image Latents 

235 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get('decoder_inv', False)) 

236 

237 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters 

238 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']} 

239 

240 # Step 4: Reverse Image Latents 

241 reversed_latents = self.config.inversion.forward_diffusion( 

242 latents=image_latents, 

243 text_embeddings=text_embeddings, 

244 guidance_scale=guidance_scale_to_use, 

245 num_inference_steps=num_steps_to_use, 

246 **inversion_kwargs 

247 )[-1] 

248 

249 if 'detector_type' in kwargs: 

250 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type']) 

251 else: 

252 return self.detector.eval_watermark(reversed_latents) 

253 

254 def get_data_for_visualize(self, 

255 image: Image.Image, 

256 prompt: str="", 

257 guidance_scale: float=1, 

258 decoder_inv: bool=False, 

259 *args, 

260 **kwargs) -> DataForVisualization: 

261 # 1. Generate watermarked latents and collect intermediate data 

262 set_random_seed(self.config.gen_seed) 

263 

264 # Step 1: Encode message 

265 prc_codeword = self.utils._encode_message(self.utils.encoding_key, self.config.message) 

266 

267 # Step 2: Sample PRC codeword 

268 pseudogaussian_noise = self.utils._sample_prc_codeword(prc_codeword) 

269 

270 # Step 3: Generate watermarked latents 

271 watermarked_latents = pseudogaussian_noise.reshape(1, self.config.latents_channel, self.config.latents_height, self.config.latents_width).to(self.config.device) 

272 

273 # 2. Generate actual watermarked image using the same process as _generate_watermarked_image 

274 generation_params = { 

275 "num_images_per_prompt": self.config.num_images, 

276 "guidance_scale": self.config.guidance_scale, 

277 "num_inference_steps": self.config.num_inference_steps, 

278 "height": self.config.image_size[0], 

279 "width": self.config.image_size[1], 

280 "latents": watermarked_latents, 

281 } 

282 

283 # Add parameters from config.gen_kwargs 

284 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs: 

285 for key, value in self.config.gen_kwargs.items(): 

286 if key not in generation_params: 

287 generation_params[key] = value 

288 

289 # Generate the actual watermarked image 

290 watermarked_image = self.config.pipe( 

291 prompt, 

292 **generation_params 

293 ).images[0] 

294 

295 # 3. Perform watermark detection to get inverted latents (for comparison) 

296 inverted_latents = None 

297 try: 

298 # Use the same detection process as _detect_watermark_in_image 

299 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

300 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps) 

301 

302 # Get Text Embeddings 

303 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

304 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

305 prompt=prompt, 

306 device=self.config.device, 

307 do_classifier_free_guidance=do_classifier_free_guidance, 

308 num_images_per_prompt=1, 

309 ) 

310 

311 if do_classifier_free_guidance: 

312 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

313 else: 

314 text_embeddings = prompt_embeds 

315 

316 # Preprocess watermarked image for detection 

317 processed_image = transform_to_model_format( 

318 watermarked_image, 

319 target_size=self.config.image_size[0] 

320 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

321 

322 # Get Image Latents 

323 image_latents = get_media_latents( 

324 pipe=self.config.pipe, 

325 media=processed_image, 

326 sample=False, 

327 decoder_inv=decoder_inv 

328 ) 

329 

330 # Reverse Image Latents to get inverted noise 

331 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']} 

332 

333 reversed_latents = self.config.inversion.forward_diffusion( 

334 latents=image_latents, 

335 text_embeddings=text_embeddings, 

336 guidance_scale=guidance_scale_to_use, 

337 num_inference_steps=num_steps_to_use, 

338 **inversion_kwargs 

339 )[-1] 

340 

341 inverted_latents = reversed_latents 

342 

343 except Exception as e: 

344 print(f"Warning: Could not perform inversion for visualization: {e}") 

345 inverted_latents = None 

346 

347 # 3.5. Run actual detection to get recovered PRC codeword 

348 recovered_prc = None 

349 try: 

350 if inverted_latents is not None: 

351 # Use the detector to recover the PRC codeword 

352 detection_result = self.detector.eval_watermark(inverted_latents) 

353 # The detector should have recovered_prc attribute or return it 

354 if hasattr(self.detector, 'recovered_prc') and self.detector.recovered_prc is not None: 

355 recovered_prc = self.detector.recovered_prc 

356 elif 'recovered_prc' in detection_result: 

357 recovered_prc = detection_result['recovered_prc'] 

358 else: 

359 print("Warning: Detector did not provide recovered_prc") 

360 except Exception as e: 

361 print(f"Warning: Could not recover PRC codeword for visualization: {e}") 

362 recovered_prc = None 

363 

364 # 4. Prepare PRC-specific data 

365 # Convert message to binary 

366 message_bits = torch.tensor(self.config._str_to_binary_array(self.config.config_dict['message']), dtype=torch.float32) 

367 

368 # Get generator matrix 

369 generator_matrix = torch.tensor(np.array(self.utils.encoding_key[0], dtype=float), dtype=torch.float32) 

370 

371 # Get parity check matrix 

372 parity_check_matrix = self.utils.decoding_key[1] 

373 

374 # PRC parameters for visualization 

375 prc_params = { 

376 'FPR': self.config.fpr, 

377 'Parameter t': self.config.t, 

378 'Variance': self.config.var, 

379 'Threshold': self.config.threshold, 

380 'Message Length': self.config.message_length, 

381 'Latent Dimension': self.config.n 

382 } 

383 

384 # 5. Prepare visualization data 

385 # Convert inverted_latents to list format to match base class expectations 

386 reversed_latents_list = [inverted_latents] if inverted_latents is not None else [None] 

387 

388 return DataForVisualization( 

389 config=self.config, 

390 utils=self.utils, 

391 latent_lists=[watermarked_latents], 

392 orig_latents=watermarked_latents, 

393 orig_watermarked_latents=watermarked_latents, 

394 watermarked_latents=watermarked_latents, 

395 watermarked_image=watermarked_image, 

396 image=image, 

397 reversed_latents=reversed_latents_list, 

398 inverted_latents=inverted_latents, 

399 # PRC-specific data 

400 message_bits=message_bits, 

401 prc_codeword=torch.tensor(prc_codeword, dtype=torch.float32), 

402 pseudogaussian_noise=torch.tensor(pseudogaussian_noise, dtype=torch.float32), 

403 generator_matrix=generator_matrix, 

404 parity_check_matrix=parity_check_matrix, 

405 prc_params=prc_params, 

406 threshold=self.config.threshold, 

407 recovered_prc=torch.tensor(recovered_prc, dtype=torch.float32) if recovered_prc is not None else None 

408 ) 

409