Coverage for detection / gm / gm_detection.py: 89.85%

197 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15"""GaussMarker detection utilities. 

16 

17This module adapts the official GaussMarker detection pipeline to the 

18MarkDiffusion detection API. It evaluates recovered diffusion latents to 

19decide whether a watermark is present, reporting both hard decisions and 

20auxiliary scores (bit/message accuracies, frequency-domain distances). 

21""" 

22 

23from __future__ import annotations 

24 

25from pathlib import Path 

26from typing import Dict, Optional, Union 

27 

28import numpy as np 

29import torch 

30 

31import joblib 

32from huggingface_hub import hf_hub_download 

33 

34 

35from detection.base import BaseDetector 

36from watermark.gm.gm import GaussianShadingChaCha, extract_complex_sign 

37from watermark.gm.gnr import GNRRestorer 

38 

39 

40class GMDetector(BaseDetector): 

41 """Detector for GaussMarker watermarks. 

42 

43 Args: 

44 watermark_generator: Instance of :class:`GaussianShadingChaCha` that 

45 holds the original watermark bits and ChaCha20 key stream. 

46 watermarking_mask: Frequency-domain mask (or label map) indicating the 

47 region that carries the watermark. 

48 gt_patch: Reference watermark pattern in the frequency domain. 

49 w_measurement: Measurement mode (e.g., ``"l1_complex"`` or 

50 ``"signal_complex"``), mirroring the official implementation. 

51 device: Torch device used for evaluation. 

52 bit_threshold: Optional override for the bit-accuracy decision 

53 threshold. Defaults to the generator's ``tau_bits`` value. 

54 message_threshold: Optional threshold for message accuracy decisions. 

55 l1_threshold: Optional threshold for frequency L1 distance decisions 

56 (smaller is better). 

57 gnr_checkpoint: Path or filename of GNR classifier checkpoint; supports 

58 HuggingFace Hub fallback if ``huggingface_repo`` is provided. 

59 gnr_classifier_type: See original implementation. 

60 gnr_model_nf: Number of feature maps for GNR model. 

61 gnr_binary_threshold: Threshold to binarize GNR outputs. 

62 gnr_use_for_decision: Whether to use GNR bit accuracy to assist decision. 

63 gnr_threshold: Optional threshold override when using GNR. 

64 fuser_checkpoint: Path or filename of fuser model; supports HuggingFace 

65 Hub fallback if ``huggingface_repo`` is provided. 

66 fuser_threshold: Decision threshold when using fused score. 

67 fuser_frequency_scale: Frequency score scaling factor. 

68 huggingface_repo: Optional HuggingFace repository for checkpoint download. 

69 hf_dir: Optional local cache directory for HuggingFace downloads. 

70 """ 

71 

72 def __init__( 

73 self, 

74 watermark_generator: GaussianShadingChaCha, 

75 watermarking_mask: torch.Tensor, 

76 gt_patch: torch.Tensor, 

77 w_measurement: str, 

78 device: Union[str, torch.device], 

79 bit_threshold: Optional[float] = None, 

80 message_threshold: Optional[float] = None, 

81 l1_threshold: Optional[float] = None, 

82 gnr_checkpoint: Optional[Union[str, Path]] = None, 

83 gnr_classifier_type: int = 0, 

84 gnr_model_nf: int = 128, 

85 gnr_binary_threshold: float = 0.5, 

86 gnr_use_for_decision: bool = True, 

87 gnr_threshold: Optional[float] = None, 

88 fuser_checkpoint: Optional[Union[str, Path]] = None, 

89 fuser_threshold: Optional[float] = None, 

90 fuser_frequency_scale: float = 0.01, 

91 huggingface_repo: Optional[str] = None, 

92 hf_dir: Optional[str] = None, 

93 ) -> None: 

94 self.generator = watermark_generator 

95 device = torch.device(device) 

96 # Ensure watermark/message buffers are initialised 

97 _, base_message = self.generator.create_watermark_and_return_w_m() 

98 self.base_message = base_message.to(device) 

99 watermarking_mask = watermarking_mask.to(device) 

100 gt_patch = gt_patch.to(device) 

101 

102 threshold = bit_threshold if bit_threshold is not None else ( 

103 self.generator.tau_bits or 0.5 

104 ) 

105 super().__init__(threshold=float(threshold), device=device) 

106 

107 self.watermarking_mask = watermarking_mask 

