Coverage for watermark / ri / ri.py: 93.78%

193 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16import itertools 

17 

18import torch 

19from ..base import BaseWatermark, BaseConfig 

20import numpy as np 

21from typing import Dict 

22from torchvision import transforms 

23from torchvision.transforms import functional as F 

24from PIL import Image 

25import random 

26from detection.ri.ri_detection import RIDetector 

27from utils.media_utils import * 

28from utils.utils import set_random_seed 

29from visualize.data_for_visualization import DataForVisualization 

30 

31class RIConfig(BaseConfig): 

32 """Configuration class for the RI algorithm.""" 

33 

34 def initialize_parameters(self) -> None: 

35 """Initialize parameters for the RI algorithm.""" 

36 self.ring_width = self.config_dict['ring_width'] 

37 self.quantization_levels = self.config_dict['quantization_levels'] 

38 self.ring_value_range = self.config_dict['ring_value_range'] 

39 

40 self.fix_gt = self.config_dict['fix_gt'] 

41 self.time_shift = self.config_dict['time_shift'] 

42 self.time_shift_factor = self.config_dict['time_shift_factor'] 

43 self.assigned_keys = self.config_dict['assigned_keys'] 

44 self.channel_min = self.config_dict['channel_min'] 

45 

46 self.radius = self.config_dict['radius'] 

47 self.anchor_x_offset = self.config_dict['anchor_x_offset'] 

48 self.anchor_y_offset = self.config_dict['anchor_y_offset'] 

49 self.radius_cutoff = self.config_dict['radius_cutoff'] 

50 

51 self.heter_watermark_channel = self.config_dict["heter_watermark_channel"] 

52 self.ring_watermark_channel = self.config_dict["ring_watermark_channel"] 

53 self.watermark_channel = sorted(self.heter_watermark_channel + self.ring_watermark_channel) 

54 self.threshold = self.config_dict['threshold'] 

55 

56 @property 

57 def algorithm_name(self) -> str: 

58 """Return the algorithm name.""" 

59 return "RI" 

60 

61class RIUtils: 

62 """Utility class for the Ring-ID algorithm.""" 

63 

64 def __init__(self, config: RIConfig, *args, **kwargs) -> None: 

65 """Initialize the Ring-ID watermarking algorithm.""" 

66 self.config = config 

67 

68 self.latents, self.pattern, self.mask, self.pattern_list = self._prepare_fourier_pattern_and_mask() 

69 

70 def fft(self, input_tensor): 

71 assert len(input_tensor.shape) == 4 

72 return torch.fft.fftshift(torch.fft.fft2(input_tensor), dim=(-1, -2)) 

73 

74 def ifft(self, input_tensor): 

75 assert len(input_tensor.shape) == 4 

76 return torch.fft.ifft2(torch.fft.ifftshift(input_tensor, dim=(-1, -2))) 

77 

78 

79 def _ring_mask(self, size=65, r_out=16, r_in=8, x_offset=0, y_offset=0, mode='full'): 

80 """ 

81 Construct a rotationally symmetric ring mask (fully replace the logic of RounderRingMask class) 

82 """ 

83 assert size >= 3 

84 assert mode == 'full', f"mode '{mode}' not implemented" 

85 

86 # Step 1: Initialize the frequency domain image and ring vector 

87 center = size // 2 

88 center_x, center_y = center + x_offset, center - y_offset 

89 

90 # Adjust r_out to fit within the image boundaries 

91 if center_y + r_out > size: 

92 r_out = max(0, size - center_y) 

93 

94 num_rings = r_out 

95 zero_bg_freq = torch.zeros(size, size) 

96 

97 ring_vector = torch.tensor([(200 - i * 4) * (-1) ** i for i in range(num_rings)]) 

98 zero_bg_freq[center_x, center_y:center_y + num_rings] = ring_vector 

99 zero_bg_freq = zero_bg_freq[None, None, ...] 

100 ring_vector_np = ring_vector.numpy() 

101 

102 # Step 2: Rotate the frequency domain image to generate the rotationally invariant background pure_bg 

103 res = torch.zeros(360, size, size) 

104 res[0] = zero_bg_freq 

