Coverage for watermark / tr / tr.py: 91.67%

144 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from ..base import BaseWatermark, BaseConfig 

2from utils.media_utils import * 

3import torch 

4from typing import Dict, Union, List, Optional 

5from utils.utils import set_random_seed, inherit_docstring 

6from utils.diffusion_config import DiffusionConfig 

7import copy 

8import numpy as np 

9from PIL import Image 

10from visualize.data_for_visualization import DataForVisualization 

11from detection.tr.tr_detection import TRDetector 

12 

13class TRConfig(BaseConfig): 

14 """Config class for TR algorithm, load config file and initialize parameters.""" 

15 

16 def initialize_parameters(self) -> None: 

17 """Initialize algorithm-specific parameters.""" 

18 self.w_seed = self.config_dict['w_seed'] 

19 self.w_channel = self.config_dict['w_channel'] 

20 self.w_pattern = self.config_dict['w_pattern'] 

21 # self.w_mask_shape = self.config_dict['w_mask_shape'] 

22 self.w_radius = self.config_dict['w_radius'] 

23 self.w_pattern_const = self.config_dict['w_pattern_const'] 

24 self.threshold = self.config_dict['threshold'] 

25 

26 @property 

27 def algorithm_name(self) -> str: 

28 """Return the algorithm name.""" 

29 return 'TR' 

30 

31class TRUtils: 

32 """Utility class for TR algorithm, contains helper functions.""" 

33 

34 def __init__(self, config: TRConfig, *args, **kwargs) -> None: 

35 """ 

36 Initialize the Tree-Ring watermarking algorithm. 

37  

38 Parameters: 

39 config (TRConfig): Configuration for the Tree-Ring algorithm. 

40 """ 

41 self.config = config 

42 self.gt_patch = self._get_watermarking_pattern() 

43 self.watermarking_mask = self._get_watermarking_mask(self.config.init_latents) 

44 

45 def _circle_mask(self, size: int=64, r: int=10, x_offset: int=0, y_offset: int=0) -> np.ndarray: 

46 """Generate a circular mask.""" 

47 x0 = y0 = size // 2 

48 x0 += x_offset 

49 y0 += y_offset 

50 y, x = np.ogrid[:size, :size] 

51 y = y[::-1] 

52 

53 return ((x - x0)**2 + (y-y0)**2)<= r**2 

54 

55 def _get_watermarking_pattern(self) -> torch.Tensor: 

56 """Get the ground truth watermarking pattern.""" 

57 set_random_seed(self.config.w_seed) 

58 

59 gt_init = get_random_latents(pipe=self.config.pipe, height=self.config.image_size[0], width=self.config.image_size[1]) 

60 

61 if 'seed_ring' in self.config.w_pattern: 

62 gt_patch = gt_init 

63 

64 gt_patch_tmp = copy.deepcopy(gt_patch) 

65 for i in range(self.config.w_radius, 0, -1): 

66 tmp_mask = self._circle_mask(gt_init.shape[-1], r=i) 

67 tmp_mask = torch.tensor(tmp_mask).to(self.config.device) 

68 

69 for j in range(gt_patch.shape[1]): 

70 gt_patch[:, j, tmp_mask] = gt_patch_tmp[0, j, 0, i].item() 

71 elif 'seed_zeros' in self.config.w_pattern: 

72 gt_patch = gt_init * 0 

73 elif 'seed_rand' in self.config.w_pattern: 

74 gt_patch = gt_init 

75 elif 'rand' in self.config.w_pattern: 

76 gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) 

77 gt_patch[:] = gt_patch[0] 

78 elif 'zeros' in self.config.w_pattern: 

79 gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) * 0 

80 elif 'const' in self.config.w_pattern: 

81 gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) * 0 

82 gt_patch += self.config.w_pattern_const 

83 elif 'ring' in self.config.w_pattern: 

84 gt_patch = torch.fft.fftshift(torch.fft.fft2(gt_init), dim=(-1, -2)) 

85 

86 gt_patch_tmp = copy.deepcopy(gt_patch) 

87 for i in range(self.config.w_radius, 0, -1): 

88 tmp_mask = self._circle_mask(gt_init.shape[-1], r=i) 

89 tmp_mask = torch.tensor(tmp_mask).to(self.config.device) 

90 

91 for j in range(gt_patch.shape[1]): 

92 gt_patch[:, j, tmp_mask] = gt_patch_tmp[0, j, 0, i].item() 

93 

94 return gt_patch 

95 

96 def _get_watermarking_mask(self, init_latents: torch.Tensor) -> torch.Tensor: 