108 self.gt_patch = gt_patch 

109 self.w_measurement = w_measurement.lower() 

110 self.message_threshold = ( 

111 float(message_threshold) if message_threshold is not None else float(self.generator.tau_onebit or threshold) 

112 ) 

113 self.l1_threshold = float(l1_threshold) if l1_threshold is not None else None 

114 self.gnr_binary_threshold = gnr_binary_threshold 

115 self.gnr_use_for_decision = gnr_use_for_decision 

116 self.gnr_threshold = float(gnr_threshold) if gnr_threshold is not None else None 

117 self.huggingface_repo = huggingface_repo 

118 self.hf_dir = hf_dir 

119 

120 self.gnr_restorer = self._build_gnr_restorer( 

121 checkpoint=gnr_checkpoint, 

122 device=device, 

123 classifier_type=gnr_classifier_type, 

124 nf=gnr_model_nf, 

125 ) 

126 self.fuser = self._load_fuser(fuser_checkpoint) 

127 self.fuser_threshold = float(fuser_threshold) if fuser_threshold is not None else 0.5 

128 self.fuser_frequency_scale = float(fuser_frequency_scale) 

129 

130 # ------------------------------------------------------------------ 

131 # Helper computations 

132 # ------------------------------------------------------------------ 

133 

134 def _complex_l1(self, reversed_latents: torch.Tensor) -> float: 

135 fft_latents = torch.fft.fftshift(torch.fft.fft2(reversed_latents), dim=(-1, -2)) 

136 if self.watermarking_mask.dtype == torch.bool: 

137 selector = self.watermarking_mask 

138 else: 

139 selector = self.watermarking_mask != 0 

140 if selector.sum() == 0: 

141 return 0.0 

142 diff = torch.abs(fft_latents[selector] - self.gt_patch[selector]) 

143 return float(diff.mean().item()) 

144 

145 def _signal_accuracy(self, reversed_latents: torch.Tensor) -> float: 

146 fft_latents = torch.fft.fftshift(torch.fft.fft2(reversed_latents), dim=(-1, -2)) 

147 if self.watermarking_mask.dtype == torch.bool: 

148 selector = self.watermarking_mask 

149 else: 

150 selector = self.watermarking_mask != 0 

151 if selector.sum() == 0: 

152 return 0.0 

153 latents_sign = extract_complex_sign(fft_latents[selector]) 

154 target_sign = extract_complex_sign(self.gt_patch[selector]) 

155 return float((latents_sign == target_sign).float().mean().item()) 

156 

157 def _build_gnr_restorer( 

158 self, 

159 checkpoint: Optional[Union[str, Path]], 

160 device: torch.device, 

161 classifier_type: int, 

162 nf: int, 

163 ) -> Optional[GNRRestorer]: 

164 if not checkpoint: 

165 return None 

166 

167 checkpoint_path: Optional[Path] = None 

168 candidates: list[Path] = [] 

169 

170 # If a HuggingFace repo is provided, try cache dir and HF download 

171 if self.huggingface_repo: 

172 # Existing local file 

173 orig = Path(checkpoint) 

174 local_path = orig if orig.is_file() else None 

175 if self.hf_dir: 

176 potential_local = Path(self.hf_dir) / Path(checkpoint).name 

177 if potential_local.is_file(): 

178 local_path = potential_local 

179 if local_path and local_path.is_file(): 

180 checkpoint_path = local_path 

181 else: 

182 try: 

183 hf_path = hf_hub_download( 

184 repo_id=self.huggingface_repo, 

185 filename=Path(checkpoint).name, 

186 cache_dir=self.hf_dir, 

187 ) 

188 checkpoint_path = Path(hf_path) 

189 except Exception as e: 

190 raise FileNotFoundError(f"GNR checkpoint not found on HuggingFace repo '{self.huggingface_repo}'. error: {e}") 

191 # Fallback: look for local candidates 

192 if checkpoint_path is None: 

193 candidates = [Path(checkpoint)] 

194 base_dir = Path(__file__).resolve().parent 

195 candidates.append(base_dir / Path(checkpoint)) 

196 candidates.append(base_dir.parent.parent / Path(checkpoint)) 

197 for candidate in candidates: 

198 if candidate.is_file(): 

199 checkpoint_path = candidate 

200 break 

201 if checkpoint_path is None: 

202 raise FileNotFoundError(f"GNR checkpoint not found at '{checkpoint}'") 

