Coverage for watermark / seal / seal.py: 93.38%

136 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16from ..base import BaseWatermark, BaseConfig 

17import torch 

18from typing import List 

19from utils.utils import set_random_seed 

20from utils.diffusion_config import DiffusionConfig 

21from visualize.data_for_visualization import DataForVisualization 

22from transformers import Blip2Processor, Blip2ForConditionalGeneration 

23from sentence_transformers import SentenceTransformer 

24from PIL import Image 

25import math 

26from detection.seal.seal_detection import SEALDetector 

27from utils.media_utils import * 

28 

29class SEALConfig(BaseConfig): 

30 """Config class for SEAL algorithm.""" 

31 

32 def initialize_parameters(self) -> None: 

33 """Initialize parameters for SEAL algorithm.""" 

34 self.k_value = self.config_dict['k_value'] 

35 self.b_value = self.config_dict['b_value'] 

36 self.patch_distance_threshold = self.config_dict['patch_distance_threshold'] 

37 self.theta_mid = self.config_dict['theta_mid'] 

38 self.cap_processor = Blip2Processor.from_pretrained(self.config_dict['cap_processor']) 

39 self.cap_model = Blip2ForConditionalGeneration.from_pretrained(self.config_dict['cap_processor'], torch_dtype=torch.float16).to(self.device) 

40 self.sentence_model = SentenceTransformer(self.config_dict['sentence_model']).to(self.device) 

41 

42 self.secret_salt = self.config_dict['secret_salt'] 

43 

44 @property 

45 def algorithm_name(self) -> str: 

46 """Return the name of the algorithm.""" 

47 return "SEAL" 

48 

49class SEALUtils: 

50 """Utility class for SEAL algorithm.""" 

51 

52 def __init__(self, config: SEALConfig, *args, **kwargs) -> None: 

53 """Initialize SEAL utility.""" 

54 self.config = config 

55 

56 def _simhash(self, v: torch.Tensor, k: int, b: int, seed: int) -> List[int]: 

57 """ 

58 SimHash algorithm to generate hash keys for an embedding vector. 

59  

60 Args: 

61 v: Input embedding vector 

62 k: Number of patches 

63 b: Number of bits per patch 

64 seed: Random seed 

65  

66 Returns: 

67 List of hash keys 

68 """ 

69 

70 keys = [] 

71 set_random_seed(seed) 

72 for j in range(k): 

73 bits = [0] * b 

74 for i in range(b): 

75 r_i = torch.randn_like(v) 

76 bits[i] = 1 if torch.dot(r_i, v) > 0 else 0 

77 bits[i] = (bits[i] + i + j * 1e4) 

78 hash_value = hash(tuple(bits)) 

79 keys.append(hash_value) 

80 return keys 

81 

82 def generate_caption(self, image: Image.Image) -> str: 

83 """ 

84 Generate caption for an image. 

85  

86 Args: 

87 image: PIL Image object 

88  

89 Returns: 

90 Caption string 

91 """ 

92 raw_image = image.convert('RGB') 

93 inputs = self.config.cap_processor(raw_image, return_tensors="pt") 

94 inputs = {k: v.to(self.config.device) for k, v in inputs.items()} 

95 out = self.config.cap_model.generate(**inputs) 

96 return self.config.cap_processor.decode(out[0], skip_special_tokens=True) 

97 

98 def generate_initial_noise(self, embedding: torch.Tensor, k: int, b: int, seed: int) -> torch.Tensor: 

99 """ 

100 Generates initial noise using improved simhash approach. 

101  

102 Args: 

103 embedding: Input embedding vector(Nomalized) 

104 k: k_value(Patch number) 

105 b: b_value(Bit number per patch) 

106 seed: Random seed(secret_salt) 

107 Returns: 

108 Noise tensor with shape [1, 4, 64, 64] 

109 """ 

110 

111 # Get latent dimensions from config 

112 latent_height = self.config.image_size[0] // 8 

113 latent_width = self.config.image_size[1] // 8 

114 

115 # Calculate patch grid dimensions 

116 patch_per_side = int(math.ceil(math.sqrt(k))) 

117 

118 # Generate hash keys for the embedding 

119 keys = self._simhash(embedding, k, b, seed) 

120 

121 # Create empty noise tensor 

122 initial_noise = torch.zeros(1, 4, latent_height, latent_width, device=self.config.device) 

123 

124 # Calculate patch dimensions 

