Coverage for watermark / seal / seal.py: 93.38%
136 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16from ..base import BaseWatermark, BaseConfig
17import torch
18from typing import List
19from utils.utils import set_random_seed
20from utils.diffusion_config import DiffusionConfig
21from visualize.data_for_visualization import DataForVisualization
22from transformers import Blip2Processor, Blip2ForConditionalGeneration
23from sentence_transformers import SentenceTransformer
24from PIL import Image
25import math
26from detection.seal.seal_detection import SEALDetector
27from utils.media_utils import *
29class SEALConfig(BaseConfig):
30 """Config class for SEAL algorithm."""
32 def initialize_parameters(self) -> None:
33 """Initialize parameters for SEAL algorithm."""
34 self.k_value = self.config_dict['k_value']
35 self.b_value = self.config_dict['b_value']
36 self.patch_distance_threshold = self.config_dict['patch_distance_threshold']
37 self.theta_mid = self.config_dict['theta_mid']
38 self.cap_processor = Blip2Processor.from_pretrained(self.config_dict['cap_processor'])
39 self.cap_model = Blip2ForConditionalGeneration.from_pretrained(self.config_dict['cap_processor'], torch_dtype=torch.float16).to(self.device)
40 self.sentence_model = SentenceTransformer(self.config_dict['sentence_model']).to(self.device)
42 self.secret_salt = self.config_dict['secret_salt']
44 @property
45 def algorithm_name(self) -> str:
46 """Return the name of the algorithm."""
47 return "SEAL"
49class SEALUtils:
50 """Utility class for SEAL algorithm."""
52 def __init__(self, config: SEALConfig, *args, **kwargs) -> None:
53 """Initialize SEAL utility."""
54 self.config = config
56 def _simhash(self, v: torch.Tensor, k: int, b: int, seed: int) -> List[int]:
57 """
58 SimHash algorithm to generate hash keys for an embedding vector.
60 Args:
61 v: Input embedding vector
62 k: Number of patches
63 b: Number of bits per patch
64 seed: Random seed
66 Returns:
67 List of hash keys
68 """
70 keys = []
71 set_random_seed(seed)
72 for j in range(k):
73 bits = [0] * b
74 for i in range(b):
75 r_i = torch.randn_like(v)
76 bits[i] = 1 if torch.dot(r_i, v) > 0 else 0
77 bits[i] = (bits[i] + i + j * 1e4)
78 hash_value = hash(tuple(bits))
79 keys.append(hash_value)
80 return keys
82 def generate_caption(self, image: Image.Image) -> str:
83 """
84 Generate caption for an image.
86 Args:
87 image: PIL Image object
89 Returns:
90 Caption string
91 """
92 raw_image = image.convert('RGB')
93 inputs = self.config.cap_processor(raw_image, return_tensors="pt")
94 inputs = {k: v.to(self.config.device) for k, v in inputs.items()}
95 out = self.config.cap_model.generate(**inputs)
96 return self.config.cap_processor.decode(out[0], skip_special_tokens=True)
98 def generate_initial_noise(self, embedding: torch.Tensor, k: int, b: int, seed: int) -> torch.Tensor:
99 """
100 Generates initial noise using improved simhash approach.
102 Args:
103 embedding: Input embedding vector(Nomalized)
104 k: k_value(Patch number)
105 b: b_value(Bit number per patch)
106 seed: Random seed(secret_salt)
107 Returns:
108 Noise tensor with shape [1, 4, 64, 64]
109 """
111 # Get latent dimensions from config
112 latent_height = self.config.image_size[0] // 8
113 latent_width = self.config.image_size[1] // 8
115 # Calculate patch grid dimensions
116 patch_per_side = int(math.ceil(math.sqrt(k)))
118 # Generate hash keys for the embedding
119 keys = self._simhash(embedding, k, b, seed)
121 # Create empty noise tensor
122 initial_noise = torch.zeros(1, 4, latent_height, latent_width, device=self.config.device)
124 # Calculate patch dimensions
125 patch_height = max(1, latent_height // patch_per_side)
126 patch_width = max(1, latent_width // patch_per_side)
128 # Fill noise tensor with random patches based on hash keys
129 patch_count = 0
130 hash_mapping = {}
132 for i in range(patch_per_side):
133 for j in range(patch_per_side):
134 if patch_count >= k:
135 break
137 # Get hash key for this patch
138 hash_key = keys[patch_count]
139 hash_mapping[patch_count] = hash_key
141 # Set random seed based on hash key
142 set_random_seed(hash_key)
144 # Calculate patch coordinates
145 y_start = i * patch_height
146 x_start = j * patch_width
147 y_end = min(y_start + patch_height, latent_height)
148 x_end = min(x_start + patch_width, latent_width)
150 # Skip if patch is empty (can happen if grid is larger than latent dims)
151 if y_end <= y_start or x_end <= x_start:
152 continue
154 # Generate random noise for this patch
155 initial_noise[:, :, y_start:y_end, x_start:x_end] = torch.randn(
156 (1, 4, y_end - y_start, x_end - x_start),
157 device=self.config.device
158 )
160 patch_count += 1
162 return initial_noise
164class SEAL(BaseWatermark):
165 """SEAL watermarking algorithm."""
167 def __init__(self, watermark_config: SEALConfig, *args, **kwargs) -> None:
168 """
169 Initialize the SEAL algorithm.
171 Parameters:
172 watermark_config (SEALConfig): Configuration instance of the SEAL algorithm.
173 """
174 self.config = watermark_config
175 self.utils = SEALUtils(self.config)
176 self.original_embedding = None
178 self.detector = SEALDetector(self.config.k_value, self.config.b_value, self.config.theta_mid, self.config.cap_processor, self.config.cap_model, self.config.sentence_model, self.config.patch_distance_threshold, self.config.device)
180 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image:
181 """Generate watermarked image."""
183 ## Step 1: Generate original image
184 image = self.config.pipe(prompt, height=self.config.image_size[0], width=self.config.image_size[1]).images[0]
186 ## Step 2: Caption the original image
187 image_caption = self.utils.generate_caption(image)
189 ## Step 3: Get the embedding of the caption
190 embedding = self.config.sentence_model.encode(image_caption, convert_to_tensor=True).to(self.config.device)
191 embedding = embedding / torch.norm(embedding)
192 self.original_embedding = embedding
194 ## Step 4: Get the watermarked initial latents
195 watermarked_latents = self.utils.generate_initial_noise(embedding, self.config.k_value, self.config.b_value, self.config.secret_salt)
197 # save watermarked latents
198 self.set_orig_watermarked_latents(watermarked_latents)
200 # Set gen seed
201 set_random_seed(self.config.gen_seed)
203 # Construct generation parameters
204 generation_params = {
205 "num_images_per_prompt": self.config.num_images,
206 "guidance_scale": self.config.guidance_scale,
207 "num_inference_steps": self.config.num_inference_steps,
208 "height": self.config.image_size[0],
209 "width": self.config.image_size[1],
210 "latents": watermarked_latents,
211 }
213 # Add parameters from config.gen_kwargs
214 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
215 for key, value in self.config.gen_kwargs.items():
216 if key not in generation_params:
217 generation_params[key] = value
219 # Use kwargs to override default parameters
220 for key, value in kwargs.items():
221 generation_params[key] = value
223 # Ensure latents parameter is not overridden
224 generation_params["latents"] = watermarked_latents
226 ## Step 5: Generate watermarked image
227 watermarked_image = self.config.pipe(prompt, **generation_params).images[0]
229 return watermarked_image
231 def _detect_watermark_in_image(self,
232 image: Image.Image,
233 prompt: str="",
234 *args,
235 **kwargs) -> bool:
236 """Detect watermark in the image."""
237 # Use config values as defaults if not explicitly provided
238 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
239 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps)
241 # Step 1: Get Text Embeddings
242 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
243 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
244 prompt=prompt,
245 device=self.config.device,
246 do_classifier_free_guidance=do_classifier_free_guidance,
247 num_images_per_prompt=1,
248 )
250 if do_classifier_free_guidance:
251 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
252 else:
253 text_embeddings = prompt_embeds
255 # Step 2: Preprocess Image
256 image_tensor = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
257 image_tensor = image_tensor.to(dtype=self.config.pipe.vae.dtype)
259 # Step 3: Get Image Latents
260 image_latents = get_media_latents(pipe=self.config.pipe, media=image_tensor, sample=False, decoder_inv=kwargs.get('decoder_inv', False))
262 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters
263 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']}
265 # Step 4: Reverse Image Latents
266 reversed_latents = self.config.inversion.forward_diffusion(
267 latents=image_latents,
268 text_embeddings=text_embeddings,
269 guidance_scale=guidance_scale_to_use,
270 num_inference_steps=num_steps_to_use,
271 **inversion_kwargs
272 )[-1]
274 # Step 5: Generate noise z(~) again from the inspected image's caption
275 inspected_caption = self.utils.generate_caption(image)
276 inspected_embedding = self.config.sentence_model.encode(inspected_caption, convert_to_tensor=True).to(self.config.device)
277 inspected_embedding = inspected_embedding / torch.norm(inspected_embedding)
279 inspected_noise = self.utils.generate_initial_noise(inspected_embedding, self.config.k_value, self.config.b_value, self.config.secret_salt)
281 # Detect the watermark
282 if 'detector_type' in kwargs:
283 return self.detector.eval_watermark(reversed_latents, inspected_noise, detector_type=kwargs['detector_type'])
284 else:
285 return self.detector.eval_watermark(reversed_latents, inspected_noise)
288 def get_data_for_visualize(self,
289 image: Image.Image,
290 prompt: str="",
291 guidance_scale: float=1,
292 decoder_inv: bool=False,
293 *args, **kwargs) -> DataForVisualization:
294 """
295 Get data for visualization of SEAL watermarking process.
297 Args:
298 image: Input image for visualization
299 prompt: Text prompt used for generation
300 guidance_scale: Guidance scale for diffusion process
301 decoder_inv: Whether to use decoder inversion
303 Returns:
304 DataForVisualization object with SEAL-specific data
305 """
306 # Step 1: Perform detection process for comparison
307 inspected_caption = self.utils.generate_caption(image)
308 inspected_embedding = self.config.sentence_model.encode(inspected_caption, convert_to_tensor=True).to(self.config.device)
309 normalized_inspected_embedding = inspected_embedding / torch.norm(inspected_embedding)
311 # Generate inspected noise for detection
312 inspected_noise = self.utils.generate_initial_noise(normalized_inspected_embedding, self.config.k_value, self.config.b_value, self.config.secret_salt)
314 # Step 2: Get inverted latents for detection visualization
315 do_classifier_free_guidance = (guidance_scale > 1.0)
316 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
317 prompt=prompt,
318 device=self.config.device,
319 do_classifier_free_guidance=do_classifier_free_guidance,
320 num_images_per_prompt=1,
321 )
323 if do_classifier_free_guidance:
324 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
325 else:
326 text_embeddings = prompt_embeds
327 image_tensor = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
328 image_tensor = image_tensor.to(dtype=self.config.pipe.vae.dtype)
329 image_latents = get_media_latents(pipe=self.config.pipe, media=image_tensor, sample=False, decoder_inv=decoder_inv)
331 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']}
332 reversed_latents = self.config.inversion.forward_diffusion(
333 latents=image_latents,
334 text_embeddings=text_embeddings,
335 guidance_scale=guidance_scale,
336 num_inference_steps=self.config.num_inference_steps,
337 **inversion_kwargs
338 )
340 # Get original watermarked latents
341 orig_watermarked_latents = self.get_orig_watermarked_latents()
343 # Get original embedding
344 original_embedding = self.original_embedding
346 # Create DataForVisualization object with SEAL-specific data
347 data_for_visualization = DataForVisualization(
348 config=self.config,
349 utils=self.utils,
350 orig_watermarked_latents=orig_watermarked_latents,
351 reversed_latents=reversed_latents,
352 inspected_embedding=normalized_inspected_embedding,
353 original_embedding=original_embedding,
354 reference_noise=inspected_noise,
355 image=image,
356 )
358 return data_for_visualization