105 for angle in range(1, 360): 

106 res[angle] = F.rotate(zero_bg_freq, angle=angle) 

107 

108 res = res.numpy() 

109 pure_bg = np.zeros((size, size)) 

110 

111 for x in range(size): 

112 for y in range(size): 

113 values, count = np.unique(res[:, x, y], return_counts=True) 

114 if len(count) > 2: 

115 nonzero_values = values[values != 0] 

116 max_value = nonzero_values[np.argmax(count[values != 0])] 

117 pure_bg[x, y] = max_value 

118 elif len(count) == 2: 

119 pure_bg[x, y] = values[values != 0][0] 

120 

121 # Step 3: Extract the specified ring interval mask from pure_bg 

122 right_end = 0 if r_in - 1 < 0 else r_in - 1 

123 cand_list = ring_vector_np[r_out - 1:right_end:-1] 

124 mask = np.isin(pure_bg, cand_list) 

125 

126 # Step 4: Crop the odd dimension → 64×64 

127 if size % 2: 

128 mask = mask[:size - 1, :size - 1] 

129 return mask 

130 

131 def _make_Fourier_ringid_pattern( 

132 self, 

133 device, 

134 key_value_combination, 

135 no_watermark_latents, 

136 radius, 

137 radius_cutoff, 

138 ring_watermark_channel, 

139 heter_watermark_channel, 

140 heter_watermark_region_mask=None, 

141 ring_width=1, 

142 ): 

143 if ring_width != 1: 

144 raise NotImplementedError(f'Proposed watermark generation only implemented for ring width = 1.') 

145 

146 if len(key_value_combination) != (self.config.radius - self.config.radius_cutoff): 

147 raise ValueError('Mismatch between #key values and #slots') 

148 

149 shape = no_watermark_latents.shape 

150 if len(shape) != 4: 

151 raise ValueError(f'Invalid shape for initial latent: {shape}') 

152 

153 latents_fft = self.fft(no_watermark_latents) 

154 # watermarked_latents_fft = copy.deepcopy(latents_fft) 

155 watermarked_latents_fft = torch.zeros_like(latents_fft) 

156 

157 radius_list = [this_radius for this_radius in range(radius, radius_cutoff, -1)] 

158 

159 # put ring 

160 for radius_index in range(len(radius_list)): 

161 this_r_out = radius_list[radius_index] 

162 this_r_in = this_r_out - ring_width 

163 mask = torch.tensor(self._ring_mask(size=shape[-1], r_out=this_r_out, r_in=this_r_in)).to(device).to( 

164 torch.float64) # sector_idx default to -1 

165 for batch_index in range(shape[0]): 

166 for channel_index in range(len(ring_watermark_channel)): 

167 watermarked_latents_fft[batch_index, ring_watermark_channel[channel_index]].real = (1 - mask) * \ 

168 watermarked_latents_fft[ 

169 batch_index, 

170 ring_watermark_channel[ 

171 channel_index]].real + mask * \ 

172 key_value_combination[ 

173 radius_index][ 

174 channel_index] 

175 watermarked_latents_fft[batch_index, ring_watermark_channel[channel_index]].imag = (1 - mask) * \ 

176 watermarked_latents_fft[ 

177 batch_index, 

178 ring_watermark_channel[ 

179 channel_index]].imag + mask * \ 

180 key_value_combination[ 

181 radius_index][ 

182 channel_index] 

183 

184 # put noise or zeros 

185 if len(heter_watermark_channel) > 0: 

186 assert len(heter_watermark_channel) == len(heter_watermark_region_mask) 

187 heter_watermark_region_mask = heter_watermark_region_mask.to(torch.float64) 

188 w_content = self.fft(torch.randn(*shape, device=device)) # [N, c, h, w] 

189 

190 for batch_index in range(shape[0]): 

191 for channel_id, channel_mask in zip(heter_watermark_channel, heter_watermark_region_mask): 

192 watermarked_latents_fft[batch_index, channel_id].real = \ 

193 (1 - channel_mask) * watermarked_latents_fft[batch_index, channel_id].real + channel_mask * \ 

194 w_content[batch_index][channel_id].real 