203 

204 latent_channels = self.base_message.shape[1] 

205 in_channels = latent_channels * (2 if classifier_type == 1 else 1) 

206 return GNRRestorer( 

207 checkpoint_path=checkpoint_path, 

208 in_channels=in_channels, 

209 out_channels=latent_channels, 

210 nf=nf, 

211 device=device, 

212 classifier_type=classifier_type, 

213 base_message=self.base_message if classifier_type == 1 else None, 

214 ) 

215 

216 def _load_fuser(self, checkpoint: Optional[Union[str, Path]]): 

217 if not checkpoint: 

218 return None 

219 if joblib is None: 

220 raise ImportError( 

221 "joblib is required to load the GaussMarker fuser. Install joblib or disable the fuser." 

222 ) 

223 

224 model_path: Optional[Path] = None 

225 candidates: list[Path] = [] 

226 

227 # If a HuggingFace repo is provided, try cache dir and HF download 

228 if self.huggingface_repo: 

229 orig = Path(checkpoint) 

230 local_path = orig if orig.is_file() else None 

231 if self.hf_dir: 

232 potential_local = Path(self.hf_dir) / Path(checkpoint).name 

233 if potential_local.is_file(): 

234 local_path = potential_local 

235 if local_path and local_path.is_file(): 

236 model_path = local_path 

237 else: 

238 try: 

239 hf_path = hf_hub_download( 

240 repo_id=self.huggingface_repo, 

241 filename=Path(checkpoint).name, 

242 cache_dir=self.hf_dir, 

243 ) 

244 model_path = Path(hf_path) 

245 except Exception: 

246 # As a fallback, try snapshot of the official repo used by GaussMarker fuser 

247 try: 

248 local_dir = str(Path(checkpoint).parts[0]) if len(Path(checkpoint).parts) > 0 else "gm_fuser" 

249 snapshot_download( 

250 repo_id="Generative-Watermark-Toolkits/MarkDiffusion-gm", 

251 local_dir=local_dir, 

252 repo_type="model", 

253 local_dir_use_symlinks=False, 

254 endpoint=os.getenv("HF_ENDPOINT", "https://huggingface.co"), 

255 ) 

256 candidates.append(Path(local_dir) / Path(checkpoint).name) 

257 except Exception as e: 

258 raise FileNotFoundError(f"Fuser checkpoint not found on HuggingFace and snapshot fallback failed. error: {e}") 

259 

260 # Fallback: local candidates 

261 if model_path is None: 

262 candidates.append(Path(checkpoint)) 

263 base_dir = Path(__file__).resolve().parent 

264 candidates.append(base_dir / Path(checkpoint)) 

265 candidates.append(base_dir.parent.parent / Path(checkpoint)) 

266 for candidate in candidates: 

267 if candidate.is_file(): 

268 model_path = candidate 

269 break 

270 

271 if model_path is None or not model_path.is_file(): 

272 raise FileNotFoundError(f"Fuser checkpoint not found at '{checkpoint}'") 

273 

274 return joblib.load(model_path) 

275 

276 # ------------------------------------------------------------------ 

277 # Public API 

278 # ------------------------------------------------------------------ 

279 def eval_watermark( 

280 self, 

281 reversed_latents: torch.Tensor, 

282 reference_latents: Optional[torch.Tensor] = None, 

283 detector_type: str = "bit_acc", 

284 ) -> Dict[str, Union[bool, float]]: 

285 detector_type = detector_type.lower() 

286 reversed_latents = reversed_latents.to(self.device, dtype=torch.float32) 

287 

288 # Bit-level reconstruction 

289 bit_watermark = self.generator.pred_w_from_latent(reversed_latents) 

290 reference_bits = self.generator.watermark_tensor(self.device) 

291 bit_acc = float((bit_watermark == reference_bits).float().mean().item()) 

292 

293 # Message bit accuracy (post ChaCha decryption) 

294 reversed_m = self.generator.pred_m_from_latent(reversed_latents) 

295 message_bits = torch.from_numpy(self.generator.message_bits.astype("float32")).to(self.device) 

296 message_acc = float((reversed_m.flatten().float() == message_bits).float().mean().item()) 

297 

298 # Frequency-domain statistics 

299 complex_l1 = self._complex_l1(reversed_latents) 

300 signal_acc = self._signal_accuracy(reversed_latents) if "signal" in self.w_measurement else None 

