Coverage for watermark / wind / wind.py: 93.80%

129 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16import torch 

17import hashlib 

18import numpy as np 

19import logging 

20from typing import Dict, Any, Union, List, Optional 

21from PIL import Image 

22from utils.media_utils import * 

23from utils.utils import load_config_file, set_random_seed 

24from utils.diffusion_config import DiffusionConfig 

25from utils.media_utils import transform_to_model_format, get_media_latents 

26from watermark.base import BaseConfig, BaseWatermark 

27from exceptions.exceptions import AlgorithmNameMismatchError 

28from detection.wind.wind_detection import WINDetector 

29from visualize.data_for_visualization import DataForVisualization 

30 

31logger = logging.getLogger(__name__) 

32 

33class WINDConfig(BaseConfig): 

34 

35 def initialize_parameters(self) -> None: 

36 

37 self.w_seed = self.config_dict['w_seed'] 

38 self.N = self.config_dict['num_noises'] 

39 self.M = self.config_dict['num_groups'] 

40 self.secret_salt = self.config_dict['secret_salt'].encode() 

41 self.hash_func = getattr(hashlib, self.config_dict['hash_function']) 

42 self.group_radius = self.config_dict['group_radius'] 

43 self.threshold = self.config_dict['threshold'] 

44 self.current_index = self.config_dict['current_index'] 

45 self.noise_groups = self._precompute_noise_groups() 

46 

47 def _precompute_noise_groups(self): 

48 groups = {} 

49 for i in range(self.N): 

50 g = i % self.M 

51 if g not in groups: 

52 groups[g] = [] 

53 seed = self._generate_seed(i) 

54 groups[g].append(self._generate_noise(seed)) 

55 return groups 

56 

57 def _generate_seed(self, index: int) -> bytes: 

58 """Generate seed""" 

59 return self.hash_func(f"{index}{self.secret_salt}".encode()).digest() 

60 

61 def _generate_noise(self, seed: bytes) -> torch.Tensor: 

62 """Generate noises from seeds""" 

63 rng = np.random.RandomState(int.from_bytes(seed[:4], 'big')) 

64 latent_height = self.image_size[0] // 8 

65 latent_width = self.image_size[1] // 8 

66 return torch.from_numpy(rng.randn(4, latent_height, latent_width)).float().to(self.device) 

67 

68 @property 

69 def algorithm_name(self) -> str: 

70 return 'WIND' 

71 

72class WINDUtils: 

73 

74 def __init__(self, config: WINDConfig): 

75 self.config = config 

76 self.group_patterns = self._generate_group_patterns() 

77 self.original_noise = None 

78 

79 def _generate_group_patterns(self) -> Dict[int, torch.Tensor]: 

80 set_random_seed(self.config.w_seed) 

81 patterns = {} 

82 latent_height = self.config.image_size[0] // 8 

83 latent_width = self.config.image_size[1] // 8 

84 # Assuming square latents for mask generation as per current implementation 

85 size = latent_height 

86 

87 for g in range(self.config.M): 

88 pattern = torch.fft.fftshift( 

89 torch.fft.fft2(torch.randn(4, latent_height, latent_width).to(self.config.device)), 

90 dim=(-1, -2) 

91 ) 

92 mask = self._circle_mask(size, self.config.group_radius) 

93 pattern *= mask 

94 patterns[g] = pattern 

95 return patterns 

96 

97 def _circle_mask(self, size: int, r: int) -> torch.Tensor: 

98 y, x = torch.meshgrid(torch.arange(size), torch.arange(size)) 

99 center = size // 2 

100 dist = (x - center)**2 + (y - center)**2 

101 return ((dist >= (r-2)**2) & (dist <= r**2)).float().to(self.config.device) 

102 

103 def inject_watermark(self, index: int) -> torch.Tensor: 

104 seed = self.config._generate_seed(index) 

105 z_i = self.config._generate_noise(seed) 

106 self.original_noise = z_i 

107 g = index % self.config.M 

108 z_fft = torch.fft.fftshift(torch.fft.fft2(z_i), dim=(-1, -2)) 

109 

110 latent_height = self.config.image_size[0] // 8 

111 # Assuming square latents for mask generation 

112 size = latent_height 

113 

114 mask = self._circle_mask(size, self.config.group_radius) 

115 z_fft = z_fft + self.group_patterns[g] * mask 

116 

117 z_combined = torch.fft.ifft2(torch.fft.ifftshift(z_fft)).real 

118 return z_combined 

119 

120class WIND(BaseWatermark): 

121 

122 def __init__(self, watermark_config: WINDConfig, *args, **kwargs): 

123 """ 

124 Initialize the WIND algorithm. 

125 

126 Parameters: 

127 watermark_config (WINDConfig): Configuration instance of the WIND algorithm. 

128 """ 

129 self.config = watermark_config 

130 self.utils = WINDUtils(self.config) 

131 