125 patch_height = max(1, latent_height // patch_per_side) 

126 patch_width = max(1, latent_width // patch_per_side) 

127 

128 # Fill noise tensor with random patches based on hash keys 

129 patch_count = 0 

130 hash_mapping = {} 

131 

132 for i in range(patch_per_side): 

133 for j in range(patch_per_side): 

134 if patch_count >= k: 

135 break 

136 

137 # Get hash key for this patch 

138 hash_key = keys[patch_count] 

139 hash_mapping[patch_count] = hash_key 

140 

141 # Set random seed based on hash key 

142 set_random_seed(hash_key) 

143 

144 # Calculate patch coordinates 

145 y_start = i * patch_height 

146 x_start = j * patch_width 

147 y_end = min(y_start + patch_height, latent_height) 

148 x_end = min(x_start + patch_width, latent_width) 

149 

150 # Skip if patch is empty (can happen if grid is larger than latent dims) 

151 if y_end <= y_start or x_end <= x_start: 

152 continue 

153 

154 # Generate random noise for this patch 

155 initial_noise[:, :, y_start:y_end, x_start:x_end] = torch.randn( 

156 (1, 4, y_end - y_start, x_end - x_start), 

157 device=self.config.device 

158 ) 

159 

160 patch_count += 1 

161 

162 return initial_noise 

163 

164class SEAL(BaseWatermark): 

165 """SEAL watermarking algorithm.""" 

166 

167 def __init__(self, watermark_config: SEALConfig, *args, **kwargs) -> None: 

168 """ 

169 Initialize the SEAL algorithm. 

170 

171 Parameters: 

172 watermark_config (SEALConfig): Configuration instance of the SEAL algorithm. 

173 """ 

174 self.config = watermark_config 

175 self.utils = SEALUtils(self.config) 

176 self.original_embedding = None 

177 

178 self.detector = SEALDetector(self.config.k_value, self.config.b_value, self.config.theta_mid, self.config.cap_processor, self.config.cap_model, self.config.sentence_model, self.config.patch_distance_threshold, self.config.device) 

179 

180 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image: 

181 """Generate watermarked image.""" 

182 

183 ## Step 1: Generate original image 

184 image = self.config.pipe(prompt, height=self.config.image_size[0], width=self.config.image_size[1]).images[0] 

185 

186 ## Step 2: Caption the original image 

187 image_caption = self.utils.generate_caption(image) 

188 

189 ## Step 3: Get the embedding of the caption 

190 embedding = self.config.sentence_model.encode(image_caption, convert_to_tensor=True).to(self.config.device) 

191 embedding = embedding / torch.norm(embedding) 

192 self.original_embedding = embedding 

193 

194 ## Step 4: Get the watermarked initial latents 

195 watermarked_latents = self.utils.generate_initial_noise(embedding, self.config.k_value, self.config.b_value, self.config.secret_salt) 

196 

197 # save watermarked latents 

198 self.set_orig_watermarked_latents(watermarked_latents) 

199 

200 # Set gen seed 

201 set_random_seed(self.config.gen_seed) 

202 

203 # Construct generation parameters 

204 generation_params = { 

205 "num_images_per_prompt": self.config.num_images, 

206 "guidance_scale": self.config.guidance_scale, 

207 "num_inference_steps": self.config.num_inference_steps, 

208 "height": self.config.image_size[0], 

209 "width": self.config.image_size[1], 

210 "latents": watermarked_latents, 

211 } 

212 

213 # Add parameters from config.gen_kwargs 

214 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs: 

215 for key, value in self.config.gen_kwargs.items(): 

216 if key not in generation_params: 

217 generation_params[key] = value 

218 

219 # Use kwargs to override default parameters 

220 for key, value in kwargs.items(): 

221 generation_params[key] = value 

222 

223 # Ensure latents parameter is not overridden 

224 generation_params["latents"] = watermarked_latents 

225 

226 ## Step 5: Generate watermarked image 

227 watermarked_image = self.config.pipe(prompt, **generation_params).images[0] 

228 

229 return watermarked_image 

230 

231 def _detect_watermark_in_image(self, 

232 image: Image.Image, 

233 prompt: str="", 

234 *args, 

235 **kwargs) -> bool: 

236 """Detect watermark in the image.""" 

237 # Use config values as defaults if not explicitly provided 

238 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

239 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps) 

240 

241 # Step 1: Get Text Embeddings 

242 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

243 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

244 prompt=prompt, 

245 device=self.config.device, 

246 do_classifier_free_guidance=do_classifier_free_guidance, 

247 num_images_per_prompt=1, 

248 ) 

249 

250 if do_classifier_free_guidance: 

251 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

252 else: 