195 watermarked_latents_fft[batch_index, channel_id].imag = \ 

196 (1 - channel_mask) * watermarked_latents_fft[batch_index, channel_id].imag + channel_mask * \ 

197 w_content[batch_index][channel_id].imag 

198 

199 return watermarked_latents_fft 

200 

201 def _prepare_fourier_pattern_and_mask(self): 

202 # if self.pattern is not None and self.mask is not None: 

203 # return self.latents, self.pattern, self.mask 

204 # get latent shape 

205 base_latents = get_random_latents(pipe=self.config.pipe, height=self.config.image_size[0], width=self.config.image_size[1]) 

206 original_latents_shape = base_latents.shape 

207 base_latents = base_latents.to(torch.float64) 

208 # self.latents = base_latents 

209 

210 sing_channel_ring_watermark_mask = torch.tensor( 

211 self._ring_mask( 

212 size=original_latents_shape[-1], 

213 r_out=self.config.radius, 

214 r_in=self.config.radius_cutoff) 

215 ) 

216 

217 # get heterogeneous watermark mask 

218 if len(self.config.heter_watermark_channel) > 0: 

219 single_channel_heter_watermark_mask = torch.tensor( 

220 self._ring_mask( 

221 size=original_latents_shape[-1], 

222 r_out=self.config.radius, 

223 r_in=self.config.radius_cutoff) # TODO: change to whole mask 

224 ) 

225 heter_watermark_region_mask = single_channel_heter_watermark_mask.unsqueeze(0).repeat( 

226 len(self.config.heter_watermark_channel), 1, 1).to(self.config.device) 

227 

228 watermark_region_mask = [] 

229 for channel_idx in self.config.watermark_channel: 

230 if channel_idx in self.config.ring_watermark_channel: 

231 watermark_region_mask.append(sing_channel_ring_watermark_mask) 

232 else: 

233 watermark_region_mask.append(single_channel_heter_watermark_mask) 

234 watermark_region_mask = torch.stack(watermark_region_mask).to(self.config.device) # [C, 64, 64] 

235 # self.mask = watermark_region_mask 

236 

237 single_channel_num_slots = self.config.radius - self.config.radius_cutoff 

238 key_value_list = [[list(combo) for combo in itertools.product( 

239 np.linspace(-self.config.ring_value_range, self.config.ring_value_range, self.config.quantization_levels).tolist(), 

240 repeat=len(self.config.ring_watermark_channel))] for _ in range(single_channel_num_slots)] 

241 key_value_combinations = list(itertools.product(*key_value_list)) 

242 

243 # random select from all possible value combinations, then generate patterns for selected ones. 

244 if self.config.assigned_keys > 0: 

245 assert self.config.assigned_keys <= len(key_value_combinations) 

246 key_value_combinations = random.sample(key_value_combinations, k=self.config.assigned_keys) 

247 Fourier_watermark_pattern_list = [self._make_Fourier_ringid_pattern(self.config.device, list(combo), base_latents, 

248 radius=self.config.radius, radius_cutoff=self.config.radius_cutoff, 

249 ring_watermark_channel=self.config.ring_watermark_channel, 

250 heter_watermark_channel=self.config.heter_watermark_channel, 

251 heter_watermark_region_mask=heter_watermark_region_mask if len( 

252 self.config.heter_watermark_channel) > 0 else None) 

253 for _, combo in enumerate(key_value_combinations)] 

254 ring_capacity = len(Fourier_watermark_pattern_list) 

255 #print(ring_capacity) 

256 

257 if self.config.fix_gt: 

258 Fourier_watermark_pattern_list = [self.fft(self.ifft(Fourier_watermark_pattern).real) for Fourier_watermark_pattern in 

259 Fourier_watermark_pattern_list] 

260 

261 if self.config.time_shift: 

262 for Fourier_watermark_pattern in Fourier_watermark_pattern_list: 

263 # Fourier_watermark_pattern[:, RING_WATERMARK_CHANNEL, ...] = fft(torch.fft.fftshift(ifft(Fourier_watermark_pattern[:, RING_WATERMARK_CHANNEL, ...]), dim = (-1, -2)) * args.time_shift_factor) 