132 self.detector = WINDetector( 

133 noise_groups=self.config.noise_groups, 

134 group_patterns=self.utils.group_patterns, 

135 threshold=self.config.threshold, 

136 device=self.config.device, 

137 group_radius=self.config.group_radius 

138 ) 

139 

140 def _generate_watermarked_image(self, prompt: str, *args, **kwargs) -> Image.Image: 

141 """Generate a watermarked image.""" 

142 index = self.config.current_index % self.config.M 

143 

144 watermarked_z = self.utils.inject_watermark(index).unsqueeze(0) # [1, 4, 64, 64] 

145 self.set_orig_watermarked_latents(watermarked_z) 

146 set_random_seed(self.config.gen_seed) 

147 

148 generation_params = { 

149 "num_images_per_prompt": self.config.num_images, 

150 "guidance_scale": self.config.guidance_scale, 

151 "num_inference_steps": self.config.num_inference_steps, 

152 "height": self.config.image_size[0], 

153 "width": self.config.image_size[1], 

154 "latents": watermarked_z, 

155 } 

156 

157 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs: 

158 for key, value in self.config.gen_kwargs.items(): 

159 if key not in generation_params: 

160 generation_params[key] = value 

161 

162 # Use kwargs to override default parameters 

163 for key, value in kwargs.items(): 

164 generation_params[key] = value 

165 

166 # Ensure latents parameter is not overridden 

167 generation_params["latents"] = watermarked_z 

168 

169 result = self.config.pipe( 

170 prompt, 

171 **generation_params 

172 ) 

173 

174 if isinstance(result, tuple): 

175 return result[0].images[0] 

176 else: 

177 return result.images[0] 

178 

179 def _detect_watermark_in_image(self, 

180 image: Image.Image, 

181 prompt: str = "", 

182 *args, 

183 **kwargs) -> Dict[str, Any]: 

184 

185 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

186 num_inference_steps = kwargs.get('num_inference_steps', 50) 

187 

188 do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

189 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

190 prompt=prompt, 

191 device=self.config.device, 

192 do_classifier_free_guidance=do_classifier_free_guidance, 

193 num_images_per_prompt=1, 

194 ) 

195 

196 if do_classifier_free_guidance: 

197 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

198 else: 

199 text_embeddings = prompt_embeds 

200 

201 processed_img = transform_to_model_format( 

202 image, 

203 target_size=self.config.image_size[0] 

204 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

205 

206 image_latents = get_media_latents( 

207 pipe=self.config.pipe, 

208 media=processed_img, 

209 sample=False, 

210 decoder_inv = kwargs.get('decoder_inv',False) 

211 ) 

212 inversion_kwargs = {k: v for k, v in kwargs.items() if k not in ['decoder_inv','guidance_scale','num_inference_steps']} 

213 

214 reversed_latents = self.config.inversion.forward_diffusion( 

215 num_inference_steps=num_inference_steps, 

216 guidance_scale=guidance_scale_to_use, 

217 latents=image_latents, 

218 text_embeddings=text_embeddings, 

219 **inversion_kwargs 

220 )[-1] 

221 if 'detector_type' in kwargs: 

222 return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type']) 

223 else: 

224 return self.detector.eval_watermark(reversed_latents) 

225 

226 def get_data_for_visualize(self, 

227 image: Image.Image, 

228 prompt: str = "", 

229 guidance_scale: Optional[float] = None, 

230 decoder_inv: bool = False, 

231 *args, 

232 **kwargs): 

233 

234 guidance_scale = guidance_scale or self.config.guidance_scale 

235 num_inference_steps = kwargs.get('num_inference_steps', 50) 

236 

237 do_classifier_free_guidance = (guidance_scale > 1.0) 

238 prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

239 prompt=prompt, 

240 device=self.config.device, 

241 do_classifier_free_guidance=do_classifier_free_guidance, 

242 num_images_per_prompt=1, 

243 ) 

244 

245 if do_classifier_free_guidance: 

246 text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

247 else: 

248 text_embeddings = prompt_embeds 

249 

250 processed_img = transform_to_model_format( 

251 image, 

252 target_size=self.config.image_size[0] 

253 ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

254 

255 image_latents = get_media_latents( 

256 pipe=self.config.pipe, 

257 media=processed_img, 

258 sample=False, 

259 decoder_inv=decoder_inv 

260 ) 

261 

262 reversed_latents = self.config.inversion.forward_diffusion( 

263 latents=image_latents, 

264 text_embeddings=text_embeddings, 

265 guidance_scale=guidance_scale, 

266 num_inference_steps=num_inference_steps, 

267 reverse=True, 

268 **kwargs 

269 ) 

270 

271 data = DataForVisualization( 

272 config=self.config, 

273 utils=self.utils, 

274 image=image, 

275 reversed_latents=reversed_latents, 

276 orig_watermarked_latents=self.orig_watermarked_latents, 

277 ) 

278 

279 return data