Coverage for watermark / wind / wind.py: 93.80%
129 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16import torch
17import hashlib
18import numpy as np
19import logging
20from typing import Dict, Any, Union, List, Optional
21from PIL import Image
22from utils.media_utils import *
23from utils.utils import load_config_file, set_random_seed
24from utils.diffusion_config import DiffusionConfig
25from utils.media_utils import transform_to_model_format, get_media_latents
26from watermark.base import BaseConfig, BaseWatermark
27from exceptions.exceptions import AlgorithmNameMismatchError
28from detection.wind.wind_detection import WINDetector
29from visualize.data_for_visualization import DataForVisualization
31logger = logging.getLogger(__name__)
33class WINDConfig(BaseConfig):
35 def initialize_parameters(self) -> None:
37 self.w_seed = self.config_dict['w_seed']
38 self.N = self.config_dict['num_noises']
39 self.M = self.config_dict['num_groups']
40 self.secret_salt = self.config_dict['secret_salt'].encode()
41 self.hash_func = getattr(hashlib, self.config_dict['hash_function'])
42 self.group_radius = self.config_dict['group_radius']
43 self.threshold = self.config_dict['threshold']
44 self.current_index = self.config_dict['current_index']
45 self.noise_groups = self._precompute_noise_groups()
47 def _precompute_noise_groups(self):
48 groups = {}
49 for i in range(self.N):
50 g = i % self.M
51 if g not in groups:
52 groups[g] = []
53 seed = self._generate_seed(i)
54 groups[g].append(self._generate_noise(seed))
55 return groups
57 def _generate_seed(self, index: int) -> bytes:
58 """Generate seed"""
59 return self.hash_func(f"{index}{self.secret_salt}".encode()).digest()
61 def _generate_noise(self, seed: bytes) -> torch.Tensor:
62 """Generate noises from seeds"""
63 rng = np.random.RandomState(int.from_bytes(seed[:4], 'big'))
64 latent_height = self.image_size[0] // 8
65 latent_width = self.image_size[1] // 8
66 return torch.from_numpy(rng.randn(4, latent_height, latent_width)).float().to(self.device)
68 @property
69 def algorithm_name(self) -> str:
70 return 'WIND'
72class WINDUtils:
74 def __init__(self, config: WINDConfig):
75 self.config = config
76 self.group_patterns = self._generate_group_patterns()
77 self.original_noise = None
79 def _generate_group_patterns(self) -> Dict[int, torch.Tensor]:
80 set_random_seed(self.config.w_seed)
81 patterns = {}
82 latent_height = self.config.image_size[0] // 8
83 latent_width = self.config.image_size[1] // 8
84 # Assuming square latents for mask generation as per current implementation
85 size = latent_height
87 for g in range(self.config.M):
88 pattern = torch.fft.fftshift(
89 torch.fft.fft2(torch.randn(4, latent_height, latent_width).to(self.config.device)),
90 dim=(-1, -2)
91 )
92 mask = self._circle_mask(size, self.config.group_radius)
93 pattern *= mask
94 patterns[g] = pattern
95 return patterns
97 def _circle_mask(self, size: int, r: int) -> torch.Tensor:
98 y, x = torch.meshgrid(torch.arange(size), torch.arange(size))
99 center = size // 2
100 dist = (x - center)**2 + (y - center)**2
101 return ((dist >= (r-2)**2) & (dist <= r**2)).float().to(self.config.device)
103 def inject_watermark(self, index: int) -> torch.Tensor:
104 seed = self.config._generate_seed(index)
105 z_i = self.config._generate_noise(seed)
106 self.original_noise = z_i
107 g = index % self.config.M
108 z_fft = torch.fft.fftshift(torch.fft.fft2(z_i), dim=(-1, -2))
110 latent_height = self.config.image_size[0] // 8
111 # Assuming square latents for mask generation
112 size = latent_height
114 mask = self._circle_mask(size, self.config.group_radius)
115 z_fft = z_fft + self.group_patterns[g] * mask
117 z_combined = torch.fft.ifft2(torch.fft.ifftshift(z_fft)).real
118 return z_combined
120class WIND(BaseWatermark):
122 def __init__(self, watermark_config: WINDConfig, *args, **kwargs):
123 """
124 Initialize the WIND algorithm.
126 Parameters:
127 watermark_config (WINDConfig): Configuration instance of the WIND algorithm.
128 """
129 self.config = watermark_config
130 self.utils = WINDUtils(self.config)
132 self.detector = WINDetector(
133 noise_groups=self.config.noise_groups,
134 group_patterns=self.utils.group_patterns,
135 threshold=self.config.threshold,
136 device=self.config.device,
137 group_radius=self.config.group_radius
138 )
140 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image:
141 """Generate a watermarked image."""
142 index = self.config.current_index % self.config.M
144 watermarked_z = self.utils.inject_watermark(index).unsqueeze(0) # [1, 4, 64, 64]
145 self.set_orig_watermarked_latents(watermarked_z)
146 set_random_seed(self.config.gen_seed)
148 generation_params = {
149 "num_images_per_prompt": self.config.num_images,
150 "guidance_scale": self.config.guidance_scale,
151 "num_inference_steps": self.config.num_inference_steps,
152 "height": self.config.image_size[0],
153 "width": self.config.image_size[1],
154 "latents": watermarked_z,
155 }
157 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
158 for key, value in self.config.gen_kwargs.items():
159 if key not in generation_params:
160 generation_params[key] = value
162 # Use kwargs to override default parameters
163 for key, value in kwargs.items():
164 generation_params[key] = value
166 # Ensure latents parameter is not overridden
167 generation_params["latents"] = watermarked_z
169 result = self.config.pipe(
170 prompt,
171 **generation_params
172 )
174 if isinstance(result, tuple):
175 return result[0].images[0]
176 else:
177 return result.images[0]
179 def _detect_watermark_in_image(self,
180 image: Image.Image,
181 prompt: str = "",
182 *args,
183 **kwargs) -> Dict[str, Any]:
185 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
186 num_inference_steps = kwargs.get('num_inference_steps', 50)
188 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
189 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
190 prompt=prompt,
191 device=self.config.device,
192 do_classifier_free_guidance=do_classifier_free_guidance,
193 num_images_per_prompt=1,
194 )
196 if do_classifier_free_guidance:
197 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
198 else:
199 text_embeddings = prompt_embeds
201 processed_img = transform_to_model_format(
202 image,
203 target_size=self.config.image_size[0]
204 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
206 image_latents = get_media_latents(
207 pipe=self.config.pipe,
208 media=processed_img,
209 sample=False,
210 decoder_inv = kwargs.get('decoder_inv',False)
211 )
212 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv','guidance_scale','num_inference_steps']}
214 reversed_latents = self.config.inversion.forward_diffusion(
215 num_inference_steps=num_inference_steps,
216 guidance_scale=guidance_scale_to_use,
217 latents=image_latents,
218 text_embeddings=text_embeddings,
219 **inversion_kwargs
220 )[-1]
221 if 'detector_type' in kwargs:
222 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type'])
223 else:
224 return self.detector.eval_watermark(reversed_latents)
226 def get_data_for_visualize(self,
227 image: Image.Image,
228 prompt: str = "",
229 guidance_scale: Optional[float] = None,
230 decoder_inv: bool = False,
231 *args,
232 **kwargs):
234 guidance_scale = guidance_scale or self.config.guidance_scale
235 num_inference_steps = kwargs.get('num_inference_steps', 50)
237 do_classifier_free_guidance = (guidance_scale > 1.0)
238 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
239 prompt=prompt,
240 device=self.config.device,
241 do_classifier_free_guidance=do_classifier_free_guidance,
242 num_images_per_prompt=1,
243 )
245 if do_classifier_free_guidance:
246 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
247 else:
248 text_embeddings = prompt_embeds
250 processed_img = transform_to_model_format(
251 image,
252 target_size=self.config.image_size[0]
253 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
255 image_latents = get_media_latents(
256 pipe=self.config.pipe,
257 media=processed_img,
258 sample=False,
259 decoder_inv=decoder_inv
260 )
262 reversed_latents = self.config.inversion.forward_diffusion(
263 latents=image_latents,
264 text_embeddings=text_embeddings,
265 guidance_scale=guidance_scale,
266 num_inference_steps=num_inference_steps,
267 reverse=True,
268 **kwargs
269 )
271 data = DataForVisualization(
272 config=self.config,
273 utils=self.utils,
274 image=image,
275 reversed_latents=reversed_latents,
276 orig_watermarked_latents=self.orig_watermarked_latents,
277 )
279 return data