Coverage for detection / gm / gm_detection.py: 89.85%
197 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
15"""GaussMarker detection utilities.
17This module adapts the official GaussMarker detection pipeline to the
18MarkDiffusion detection API. It evaluates recovered diffusion latents to
19decide whether a watermark is present, reporting both hard decisions and
20auxiliary scores (bit/message accuracies, frequency-domain distances).
21"""
23from __future__ import annotations
25from pathlib import Path
26from typing import Dict, Optional, Union
28import numpy as np
29import torch
31import joblib
32from huggingface_hub import hf_hub_download
35from detection.base import BaseDetector
36from watermark.gm.gm import GaussianShadingChaCha, extract_complex_sign
37from watermark.gm.gnr import GNRRestorer
40class GMDetector(BaseDetector):
41 """Detector for GaussMarker watermarks.
43 Args:
44 watermark_generator: Instance of :class:`GaussianShadingChaCha` that
45 holds the original watermark bits and ChaCha20 key stream.
46 watermarking_mask: Frequency-domain mask (or label map) indicating the
47 region that carries the watermark.
48 gt_patch: Reference watermark pattern in the frequency domain.
49 w_measurement: Measurement mode (e.g., ``"l1_complex"`` or
50 ``"signal_complex"``), mirroring the official implementation.
51 device: Torch device used for evaluation.
52 bit_threshold: Optional override for the bit-accuracy decision
53 threshold. Defaults to the generator's ``tau_bits`` value.
54 message_threshold: Optional threshold for message accuracy decisions.
55 l1_threshold: Optional threshold for frequency L1 distance decisions
56 (smaller is better).
57 gnr_checkpoint: Path or filename of GNR classifier checkpoint; supports
58 HuggingFace Hub fallback if ``huggingface_repo`` is provided.
59 gnr_classifier_type: See original implementation.
60 gnr_model_nf: Number of feature maps for GNR model.
61 gnr_binary_threshold: Threshold to binarize GNR outputs.
62 gnr_use_for_decision: Whether to use GNR bit accuracy to assist decision.
63 gnr_threshold: Optional threshold override when using GNR.
64 fuser_checkpoint: Path or filename of fuser model; supports HuggingFace
65 Hub fallback if ``huggingface_repo`` is provided.
66 fuser_threshold: Decision threshold when using fused score.
67 fuser_frequency_scale: Frequency score scaling factor.
68 huggingface_repo: Optional HuggingFace repository for checkpoint download.
69 hf_dir: Optional local cache directory for HuggingFace downloads.
70 """
72 def __init__(
73 self,
74 watermark_generator: GaussianShadingChaCha,
75 watermarking_mask: torch.Tensor,
76 gt_patch: torch.Tensor,
77 w_measurement: str,
78 device: Union[str, torch.device],
79 bit_threshold: Optional[float] = None,
80 message_threshold: Optional[float] = None,
81 l1_threshold: Optional[float] = None,
82 gnr_checkpoint: Optional[Union[str, Path]] = None,
83 gnr_classifier_type: int = 0,
84 gnr_model_nf: int = 128,
85 gnr_binary_threshold: float = 0.5,
86 gnr_use_for_decision: bool = True,
87 gnr_threshold: Optional[float] = None,
88 fuser_checkpoint: Optional[Union[str, Path]] = None,
89 fuser_threshold: Optional[float] = None,
90 fuser_frequency_scale: float = 0.01,
91 huggingface_repo: Optional[str] = None,
92 hf_dir: Optional[str] = None,
93 ) -> None:
94 self.generator = watermark_generator
95 device = torch.device(device)
96 # Ensure watermark/message buffers are initialised
97 _, base_message = self.generator.create_watermark_and_return_w_m()
98 self.base_message = base_message.to(device)
99 watermarking_mask = watermarking_mask.to(device)
100 gt_patch = gt_patch.to(device)
102 threshold = bit_threshold if bit_threshold is not None else (
103 self.generator.tau_bits or 0.5
104 )
105 super().__init__(threshold=float(threshold), device=device)
107 self.watermarking_mask = watermarking_mask
108 self.gt_patch = gt_patch
109 self.w_measurement = w_measurement.lower()
110 self.message_threshold = (
111 float(message_threshold) if message_threshold is not None else float(self.generator.tau_onebit or threshold)
112 )
113 self.l1_threshold = float(l1_threshold) if l1_threshold is not None else None
114 self.gnr_binary_threshold = gnr_binary_threshold
115 self.gnr_use_for_decision = gnr_use_for_decision
116 self.gnr_threshold = float(gnr_threshold) if gnr_threshold is not None else None
117 self.huggingface_repo = huggingface_repo
118 self.hf_dir = hf_dir
120 self.gnr_restorer = self._build_gnr_restorer(
121 checkpoint=gnr_checkpoint,
122 device=device,
123 classifier_type=gnr_classifier_type,
124 nf=gnr_model_nf,
125 )
126 self.fuser = self._load_fuser(fuser_checkpoint)
127 self.fuser_threshold = float(fuser_threshold) if fuser_threshold is not None else 0.5
128 self.fuser_frequency_scale = float(fuser_frequency_scale)
130 # ------------------------------------------------------------------
131 # Helper computations
132 # ------------------------------------------------------------------
134 def _complex_l1(self, reversed_latents: torch.Tensor) -> float:
135 fft_latents = torch.fft.fftshift(torch.fft.fft2(reversed_latents), dim=(-1, -2))
136 if self.watermarking_mask.dtype == torch.bool:
137 selector = self.watermarking_mask
138 else:
139 selector = self.watermarking_mask != 0
140 if selector.sum() == 0:
141 return 0.0
142 diff = torch.abs(fft_latents[selector] - self.gt_patch[selector])
143 return float(diff.mean().item())
145 def _signal_accuracy(self, reversed_latents: torch.Tensor) -> float:
146 fft_latents = torch.fft.fftshift(torch.fft.fft2(reversed_latents), dim=(-1, -2))
147 if self.watermarking_mask.dtype == torch.bool:
148 selector = self.watermarking_mask
149 else:
150 selector = self.watermarking_mask != 0
151 if selector.sum() == 0:
152 return 0.0
153 latents_sign = extract_complex_sign(fft_latents[selector])
154 target_sign = extract_complex_sign(self.gt_patch[selector])
155 return float((latents_sign == target_sign).float().mean().item())
157 def _build_gnr_restorer(
158 self,
159 checkpoint: Optional[Union[str, Path]],
160 device: torch.device,
161 classifier_type: int,
162 nf: int,
163 ) -> Optional[GNRRestorer]:
164 if not checkpoint:
165 return None
167 checkpoint_path: Optional[Path] = None
168 candidates: list[Path] = []
170 # If a HuggingFace repo is provided, try cache dir and HF download
171 if self.huggingface_repo:
172 # Existing local file
173 orig = Path(checkpoint)
174 local_path = orig if orig.is_file() else None
175 if self.hf_dir:
176 potential_local = Path(self.hf_dir) / Path(checkpoint).name
177 if potential_local.is_file():
178 local_path = potential_local
179 if local_path and local_path.is_file():
180 checkpoint_path = local_path
181 else:
182 try:
183 hf_path = hf_hub_download(
184 repo_id=self.huggingface_repo,
185 filename=Path(checkpoint).name,
186 cache_dir=self.hf_dir,
187 )
188 checkpoint_path = Path(hf_path)
189 except Exception as e:
190 raise FileNotFoundError(f"GNR checkpoint not found on HuggingFace repo '{self.huggingface_repo}'. error: {e}")
191 # Fallback: look for local candidates
192 if checkpoint_path is None:
193 candidates = [Path(checkpoint)]
194 base_dir = Path(__file__).resolve().parent
195 candidates.append(base_dir / Path(checkpoint))
196 candidates.append(base_dir.parent.parent / Path(checkpoint))
197 for candidate in candidates:
198 if candidate.is_file():
199 checkpoint_path = candidate
200 break
201 if checkpoint_path is None:
202 raise FileNotFoundError(f"GNR checkpoint not found at '{checkpoint}'")
204 latent_channels = self.base_message.shape[1]
205 in_channels = latent_channels * (2 if classifier_type == 1 else 1)
206 return GNRRestorer(
207 checkpoint_path=checkpoint_path,
208 in_channels=in_channels,
209 out_channels=latent_channels,
210 nf=nf,
211 device=device,
212 classifier_type=classifier_type,
213 base_message=self.base_message if classifier_type == 1 else None,
214 )
216 def _load_fuser(self, checkpoint: Optional[Union[str, Path]]):
217 if not checkpoint:
218 return None
219 if joblib is None:
220 raise ImportError(
221 "joblib is required to load the GaussMarker fuser. Install joblib or disable the fuser."
222 )
224 model_path: Optional[Path] = None
225 candidates: list[Path] = []
227 # If a HuggingFace repo is provided, try cache dir and HF download
228 if self.huggingface_repo:
229 orig = Path(checkpoint)
230 local_path = orig if orig.is_file() else None
231 if self.hf_dir:
232 potential_local = Path(self.hf_dir) / Path(checkpoint).name
233 if potential_local.is_file():
234 local_path = potential_local
235 if local_path and local_path.is_file():
236 model_path = local_path
237 else:
238 try:
239 hf_path = hf_hub_download(
240 repo_id=self.huggingface_repo,
241 filename=Path(checkpoint).name,
242 cache_dir=self.hf_dir,
243 )
244 model_path = Path(hf_path)
245 except Exception:
246 # As a fallback, try snapshot of the official repo used by GaussMarker fuser
247 try:
248 local_dir = str(Path(checkpoint).parts[0]) if len(Path(checkpoint).parts) > 0 else "gm_fuser"
249 snapshot_download(
250 repo_id="Generative-Watermark-Toolkits/MarkDiffusion-gm",
251 local_dir=local_dir,
252 repo_type="model",
253 local_dir_use_symlinks=False,
254 endpoint=os.getenv("HF_ENDPOINT", "https://huggingface.co"),
255 )
256 candidates.append(Path(local_dir) / Path(checkpoint).name)
257 except Exception as e:
258 raise FileNotFoundError(f"Fuser checkpoint not found on HuggingFace and snapshot fallback failed. error: {e}")
260 # Fallback: local candidates
261 if model_path is None:
262 candidates.append(Path(checkpoint))
263 base_dir = Path(__file__).resolve().parent
264 candidates.append(base_dir / Path(checkpoint))
265 candidates.append(base_dir.parent.parent / Path(checkpoint))
266 for candidate in candidates:
267 if candidate.is_file():
268 model_path = candidate
269 break
271 if model_path is None or not model_path.is_file():
272 raise FileNotFoundError(f"Fuser checkpoint not found at '{checkpoint}'")
274 return joblib.load(model_path)
276 # ------------------------------------------------------------------
277 # Public API
278 # ------------------------------------------------------------------
279 def eval_watermark(
280 self,
281 reversed_latents: torch.Tensor,
282 reference_latents: Optional[torch.Tensor] = None,
283 detector_type: str = "bit_acc",
284 ) -> Dict[str, Union[bool, float]]:
285 detector_type = detector_type.lower()
286 reversed_latents = reversed_latents.to(self.device, dtype=torch.float32)
288 # Bit-level reconstruction
289 bit_watermark = self.generator.pred_w_from_latent(reversed_latents)
290 reference_bits = self.generator.watermark_tensor(self.device)
291 bit_acc = float((bit_watermark == reference_bits).float().mean().item())
293 # Message bit accuracy (post ChaCha decryption)
294 reversed_m = self.generator.pred_m_from_latent(reversed_latents)
295 message_bits = torch.from_numpy(self.generator.message_bits.astype("float32")).to(self.device)
296 message_acc = float((reversed_m.flatten().float() == message_bits).float().mean().item())
298 # Frequency-domain statistics
299 complex_l1 = self._complex_l1(reversed_latents)
300 signal_acc = self._signal_accuracy(reversed_latents) if "signal" in self.w_measurement else None
302 gnr_bit_acc = None
303 gnr_message_acc = None
304 if self.gnr_restorer is not None:
305 restored_binary = self.gnr_restorer.restore_binary(reversed_m, threshold=self.gnr_binary_threshold)
306 restored_w = self.generator.pred_w_from_m(restored_binary)
307 gnr_bit_acc = float((restored_w == reference_bits).float().mean().item())
308 gnr_message_acc = float((restored_binary.flatten() == message_bits).float().mean().item())
310 frequency_score = -complex_l1 * self.fuser_frequency_scale
311 metrics: Dict[str, Union[bool, float]] = {
312 "bit_acc": bit_acc,
313 "message_acc": message_acc,
314 "complex_l1": complex_l1,
315 "frequency_score": frequency_score,
316 "tau_bits": float(self.generator.tau_bits or 0.5),
317 "tau_onebit": float(self.generator.tau_onebit or 0.5),
318 }
319 if signal_acc is not None:
320 metrics["signal_acc"] = signal_acc
321 if gnr_bit_acc is not None:
322 metrics["gnr_bit_acc"] = gnr_bit_acc
323 if gnr_message_acc is not None:
324 metrics["gnr_message_acc"] = gnr_message_acc
326 # Determine binary decision based on requested detector type
327 decision_threshold = self.threshold if self.gnr_threshold is None else self.gnr_threshold
328 decision_bit_acc = bit_acc
329 if self.gnr_restorer is not None and self.gnr_use_for_decision and gnr_bit_acc is not None:
330 decision_bit_acc = max(decision_bit_acc, gnr_bit_acc)
331 metrics["decision_bit_acc"] = decision_bit_acc
332 metrics["decision_threshold"] = decision_threshold
334 fused_score = None
335 fused_threshold = self.fuser_threshold if self.fuser is not None else None
336 if self.fuser is not None:
337 spatial_score = gnr_bit_acc if gnr_bit_acc is not None else bit_acc
338 frequency_score = metrics["frequency_score"]
339 features = np.array([[spatial_score, frequency_score]], dtype=np.float32)
340 if hasattr(self.fuser, "predict_proba"):
341 fused_score = float(self.fuser.predict_proba(features)[0, 1])
342 elif hasattr(self.fuser, "decision_function"):
343 fused_score = float(self.fuser.decision_function(features)[0])
344 else:
345 raise AttributeError("Unsupported fuser model: missing predict_proba/decision_function")
346 metrics["fused_score"] = fused_score
347 metrics["fused_threshold"] = fused_threshold
349 if detector_type == "message_acc":
350 is_watermarked = message_acc >= self.message_threshold
351 elif detector_type == "complex_l1":
352 threshold = self.l1_threshold if self.l1_threshold is not None else self.threshold
353 is_watermarked = complex_l1 <= threshold
354 elif detector_type == "signal_acc":
355 if signal_acc is None:
356 raise ValueError("Signal accuracy requested but watermark measurement does not use signal mode.")
357 is_watermarked = signal_acc >= self.threshold
358 elif detector_type == "gnr_bit_acc":
359 if gnr_bit_acc is None:
360 raise ValueError("GNR checkpoint not provided, cannot compute GNR-based accuracy.")
361 is_watermarked = gnr_bit_acc >= decision_threshold
362 elif detector_type == "fused":
363 if fused_score is None:
364 raise ValueError("Fuser checkpoint not provided, cannot compute fused score.")
365 is_watermarked = fused_score >= fused_threshold
366 elif detector_type == "all":
367 if fused_score is not None:
368 is_watermarked = fused_score >= fused_threshold
369 else:
370 is_watermarked = decision_bit_acc >= decision_threshold
371 else:
372 if detector_type in {"bit_acc", "is_watermarked"} and fused_score is not None:
373 is_watermarked = fused_score >= fused_threshold
374 elif detector_type in {"bit_acc", "is_watermarked"}:
375 is_watermarked = decision_bit_acc >= decision_threshold
376 else:
377 raise ValueError(f"Unsupported detector_type '{detector_type}' for GaussMarker.")
379 metrics["is_watermarked"] = bool(is_watermarked)
380 return metrics