264 Fourier_watermark_pattern[:, self.config.ring_watermark_channel, ...] = self.fft( 

265 torch.fft.fftshift(self.ifft(Fourier_watermark_pattern[:, self.config.ring_watermark_channel, ...]), dim=(-1, -2))) 

266 

267 # self.pattern_list = Fourier_watermark_pattern_list 

268 # Use a single ring pattern for verification 

269 Fourier_watermark_pattern = Fourier_watermark_pattern_list[ 

270 -1] # [64, -64, 64, -64, 64...], select this ring pattern 

271 # self.pattern = Fourier_watermark_pattern 

272 return base_latents, Fourier_watermark_pattern, watermark_region_mask, Fourier_watermark_pattern_list 

273 

274 

275 def generate_Fourier_watermark_latents(self, device, radius, radius_cutoff, watermark_region_mask, watermark_channel, 

276 original_latents=None, watermark_pattern=None): 

277 

278 # set_random_seed(seed) 

279 

280 if original_latents is None: 

281 # original_latents = torch.randn(*shape, device = device) 

282 raise NotImplementedError('Original latents should be provided.') 

283 

284 if watermark_pattern is None: 

285 raise NotImplementedError('Fourier watermark pattern should be provided.') 

286 

287 # circular_mask = torch.tensor(_ring_mask(size = original_latents.shape[-1], r_out = radius, r_in = radius_cutoff)).to(device) 

288 watermarked_latents_fft = torch.fft.fftshift(torch.fft.fft2(original_latents), dim=(-1, -2)) 

289 

290 # for channel in watermark_channel: 

291 # watermarked_latents_fft[:, channel] = watermarked_latents_fft[:, channel] * ~circular_mask + watermark_pattern[:, channel] * circular_mask 

292 

293 assert len(watermark_channel) == len(watermark_region_mask) 

294 for channel, channel_mask in zip(watermark_channel, watermark_region_mask): 

295 watermarked_latents_fft[:, channel] = watermarked_latents_fft[:, 

296 channel] * ~channel_mask + watermark_pattern[:, 

297 channel] * channel_mask 

298 

299 return torch.fft.ifft2(torch.fft.ifftshift(watermarked_latents_fft, dim=(-1, -2))).real 

300 

301 

302class RI(BaseWatermark): 

303 """RI watermarking algorithm.""" 

304 

305 def __init__(self, 

306 watermark_config: RIConfig, 

307 *args, **kwargs): 

308 """ 

309 Initialize the RI algorithm. 

310 

311 Parameters: 

312 watermark_config (RIConfig): Configuration instance of the RI algorithm. 

313 """ 

314 self.config = watermark_config 

315 self.utils = RIUtils(self.config) 

316 

317 self.detector = RIDetector( 

318 watermarking_mask=self.utils.mask, 

319 ring_watermark_channel=self.config.ring_watermark_channel, 

320 heter_watermark_channel=self.config.heter_watermark_channel, 

321 pattern_list=self.utils.pattern_list, 

322 threshold=self.config.threshold, 

323 device=self.config.device 

324 ) 

325 

326 def _generate_watermarked_image(self, prompt: str, *args, 

327 **kwargs) -> Image.Image: 

328 """Generate an image with a watermarked latent representation.""" 

329 watermarked_latents = self.utils.generate_Fourier_watermark_latents( 

330 device=self.config.device, 

331 radius=self.config.radius, 

332 radius_cutoff=self.config.radius_cutoff, 

333 original_latents= self.utils.latents, 

334 watermark_pattern= self.utils.pattern, 

335 watermark_channel=self.config.watermark_channel, 

336 watermark_region_mask=self.utils.mask, 

337 ).to(torch.float32) 

338 

339 # save watermarked latents 

340 self.set_orig_watermarked_latents(watermarked_latents) 

341 

342 # Set gen seed 

343 set_random_seed(self.config.gen_seed) 

344 

345 # Construct generation parameters 

346 generation_params = { 

347 "num_images_per_prompt": self.config.num_images, 

348 "guidance_scale": self.config.guidance_scale, 

349 "num_inference_steps": self.config.num_inference_steps, 

350 "height": self.config.image_size[0], 

351 "width": self.config.image_size[1], 

352 "latents": watermarked_latents, 

353 } 

