Coverage for watermark / gs / gs.py: 90.91%
132 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from ..base import BaseWatermark, BaseConfig
2import torch
3from typing import Dict
4from PIL import Image
5from utils.diffusion_config import DiffusionConfig
6import numpy as np
7from Crypto.Cipher import ChaCha20
8import random
9from scipy.stats import norm,truncnorm
10from functools import reduce
11from visualize.data_for_visualization import DataForVisualization
12from detection.gs.gs_detection import GSDetector
13from utils.media_utils import *
14from utils.utils import set_random_seed
16class GSConfig(BaseConfig):
17 """Config class for Gaussian Shading algorithm."""
19 def initialize_parameters(self) -> None:
20 """Initialize algorithm-specific parameters."""
21 self.channel_copy = self.config_dict['channel_copy']
22 self.hw_copy = self.config_dict['hw_copy']
23 self.chacha = self.config_dict['chacha']
24 self.wm_key = self.config_dict['wm_key']
25 self.chacha_key_seed = self.config_dict['chacha_key_seed']
26 self.chacha_nonce_seed = self.config_dict['chacha_nonce_seed']
27 self.threshold = self.config_dict['threshold']
28 self.vote_threshold = 1 if self.hw_copy == 1 and self.channel_copy == 1 else self.channel_copy * self.hw_copy * self.hw_copy // 2
30 self.latents_height = self.image_size[0] // self.pipe.vae_scale_factor
31 self.latents_width = self.image_size[1] // self.pipe.vae_scale_factor
32 generator = torch.Generator(device=self.device)
33 generator.manual_seed(self.wm_key)
34 self.watermark = torch.randint(0, 2, [1, 4 // self.channel_copy, self.latents_height // self.hw_copy, self.latents_width // self.hw_copy], generator=generator, device=self.device)
35 if not self.chacha:
36 self.key = torch.randint(0, 2, [1, 4, self.latents_height, self.latents_width], generator=generator, device=self.device)
38 @property
39 def algorithm_name(self) -> str:
40 """Return the algorithm name."""
41 return 'GS'
43class GSUtils:
44 """Utility class for Gaussian Shading algorithm."""
46 def __init__(self, config: GSConfig, *args, **kwargs) -> None:
47 """
48 Initialize the Gaussian Shading watermarking utility.
50 Parameters:
51 config (GSConfig): Configuration for the Gaussian Shading watermarking algorithm.
52 """
53 self.config = config
54 self.chacha_key = self._get_bytes_with_seed(self.config.chacha_key_seed, 32)
55 self.chacha_nonce = self._get_bytes_with_seed(self.config.chacha_nonce_seed, 12)
56 self.latentlength = 4 * self.config.latents_height * self.config.latents_width
57 self.marklength = self.latentlength//(self.config.channel_copy * self.config.hw_copy * self.config.hw_copy)
59 def _get_bytes_with_seed(self, seed: int, n: int) -> bytes:
60 random.seed(seed)
61 return bytes(random.getrandbits(8) for _ in range(n))
63 def _stream_key_encrypt(self, sd):
64 """Encrypt the watermark using ChaCha20 cipher."""
65 cipher = ChaCha20.new(key=self.chacha_key, nonce=self.chacha_nonce)
66 m_byte = cipher.encrypt(np.packbits(sd).tobytes())
67 m_bit = np.unpackbits(np.frombuffer(m_byte, dtype=np.uint8))
68 return m_bit
70 def _truncSampling(self, message):
71 """Truncated Gaussian sampling for watermarking."""
72 z = np.zeros(self.latentlength)
73 denominator = 2.0
74 ppf = [norm.ppf(j / denominator) for j in range(int(denominator) + 1)]
75 for i in range(self.latentlength):
76 dec_mes = reduce(lambda a, b: 2 * a + b, message[i : i + 1])
77 dec_mes = int(dec_mes)
78 z[i] = truncnorm.rvs(ppf[dec_mes], ppf[dec_mes + 1])
79 z = torch.from_numpy(z).reshape(1, 4, self.config.latents_height, self.config.latents_width).float()
80 return z.to(self.config.device)
82 def _create_watermark(self) -> torch.Tensor:
83 """Create watermark pattern without encryption."""
84 sd = self.config.watermark.repeat(1,self.config.channel_copy,self.config.hw_copy,self.config.hw_copy)
85 m = ((sd + self.config.key) % 2).flatten().cpu().numpy()
86 w = self._truncSampling(m)
87 return w
89 def _create_watermark_chacha(self) -> torch.Tensor:
90 """Create watermark pattern using ChaCha20 cipher."""
91 sd = self.config.watermark.repeat(1,self.config.channel_copy,self.config.hw_copy,self.config.hw_copy)
92 m = self._stream_key_encrypt(sd.flatten().cpu().numpy())
93 w = self._truncSampling(m)
94 return w
96 def inject_watermark(self) -> torch.Tensor:
97 """Inject watermark into latent space."""
98 if self.config.chacha:
99 watermarked = self._create_watermark_chacha()
100 else:
101 watermarked = self._create_watermark()
102 return watermarked
104class GS(BaseWatermark):
105 """Main class for Gaussian Shading watermarking algorithm."""
107 def __init__(self,
108 watermark_config: GSConfig,
109 *args, **kwargs):
110 """
111 Initialize the Gaussian Shading watermarking algorithm.
113 Parameters:
114 watermark_config (GSConfig): Configuration instance of the GS algorithm.
115 """
116 self.config = watermark_config
117 self.utils = GSUtils(self.config)
119 self.detector = GSDetector(
120 watermarking_mask=self.config.watermark,
121 chacha=self.config.chacha,
122 wm_key=(self.utils.chacha_key, self.utils.chacha_nonce) if self.config.chacha else self.config.key,
123 channel_copy=self.config.channel_copy,
124 hw_copy=self.config.hw_copy,
125 vote_threshold=self.config.vote_threshold,
126 threshold=self.config.threshold,
127 device=self.config.device
128 )
130 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image:
131 """Generate image with Gaussian Shading watermark."""
132 set_random_seed(self.config.gen_seed)
133 watermarked_latents = self.utils.inject_watermark()
135 # save watermarked latents
136 self.set_orig_watermarked_latents(watermarked_latents)
138 # Construct generation parameters
139 generation_params = {
140 "num_images_per_prompt": self.config.num_images,
141 "guidance_scale": self.config.guidance_scale,
142 "num_inference_steps": self.config.num_inference_steps,
143 "height": self.config.image_size[0],
144 "width": self.config.image_size[1],
145 "latents": watermarked_latents,
146 }
148 # Add parameters from config.gen_kwargs
149 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
150 for key, value in self.config.gen_kwargs.items():
151 if key not in generation_params:
152 generation_params[key] = value
154 # Use kwargs to override default parameters
155 for key, value in kwargs.items():
156 generation_params[key] = value
158 # Ensure latents parameter is not overridden
159 generation_params["latents"] = watermarked_latents
161 return self.config.pipe(
162 prompt,
163 **generation_params
164 ).images[0]
166 def _detect_watermark_in_image(self,
167 image: Image.Image,
168 prompt: str="",
169 *args,
170 **kwargs) -> Dict[str, float]:
171 """Detect Gaussian Shading watermark."""
172 # Use config values as defaults if not explicitly provided
173 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
174 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps)
176 # Step 1: Get Text Embeddings
177 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
178 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
179 prompt=prompt,
180 device=self.config.device,
181 do_classifier_free_guidance=do_classifier_free_guidance,
182 num_images_per_prompt=1,
183 )
185 if do_classifier_free_guidance:
186 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
187 else:
188 text_embeddings = prompt_embeds
190 # Step 2: Preprocess Image
191 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
193 # Step 3: Get Image Latents
194 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get("decoder_inv", False))
196 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters
197 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']}
199 # Step 4: Reverse Image Latents
200 reversed_latents = self.config.inversion.forward_diffusion(
201 latents=image_latents,
202 text_embeddings=text_embeddings,
203 guidance_scale=guidance_scale_to_use,
204 num_inference_steps=num_steps_to_use,
205 **inversion_kwargs
206 )[-1]
208 if 'detector_type' in kwargs:
209 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type'])
210 else:
211 return self.detector.eval_watermark(reversed_latents)
213 def get_data_for_visualize(self,
214 image: Image.Image,
215 prompt: str="",
216 guidance_scale: float=1,
217 decoder_inv: bool=False,
218 *args,
219 **kwargs) -> DataForVisualization:
220 """Get Gaussian Shading visualization data"""
221 # 1. Generate watermarked latents
222 set_random_seed(self.config.gen_seed)
223 watermarked_latents = self.utils.inject_watermark()
225 # 2. Generate actual watermarked image using the same process as _generate_watermarked_image
226 generation_params = {
227 "num_images_per_prompt": self.config.num_images,
228 "guidance_scale": self.config.guidance_scale,
229 "num_inference_steps": self.config.num_inference_steps,
230 "height": self.config.image_size[0],
231 "width": self.config.image_size[1],
232 "latents": watermarked_latents,
233 }
235 # Add parameters from config.gen_kwargs
236 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
237 for key, value in self.config.gen_kwargs.items():
238 if key not in generation_params:
239 generation_params[key] = value
241 # Generate the actual watermarked image
242 watermarked_image = self.config.pipe(
243 prompt,
244 **generation_params
245 ).images[0]
247 # 3. Perform watermark detection to get inverted latents (for comparison)
248 inverted_latents = None
249 try:
250 # Use the same detection process as _detect_watermark_in_image
251 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
252 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps)
254 # Get Text Embeddings
255 do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
256 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
257 prompt=prompt,
258 device=self.config.device,
259 do_classifier_free_guidance=do_classifier_free_guidance,
260 num_images_per_prompt=1,
261 )
263 if do_classifier_free_guidance:
264 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
265 else:
266 text_embeddings = prompt_embeds
268 # Preprocess watermarked image for detection
269 processed_image = transform_to_model_format(
270 watermarked_image,
271 target_size=self.config.image_size[0]
272 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
274 # Get Image Latents
275 image_latents = get_media_latents(
276 pipe=self.config.pipe,
277 media=processed_image,
278 sample=False,
279 decoder_inv=decoder_inv
280 )
282 # Reverse Image Latents to get inverted noise
283 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']}
285 reversed_latents = self.config.inversion.forward_diffusion(
286 latents=image_latents,
287 text_embeddings=text_embeddings,
288 guidance_scale=guidance_scale_to_use,
289 num_inference_steps=num_steps_to_use,
290 **inversion_kwargs
291 )
293 except Exception as e:
294 raise ValueError(f"Warning: Could not perform inversion for visualization: {e}")
296 # 4. Prepare visualization data
297 return DataForVisualization(
298 config=self.config,
299 utils=self.utils,
300 orig_watermarked_latents=self.orig_watermarked_latents,
301 reversed_latents=reversed_latents,
302 image=image
303 )