Coverage for watermark / gs / gs.py: 90.91%

132 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from ..base import BaseWatermark, BaseConfig 

2import torch 

3from typing import Dict 

4from PIL import Image 

5from utils.diffusion_config import DiffusionConfig 

6import numpy as np 

7from Crypto.Cipher import ChaCha20 

8import random 

9from scipy.stats import norm,truncnorm 

10from functools import reduce 

11from visualize.data_for_visualization import DataForVisualization 

12from detection.gs.gs_detection import GSDetector 

13from utils.media_utils import * 

14from utils.utils import set_random_seed 

15 

16class GSConfig(BaseConfig): 

17 """Config class for Gaussian Shading algorithm.""" 

18 

19 def initialize_parameters(self) -> None: 

20 """Initialize algorithm-specific parameters.""" 

21 self.channel_copy = self.config_dict['channel_copy'] 

22 self.hw_copy = self.config_dict['hw_copy'] 

23 self.chacha = self.config_dict['chacha'] 

24 self.wm_key = self.config_dict['wm_key'] 

25 self.chacha_key_seed = self.config_dict['chacha_key_seed'] 

26 self.chacha_nonce_seed = self.config_dict['chacha_nonce_seed'] 

27 self.threshold = self.config_dict['threshold'] 

28 self.vote_threshold = 1 if self.hw_copy == 1 and self.channel_copy == 1 else self.channel_copy * self.hw_copy * self.hw_copy // 2 

29 

30 self.latents_height = self.image_size[0] // self.pipe.vae_scale_factor 

31 self.latents_width = self.image_size[1] // self.pipe.vae_scale_factor 

32 generator = torch.Generator(device=self.device) 

33 generator.manual_seed(self.wm_key) 