301 

302 gnr_bit_acc = None 

303 gnr_message_acc = None 

304 if self.gnr_restorer is not None: 

305 restored_binary = self.gnr_restorer.restore_binary(reversed_m, threshold=self.gnr_binary_threshold) 

306 restored_w = self.generator.pred_w_from_m(restored_binary) 

307 gnr_bit_acc = float((restored_w == reference_bits).float().mean().item()) 

308 gnr_message_acc = float((restored_binary.flatten() == message_bits).float().mean().item()) 

309 

310 frequency_score = -complex_l1 * self.fuser_frequency_scale 

311 metrics: Dict[str, Union[bool, float]] = { 

312 "bit_acc": bit_acc, 

313 "message_acc": message_acc, 

314 "complex_l1": complex_l1, 

315 "frequency_score": frequency_score, 

316 "tau_bits": float(self.generator.tau_bits or 0.5), 

317 "tau_onebit": float(self.generator.tau_onebit or 0.5), 

318 } 

319 if signal_acc is not None: 

320 metrics["signal_acc"] = signal_acc 

321 if gnr_bit_acc is not None: 

322 metrics["gnr_bit_acc"] = gnr_bit_acc 

323 if gnr_message_acc is not None: 

324 metrics["gnr_message_acc"] = gnr_message_acc 

325 

326 # Determine binary decision based on requested detector type 

327 decision_threshold = self.threshold if self.gnr_threshold is None else self.gnr_threshold 

328 decision_bit_acc = bit_acc 

329 if self.gnr_restorer is not None and self.gnr_use_for_decision and gnr_bit_acc is not None: 

330 decision_bit_acc = max(decision_bit_acc, gnr_bit_acc) 

331 metrics["decision_bit_acc"] = decision_bit_acc 

332 metrics["decision_threshold"] = decision_threshold 

333 

334 fused_score = None 

335 fused_threshold = self.fuser_threshold if self.fuser is not None else None 

336 if self.fuser is not None: 

337 spatial_score = gnr_bit_acc if gnr_bit_acc is not None else bit_acc 

338 frequency_score = metrics["frequency_score"] 

339 features = np.array([[spatial_score, frequency_score]], dtype=np.float32) 

340 if hasattr(self.fuser, "predict_proba"): 

341 fused_score = float(self.fuser.predict_proba(features)[0, 1]) 

342 elif hasattr(self.fuser, "decision_function"): 

343 fused_score = float(self.fuser.decision_function(features)[0]) 

344 else: 

345 raise AttributeError("Unsupported fuser model: missing predict_proba/decision_function") 

346 metrics["fused_score"] = fused_score 

347 metrics["fused_threshold"] = fused_threshold 

348 

349 if detector_type == "message_acc": 

350 is_watermarked = message_acc >= self.message_threshold 

351 elif detector_type == "complex_l1": 

352 threshold = self.l1_threshold if self.l1_threshold is not None else self.threshold 

353 is_watermarked = complex_l1 <= threshold 

354 elif detector_type == "signal_acc": 

355 if signal_acc is None: 

356 raise ValueError("Signal accuracy requested but watermark measurement does not use signal mode.") 

357 is_watermarked = signal_acc >= self.threshold 

358 elif detector_type == "gnr_bit_acc": 

359 if gnr_bit_acc is None: 

360 raise ValueError("GNR checkpoint not provided, cannot compute GNR-based accuracy.") 

361 is_watermarked = gnr_bit_acc >= decision_threshold 

362 elif detector_type == "fused": 

363 if fused_score is None: 

364 raise ValueError("Fuser checkpoint not provided, cannot compute fused score.") 

365 is_watermarked = fused_score >= fused_threshold 

366 elif detector_type == "all": 

367 if fused_score is not None: 

368 is_watermarked = fused_score >= fused_threshold 

369 else: 

370 is_watermarked = decision_bit_acc >= decision_threshold 

371 else: 

372 if detector_type in {"bit_acc", "is_watermarked"} and fused_score is not None: 

373 is_watermarked = fused_score >= fused_threshold 

374 elif detector_type in {"bit_acc", "is_watermarked"}: 

375 is_watermarked = decision_bit_acc >= decision_threshold 

376 else: 

377 raise ValueError(f"Unsupported detector_type '{detector_type}' for GaussMarker.") 

378 

379 metrics["is_watermarked"] = bool(is_watermarked) 

380 return metrics