354 

355 # Add parameters from config.gen_kwargs 

356 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs: 

357 for key, value in self.config.gen_kwargs.items(): 

358 if key not in generation_params: 

359 generation_params[key] = value 

360 

361 # Use kwargs to override default parameters 

362 for key, value in kwargs.items(): 

363 generation_params[key] = value 

364 

365 # Ensure latents parameter is not overridden 

366 generation_params["latents"] = watermarked_latents 

367 

368 return self.config.pipe( 

369 prompt, 

370 **generation_params 

371 ).images[0] 

372 

373 def _detect_watermark_in_image(self, 

374 image: Image.Image, 

375 prompt: str = "", 

376 *args, 

377 **kwargs) -> Dict[str, float]: 

378 """Detect the watermark in the image.""" 

379 # Use config values as defaults if not explicitly provided 

380 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

381 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps) 

382 

383 # Step 1: Get Text Embeddings 

384 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

385 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

386 prompt=prompt, 

387 device=self.config.device, 

388 do_classifier_free_guidance=do_classifier_free_guidance, 

389 num_images_per_prompt=1, 

390 ) 

391 

392 if do_classifier_free_guidance: 

393 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

394 else: 

395 text_embeddings = prompt_embeds 

396 

397 # Step 2: Preprocess Image 

398 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

399 

400 # Step 3: Get Image Latents 

401 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get('decoder_inv', False)) 

402 

403 # Step 4: Reverse Image Latents 

404 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters 

405 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']} 

406 

407 reversed_latents = self.config.inversion.forward_diffusion( 

408 latents=image_latents, 

409 text_embeddings=text_embeddings, 

410 guidance_scale=guidance_scale_to_use, 

411 num_inference_steps=num_steps_to_use, 

412 **inversion_kwargs 

413 )[-1] 

414 

415 if 'detector_type' in kwargs and 'mode' in kwargs: 

416 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type'], mode=kwargs['mode']) 

417 elif 'detector_type' in kwargs: 

418 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type']) 

419 elif 'mode' in kwargs: 

420 return self.detector.eval_watermark(reversed_latents, mode=kwargs['mode']) 

421 else: 

422 return self.detector.eval_watermark(reversed_latents) 

423 

424 def get_data_for_visualize(self, 

425 image: Image.Image = None, 

426 prompt: str = "", 

427 guidance_scale: float = 1, 

428 decoder_inv: bool = False, 

429 *args, 

430 **kwargs) -> DataForVisualization: 

431 """ 

432 Collect data for visualization of the RingID watermarking process. 

433  

434 Returns a DataForVisualization object containing all necessary data for RIVisualizer. 

435 """ 

436 # Use config values as defaults if not explicitly provided 

437 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

438 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps) 

439 

440 # Step 1: Get Text Embeddings 

441 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

442 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

443 prompt=prompt, 

444 device=self.config.device, 

445 do_classifier_free_guidance=do_classifier_free_guidance, 

446 num_images_per_prompt=1, 

447 ) 

448 

449 if do_classifier_free_guidance: 

450 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

451 else: 

452 text_embeddings = prompt_embeds 

453 

454 # Step 2: Preprocess Image 

455 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

456 

457 # Step 3: Get Image Latents 

458 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get('decoder_inv', False)) 

459 

460 # Step 4: Reverse Image Latents 

461 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters 

462 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']} 

463 

464 reversed_latents = self.config.inversion.forward_diffusion( 

465 latents=image_latents, 

466 text_embeddings=text_embeddings, 

467 guidance_scale=guidance_scale_to_use, 

468 num_inference_steps=num_steps_to_use, 

469 **inversion_kwargs 

470 ) 

471 

472 # Step 4: Create DataForVisualization object with extended attributes for RI 

473 data = DataForVisualization( 

474 config=self.config, 

475 utils=self.utils, 

476 image=image, 

477 reversed_latents=reversed_latents, 

478 orig_watermarked_latents=self.orig_watermarked_latents, 

479 ) 

480 

481 return data