34 self.watermark = torch.randint(0, 2, [1, 4 // self.channel_copy, self.latents_height // self.hw_copy, self.latents_width // self.hw_copy], generator=generator, device=self.device) 

35 if not self.chacha: 

36 self.key = torch.randint(0, 2, [1, 4, self.latents_height, self.latents_width], generator=generator, device=self.device) 

37 

38 @property 

39 def algorithm_name(self) -> str: 

40 """Return the algorithm name.""" 

41 return 'GS' 

42 

43class GSUtils: 

44 """Utility class for Gaussian Shading algorithm.""" 

45 

46 def __init__(self, config: GSConfig, *args, **kwargs) -> None: 

47 """ 

48 Initialize the Gaussian Shading watermarking utility. 

49  

50 Parameters: 

51 config (GSConfig): Configuration for the Gaussian Shading watermarking algorithm. 

52 """ 

53 self.config = config 

54 self.chacha_key = self._get_bytes_with_seed(self.config.chacha_key_seed, 32) 

55 self.chacha_nonce = self._get_bytes_with_seed(self.config.chacha_nonce_seed, 12) 

56 self.latentlength = 4 * self.config.latents_height * self.config.latents_width 

57 self.marklength = self.latentlength//(self.config.channel_copy * self.config.hw_copy * self.config.hw_copy) 

58 

59 def _get_bytes_with_seed(self, seed: int, n: int) -> bytes: 

60 random.seed(seed) 

61 return bytes(random.getrandbits(8) for _ in range(n)) 

62 

63 def _stream_key_encrypt(self, sd): 

64 """Encrypt the watermark using ChaCha20 cipher.""" 

65 cipher = ChaCha20.new(key=self.chacha_key, nonce=self.chacha_nonce) 

66 m_byte = cipher.encrypt(np.packbits(sd).tobytes()) 

67 m_bit = np.unpackbits(np.frombuffer(m_byte, dtype=np.uint8)) 

68 return m_bit 

69 

70 def _truncSampling(self, message): 

71 """Truncated Gaussian sampling for watermarking.""" 

72 z = np.zeros(self.latentlength) 

73 denominator = 2.0 

74 ppf = [norm.ppf(j / denominator) for j in range(int(denominator) + 1)] 

75 for i in range(self.latentlength): 

76 dec_mes = reduce(lambda a, b: 2 * a + b, message[i : i + 1]) 

77 dec_mes = int(dec_mes) 

78 z[i] = truncnorm.rvs(ppf[dec_mes], ppf[dec_mes + 1]) 

79 z = torch.from_numpy(z).reshape(1, 4, self.config.latents_height, self.config.latents_width).float() 

80 return z.to(self.config.device) 

81 

82 def _create_watermark(self) -> torch.Tensor: 

83 """Create watermark pattern without encryption.""" 

84 sd = self.config.watermark.repeat(1,self.config.channel_copy,self.config.hw_copy,self.config.hw_copy) 

85 m = ((sd + self.config.key) % 2).flatten().cpu().numpy() 

86 w = self._truncSampling(m) 

87 return w 

88 

89 def _create_watermark_chacha(self) -> torch.Tensor: 

90 """Create watermark pattern using ChaCha20 cipher.""" 

91 sd = self.config.watermark.repeat(1,self.config.channel_copy,self.config.hw_copy,self.config.hw_copy) 

92 m = self._stream_key_encrypt(sd.flatten().cpu().numpy()) 

93 w = self._truncSampling(m) 

94 return w 

95 

96 def inject_watermark(self) -> torch.Tensor: 

97 """Inject watermark into latent space.""" 

98 if self.config.chacha: 

99 watermarked = self._create_watermark_chacha() 

100 else: 

101 watermarked = self._create_watermark() 

102 return watermarked 

103 

104class GS(BaseWatermark): 

105 """Main class for Gaussian Shading watermarking algorithm.""" 

106 

107 def __init__(self, 

108 watermark_config: GSConfig, 

109 *args, **kwargs): 

110 """ 

111 Initialize the Gaussian Shading watermarking algorithm. 

112  

113 Parameters: 

114 watermark_config (GSConfig): Configuration instance of the GS algorithm. 

115 """ 

116 self.config = watermark_config 

117 self.utils = GSUtils(self.config) 

118 

119 self.detector = GSDetector( 

120 watermarking_mask=self.config.watermark, 

121 chacha=self.config.chacha, 

122 wm_key=(self.utils.chacha_key, self.utils.chacha_nonce) if self.config.chacha else self.config.key, 

123 channel_copy=self.config.channel_copy, 

124 hw_copy=self.config.hw_copy, 

125 vote_threshold=self.config.vote_threshold, 

126 threshold=self.config.threshold, 

127 device=self.config.device 

128 ) 

129 

130 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image: 

131 """Generate image with Gaussian Shading watermark.""" 

132 set_random_seed(self.config.gen_seed) 

133 watermarked_latents = self.utils.inject_watermark() 

134 

135 # save watermarked latents 

136 self.set_orig_watermarked_latents(watermarked_latents) 

137 

138 # Construct generation parameters 

139 generation_params = { 

140 "num_images_per_prompt": self.config.num_images, 

141 "guidance_scale": self.config.guidance_scale, 

142 "num_inference_steps": self.config.num_inference_steps, 

143 "height": self.config.image_size[0], 

144 "width": self.config.image_size[1], 

145 "latents": watermarked_latents, 

146 } 

147 

148 # Add parameters from config.gen_kwargs 

149 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs: 

150 for key, value in self.config.gen_kwargs.items(): 

151 if key not in generation_params: 

152 generation_params[key] = value 

153 

154 # Use kwargs to override default parameters 

155 for key, value in kwargs.items(): 

156 generation_params[key] = value 

157 

158 # Ensure latents parameter is not overridden 

159 generation_params["latents"] = watermarked_latents 

160 

161 return self.config.pipe( 

162 prompt, 

163 **generation_params 

164 ).images[0] 

165 

166 def _detect_watermark_in_image(self, 

167 image: Image.Image, 

168 prompt: str="", 

169 *args, 

170 **kwargs) -> Dict[str, float]: 

171 """Detect Gaussian Shading watermark.""" 

172 # Use config values as defaults if not explicitly provided 

173 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

174 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps) 

175 

176 # Step 1: Get Text Embeddings 

177 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

178 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

179 prompt=prompt, 

180 device=self.config.device, 

181 do_classifier_free_guidance=do_classifier_free_guidance, 

182 num_images_per_prompt=1, 

183 ) 

