Coverage for markdiffusion / detection / sfw / sfw_detection.py: 94.64%
56 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-14 19:25 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-14 19:25 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16import torch
17from markdiffusion.detection.base import BaseDetector
18from scipy.stats import ncx2
20class SFWDetector(BaseDetector):
22 def __init__(self,
23 watermarking_mask: torch.Tensor,
24 gt_patch: torch.Tensor,
25 w_channel: int,
26 threshold: float,
27 device: torch.device,
28 wm_type: str = "HSTR",
29 threshold_p_value: float = 0.01):
30 super().__init__(threshold, device)
31 self.watermarking_mask = watermarking_mask
32 self.gt_patch = gt_patch
33 self.w_channel = w_channel
34 self.wm_type = wm_type
35 self.threshold_p_value = threshold_p_value
37 @torch.no_grad()
38 def get_distance_hsqr(self,qr_gt_bool, target_fft,p=1):
39 """
40 qr_gt_bool : (c_wm,42,42) boolean
41 target_fft : (1,4,64,64) complex64
42 """
43 qr_gt_bool = qr_gt_bool.squeeze(0)
44 center_row = target_fft.shape[-2] // 2 # 32
45 qr_pix_len = qr_gt_bool.shape[-1] # 42
46 qr_pix_half = (qr_pix_len + 1) // 2 # 21
47 qr_gt_f32 = torch.where(qr_gt_bool, torch.tensor(45.0), torch.tensor(-45.0)).to(torch.float32) # (c_wm,42,42) boolean -> float32
48 qr_left = qr_gt_f32[0,:, :qr_pix_half] # (42,21) float32
49 qr_right = qr_gt_f32[0,:, qr_pix_half:] # (42,21) float32
50 qr_complex = torch.complex(qr_left, qr_right).to(target_fft.device) # (42,21) complex64
51 row_start = 10 + 1 # 11
52 row_end = row_start + qr_pix_len # 53 = 11+42
53 col_start = center_row + 1 # 33 = 32+1
54 col_end = col_start + qr_pix_half # 54 = 33+21
55 qr_slice = (0, self.w_channel, slice(row_start, row_end), slice(col_start, col_end)) # (42,21)
56 diff = torch.abs(qr_complex - target_fft[qr_slice]) # (42,21)
57 return torch.mean(diff).item()
59 def eval_watermark(self,
60 reversed_latents: torch.Tensor,
61 reference_latents: torch.Tensor = None,
62 detector_type: str = "l1_distance") -> float:
63 h = reversed_latents.shape[-2]
65 # Handle small inputs (e.g. CI tests with 64x64 images -> 8x8 latents)
66 if h < 44:
67 return {
68 'is_watermarked': False,
69 'l1_distance': 0.0,
70 'bit_acc': 0.0
71 }
73 start, end = 10, 54
74 center_slice = (slice(None), slice(None), slice(start, end), slice(start, end))
75 reversed_latents_fft = torch.zeros_like(reversed_latents, dtype=torch.complex64)
76 reversed_latents_fft[center_slice] = torch.fft.fftshift(torch.fft.fft2(reversed_latents[center_slice]), dim=(-1, -2))
77 if self.wm_type == "HSQR":
78 if detector_type == 'l1_distance':
79 hsqr_distance = self.get_distance_hsqr(qr_gt_bool=self.gt_patch, target_fft=reversed_latents_fft)
80 return {
81 'is_watermarked': hsqr_distance < self.threshold,
82 'l1_distance': hsqr_distance
83 }
84 else:
85 raise ValueError(f"SFW(HSQR)'s watermark detector type {self.detector_type} not supported")
86 else:
87 if detector_type == 'l1_distance':
88 target_patch = self.gt_patch #[self.watermarking_mask].flatten()
89 l1_distance = torch.abs(reversed_latents_fft[self.watermarking_mask] - target_patch[self.watermarking_mask]).mean().item()
90 return {
91 'is_watermarked': l1_distance < self.threshold,
92 'l1_distance': l1_distance
93 }
94 elif detector_type == 'p_value':
95 reversed_latents_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents), dim=(-1, -2))[self.watermarking_mask].flatten()
96 target_patch = self.gt_patch[self.watermarking_mask].flatten()
97 target_patch = torch.concatenate([target_patch.real, target_patch.imag])
98 reversed_latents_fft = torch.concatenate([reversed_latents_fft.real, reversed_latents_fft.imag])
99 sigma_ = reversed_latents_fft.std()
100 lambda_ = (target_patch ** 2 / sigma_ ** 2).sum().item()
101 x = (((reversed_latents_fft - target_patch) / sigma_) ** 2).sum().item()
102 p = ncx2.cdf(x=x, df=len(target_patch), nc=lambda_)
103 return {
104 'is_watermarked': bool(p < self.threshold_p_value),
105 'p_value': p
106 }
107 else:
108 raise ValueError(f"SFW(HSTR)'s watermark detector type {self.detector_type} not supported")