97 """Get the watermarking mask.""" 

98 watermarking_mask = torch.zeros(init_latents.shape, dtype=torch.bool).to(self.config.device) 

99 

100 # if self.config.w_mask_shape == 'circle': 

101 np_mask = self._circle_mask(init_latents.shape[-1], r=self.config.w_radius) 

102 torch_mask = torch.tensor(np_mask).to(self.config.device) 

103 

104 if self.config.w_channel == -1: 

105 # all channels 

106 watermarking_mask[:, :] = torch_mask 

107 else: 

108 watermarking_mask[:, self.config.w_channel] = torch_mask 

109 # elif self.config.w_mask_shape == 'square': 

110 # anchor_p = init_latents.shape[-1] // 2 

111 # if self.config.w_channel == -1: 

112 # # all channels 

113 # watermarking_mask[:, :, anchor_p-self.config.w_radius:anchor_p+self.config.w_radius, anchor_p-self.config.w_radius:anchor_p+self.config.w_radius] = True 

114 # else: 

115 # watermarking_mask[:, self.config.w_channel, anchor_p-self.config.w_radius:anchor_p+self.config.w_radius, anchor_p-self.config.w_radius:anchor_p+self.config.w_radius] = True 

116 # elif self.config.w_mask_shape == 'no': 

117 # pass 

118 # else: 

119 # raise NotImplementedError(f'w_mask_shape: {self.config.w_mask_shape}') 

120 

121 return watermarking_mask 

122 

123 def inject_watermark(self, init_latents: torch.Tensor) -> torch.Tensor: 

124 init_latents_w_fft = torch.fft.fftshift(torch.fft.fft2(init_latents), dim=(-1, -2)) 

125 target_patch = self.gt_patch 

126 

127 if not torch.is_complex(target_patch): 

128 real = target_patch.to(torch.float32) 

129 imag = torch.zeros_like(real) 

130 target_patch = torch.complex(real, imag) 

131 target_patch = target_patch.to(init_latents_w_fft.dtype) 

132 

133 init_latents_w_fft[self.watermarking_mask] = target_patch[self.watermarking_mask].clone() 

134 

135 init_latents_w = torch.fft.ifft2(torch.fft.ifftshift(init_latents_w_fft, dim=(-1, -2))).real 

136 return init_latents_w 

137 

138@inherit_docstring 

139class TR(BaseWatermark): 

140 def __init__(self, 

141 watermark_config: TRConfig, 

142 *args, **kwargs): 

143 """ 

144 Initialize the TR watermarking algorithm. 

145  

146 Parameters: 

147 watermark_config (TRConfig): Configuration instance of the Tree-Ring algorithm. 

148 """ 

149 self.config = watermark_config 

150 self.utils = TRUtils(self.config) 

151 

152 self.detector = TRDetector( 

153 watermarking_mask=self.utils.watermarking_mask, 

154 gt_patch=self.utils.gt_patch, 

155 threshold=self.config.threshold, 

156 device=self.config.device 

157 ) 

158 

159 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image: 

160 """Internal method to generate a watermarked image.""" 

161 watermarked_latents = self.utils.inject_watermark(self.config.init_latents) 

162 

163 # save watermarked latents 

164 self.set_orig_watermarked_latents(watermarked_latents) 

165 

166 # Construct generation parameters 

167 generation_params = { 

168 "num_images_per_prompt": self.config.num_images, 

169 "guidance_scale": self.config.guidance_scale, 

170 "num_inference_steps": self.config.num_inference_steps, 

171 "height": self.config.image_size[0], 

172 "width": self.config.image_size[1], 

173 "latents": watermarked_latents, 

174 } 

175 

176 # Add parameters from config.gen_kwargs 

177 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs: 

178 for key, value in self.config.gen_kwargs.items(): 

179 if key not in generation_params: 

180 generation_params[key] = value 

181 

182 # Use kwargs to override default parameters 

183 for key, value in kwargs.items(): 

184 generation_params[key] = value 

185 

186 # Ensure latents parameter is not overridden 

187 generation_params["latents"] = watermarked_latents 

188 

189 return self.config.pipe( 

190 prompt, 

191 **generation_params 

192 ).images[0] 

193 

194 def _detect_watermark_in_image(self, 

195 image: Image.Image, 

196 prompt: str = "", 

197 *args, 

198 **kwargs) -> Dict[str, float]: 

199 """Detect the watermark in the image.""" 

200 # Use config values as defaults if not explicitly provided 

201 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

202 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps) 

203 

204 # Step 1: Get Text Embeddings 