184 

185 if do_classifier_free_guidance: 

186 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

187 else: 

188 text_embeddings = prompt_embeds 

189 

190 # Step 2: Preprocess Image 

191 image = transform_to_model_format(image, target_size=self.config.image_size[0]).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

192 

193 # Step 3: Get Image Latents 

194 image_latents = get_media_latents(pipe=self.config.pipe, media=image, sample=False, decoder_inv=kwargs.get("decoder_inv", False)) 

195 

196 # Pass only known parameters to forward_diffusion, and let kwargs handle any additional parameters 

197 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']} 

198 

199 # Step 4: Reverse Image Latents 

200 reversed_latents = self.config.inversion.forward_diffusion( 

201 latents=image_latents, 

202 text_embeddings=text_embeddings, 

203 guidance_scale=guidance_scale_to_use, 

204 num_inference_steps=num_steps_to_use, 

205 **inversion_kwargs 

206 )[-1] 

207 

208 if 'detector_type' in kwargs: 

209 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type']) 

210 else: 

211 return self.detector.eval_watermark(reversed_latents) 

212 

213 def get_data_for_visualize(self, 

214 image: Image.Image, 

215 prompt: str="", 

216 guidance_scale: float=1, 

217 decoder_inv: bool=False, 

218 *args, 

219 **kwargs) -> DataForVisualization: 

220 """Get Gaussian Shading visualization data""" 

221 # 1. Generate watermarked latents 

222 set_random_seed(self.config.gen_seed) 

223 watermarked_latents = self.utils.inject_watermark() 

224 

225 # 2. Generate actual watermarked image using the same process as _generate_watermarked_image 

226 generation_params = { 

227 "num_images_per_prompt": self.config.num_images, 

228 "guidance_scale": self.config.guidance_scale, 

229 "num_inference_steps": self.config.num_inference_steps, 

230 "height": self.config.image_size[0], 

231 "width": self.config.image_size[1], 

232 "latents": watermarked_latents, 

233 } 

234 

235 # Add parameters from config.gen_kwargs 

236 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs: 

237 for key, value in self.config.gen_kwargs.items(): 

238 if key not in generation_params: 

239 generation_params[key] = value 

240 

241 # Generate the actual watermarked image 

242 watermarked_image = self.config.pipe( 

243 prompt, 

244 **generation_params 

245 ).images[0] 

246 

247 # 3. Perform watermark detection to get inverted latents (for comparison) 

248 inverted_latents = None 

249 try: 

250 # Use the same detection process as _detect_watermark_in_image 

251 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

252 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps) 

253 

254 # Get Text Embeddings 

255 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

256 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

257 prompt=prompt, 

258 device=self.config.device, 

259 do_classifier_free_guidance=do_classifier_free_guidance, 

260 num_images_per_prompt=1, 

261 ) 

262 

263 if do_classifier_free_guidance: 

264 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

265 else: 

266 text_embeddings = prompt_embeds 

267 

268 # Preprocess watermarked image for detection 

269 processed_image = transform_to_model_format( 

270 watermarked_image, 

271 target_size=self.config.image_size[0] 

272 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

273 

274 # Get Image Latents 

275 image_latents = get_media_latents( 

276 pipe=self.config.pipe, 

277 media=processed_image, 

278 sample=False, 

279 decoder_inv=decoder_inv 

280 ) 

281 

282 # Reverse Image Latents to get inverted noise 

283 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']} 

284 

285 reversed_latents = self.config.inversion.forward_diffusion( 

286 latents=image_latents, 

287 text_embeddings=text_embeddings, 

288 guidance_scale=guidance_scale_to_use, 

289 num_inference_steps=num_steps_to_use, 

290 **inversion_kwargs 

291 ) 

292 

293 except Exception as e: 

294 raise ValueError(f"Warning: Could not perform inversion for visualization: {e}") 

295 

296 # 4. Prepare visualization data 

297 return DataForVisualization( 

298 config=self.config, 

299 utils=self.utils, 

300 orig_watermarked_latents=self.orig_watermarked_latents, 

301 reversed_latents=reversed_latents, 

302 image=image 

303 )