253 text_embeddings = prompt_embeds 

254 

255 # Step 2: Preprocess Image 

256 image_tensor = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

257 image_tensor = image_tensor.to(dtype=self.config.pipe.vae.dtype) 

258 

259 # Step 3: Get Image Latents 

260 image_latents = get_media_latents(pipe=self.config.pipe, media=image_tensor, sample=False, decoder_inv=kwargs.get('decoder_inv', False)) 

261 

262 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters 

263 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']} 

264 

265 # Step 4: Reverse Image Latents 

266 reversed_latents = self.config.inversion.forward_diffusion( 

267 latents=image_latents, 

268 text_embeddings=text_embeddings, 

269 guidance_scale=guidance_scale_to_use, 

270 num_inference_steps=num_steps_to_use, 

271 **inversion_kwargs 

272 )[-1] 

273 

274 # Step 5: Generate noise z(~) again from the inspected image's caption 

275 inspected_caption = self.utils.generate_caption(image) 

276 inspected_embedding = self.config.sentence_model.encode(inspected_caption, convert_to_tensor=True).to(self.config.device) 

277 inspected_embedding = inspected_embedding / torch.norm(inspected_embedding) 

278 

279 inspected_noise = self.utils.generate_initial_noise(inspected_embedding, self.config.k_value, self.config.b_value, self.config.secret_salt) 

280 

281 # Detect the watermark 

282 if 'detector_type' in kwargs: 

283 return self.detector.eval_watermark(reversed_latents, inspected_noise, detector_type=kwargs['detector_type']) 

284 else: 

285 return self.detector.eval_watermark(reversed_latents, inspected_noise) 

286 

287 

288 def get_data_for_visualize(self, 

289 image: Image.Image, 

290 prompt: str="", 

291 guidance_scale: float=1, 

292 decoder_inv: bool=False, 

293 *args, **kwargs) -> DataForVisualization: 

294 """ 

295 Get data for visualization of SEAL watermarking process. 

296  

297 Args: 

298 image: Input image for visualization 

299 prompt: Text prompt used for generation 

300 guidance_scale: Guidance scale for diffusion process 

301 decoder_inv: Whether to use decoder inversion 

302  

303 Returns: 

304 DataForVisualization object with SEAL-specific data 

305 """ 

306 # Step 1: Perform detection process for comparison 

307 inspected_caption = self.utils.generate_caption(image) 

308 inspected_embedding = self.config.sentence_model.encode(inspected_caption, convert_to_tensor=True).to(self.config.device) 

309 normalized_inspected_embedding = inspected_embedding / torch.norm(inspected_embedding) 

310 

311 # Generate inspected noise for detection 

312 inspected_noise = self.utils.generate_initial_noise(normalized_inspected_embedding, self.config.k_value, self.config.b_value, self.config.secret_salt) 

313 

314 # Step 2: Get inverted latents for detection visualization 

315 do_classifier_free_guidance = (guidance_scale > 1.0) 

316 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

317 prompt=prompt, 

318 device=self.config.device, 

319 do_classifier_free_guidance=do_classifier_free_guidance, 

320 num_images_per_prompt=1, 

321 ) 

322 

323 if do_classifier_free_guidance: 

324 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

325 else: 

326 text_embeddings = prompt_embeds 

327 image_tensor = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

328 image_tensor = image_tensor.to(dtype=self.config.pipe.vae.dtype) 

329 image_latents = get_media_latents(pipe=self.config.pipe, media=image_tensor, sample=False, decoder_inv=decoder_inv) 

330 

331 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']} 

332 reversed_latents = self.config.inversion.forward_diffusion( 

333 latents=image_latents, 

334 text_embeddings=text_embeddings, 

335 guidance_scale=guidance_scale, 

336 num_inference_steps=self.config.num_inference_steps, 

337 **inversion_kwargs 

338 ) 

339 

340 # Get original watermarked latents  

341 orig_watermarked_latents = self.get_orig_watermarked_latents() 

342 

343 # Get original embedding  

344 original_embedding = self.original_embedding 

345 

346 # Create DataForVisualization object with SEAL-specific data 

347 data_for_visualization = DataForVisualization( 

348 config=self.config, 

349 utils=self.utils, 

350 orig_watermarked_latents=orig_watermarked_latents, 

351 reversed_latents=reversed_latents, 

352 inspected_embedding=normalized_inspected_embedding, 

353 original_embedding=original_embedding, 

354 reference_noise=inspected_noise, 

355 image=image, 

356 ) 

357 

358 return data_for_visualization 

359 

360