205 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

206 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

207 prompt=prompt, 

208 device=self.config.device, 

209 do_classifier_free_guidance=do_classifier_free_guidance, 

210 num_images_per_prompt=1, # TODO: Multiple image generation to be supported 

211 ) 

212 

213 if do_classifier_free_guidance: 

214 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

215 else: 

216 text_embeddings = prompt_embeds 

217 

218 # Step 2: Preprocess Image 

219 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

220 

221 # Step 3: Get Image Latents 

222 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get('decoder_inv', False)) 

223 

224 # Step 4: Reverse Image Latents 

225 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters 

226 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']} 

227 

228 reversed_latents = self.config.inversion.forward_diffusion( 

229 latents=image_latents, 

230 text_embeddings=text_embeddings, 

231 guidance_scale=guidance_scale_to_use, 

232 num_inference_steps=num_steps_to_use, 

233 **inversion_kwargs 

234 )[-1] 

235 

236 # Step 5: Evaluate Watermark 

237 if 'detector_type' in kwargs: 

238 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type']) 

239 else: 

240 return self.detector.eval_watermark(reversed_latents) 

241 

242 def get_data_for_visualize(self, 

243 image: Image.Image, 

244 prompt: str="", 

245 guidance_scale: Optional[float]=None, 

246 decoder_inv: bool=False, 

247 *args, 

248 **kwargs) -> DataForVisualization: 

249 """Get data for visualization including detection inversion - similar to GS logic.""" 

250 # Use config values as defaults if not explicitly provided 

251 guidance_scale_to_use = guidance_scale if guidance_scale is not None else self.config.guidance_scale 

252 

253 # Step 1: Generate watermarked latents (generation process) 

254 set_random_seed(self.config.gen_seed) 

255 watermarked_latents = self.utils.inject_watermark(self.config.init_latents) 

256 

257 # Step 2: Generate actual watermarked image using the same process as _generate_watermarked_image 

258 generation_params = { 

259 "num_images_per_prompt": self.config.num_images, 

260 "guidance_scale": self.config.guidance_scale, 

261 "num_inference_steps": self.config.num_inference_steps, 

262 "height": self.config.image_size[0], 

263 "width": self.config.image_size[1], 

264 "latents": watermarked_latents, 

265 } 

266 

267 # Add parameters from config.gen_kwargs 

268 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs: 

269 for key, value in self.config.gen_kwargs.items(): 

270 if key not in generation_params: 

271 generation_params[key] = value 

272 

273 # Generate the actual watermarked image 

274 watermarked_image = self.config.pipe( 

275 prompt, 

276 **generation_params 

277 ).images[0] 

278 

279 # Step 3: Perform watermark detection to get inverted latents (detection process) 

280 inverted_latents = None 

281 try: 

282 # Get Text Embeddings for detection 

283 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

284 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

285 prompt=prompt, 

286 device=self.config.device, 

287 do_classifier_free_guidance=do_classifier_free_guidance, 

288 num_images_per_prompt=1, # TODO: Multiple image generation to be supported 

289 ) 

290 

291 if do_classifier_free_guidance: 

292 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

293 else: 

294 text_embeddings = prompt_embeds 

295 

296 # Preprocess watermarked image for detection 

297 processed_image = transform_to_model_format( 

298 watermarked_image, 

299 target_size=self.config.image_size[0] 

300 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

301 

302 # Get Image Latents 

303 image_latents = get_media_latents( 

304 pipe=self.config.pipe, 

305 media=processed_image, 

306 sample=False, 

307 decoder_inv=decoder_inv 

308 ) 

309 

310 # Reverse Image Latents to get inverted noise 

311 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['prompt', 'decoder_inv', 'guidance_scale', 'num_inference_steps']} 

312 

313 reversed_latents_list = self.config.inversion.forward_diffusion( 

314 latents=image_latents, 

315 text_embeddings=text_embeddings, 

316 guidance_scale=guidance_scale_to_use, 

317 num_inference_steps=self.config.num_inference_steps, 

318 **inversion_kwargs 

319 ) 

320 

321 inverted_latents = reversed_latents_list[-1] 

322 

323 except Exception as e: 

324 print(f"Warning: Could not perform inversion for visualization: {e}") 

325 inverted_latents = None 

326 

327 # Step 4: Prepare visualization data  

328 return DataForVisualization( 

329 config=self.config, 

330 utils=self.utils, 

331 reversed_latents=reversed_latents_list, 

332 orig_watermarked_latents=self.orig_watermarked_latents, 

333 image=image, 

334 ) 

335# try tr