Coverage for detection / sfw / sfw_detection.py: 94.55%

55 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16import torch 

17from detection.base import BaseDetector 

18from scipy.stats import ncx2 

19 

20class SFWDetector(BaseDetector): 

21 

22 def __init__(self, 

23 watermarking_mask: torch.Tensor, 

24 gt_patch: torch.Tensor, 

25 w_channel: int, 

26 threshold: float, 

27 device: torch.device, 

28 wm_type: str = "HSTR"): 

29 super().__init__(threshold, device) 

30 self.watermarking_mask = watermarking_mask 

31 self.gt_patch = gt_patch 

32 self.w_channel = w_channel 

33 self.wm_type = wm_type 

34 

35 @torch.no_grad() 

36 def get_distance_hsqr(self,qr_gt_bool, target_fft,p=1): 

37 """ 

38 qr_gt_bool : (c_wm,42,42) boolean 

39 target_fft : (1,4,64,64) complex64 

40 """ 

41 qr_gt_bool = qr_gt_bool.squeeze(0) 

42 center_row = target_fft.shape[-2] // 2 # 32 

43 qr_pix_len = qr_gt_bool.shape[-1] # 42 

44 qr_pix_half = (qr_pix_len + 1) // 2 # 21 

45 qr_gt_f32 = torch.where(qr_gt_bool, torch.tensor(45.0), torch.tensor(-45.0)).to(torch.float32) # (c_wm,42,42) boolean -> float32 

46 qr_left = qr_gt_f32[0,:, :qr_pix_half] # (42,21) float32 

47 qr_right = qr_gt_f32[0,:, qr_pix_half:] # (42,21) float32 

48 qr_complex = torch.complex(qr_left, qr_right).to(target_fft.device) # (42,21) complex64 

49 row_start = 10 + 1 # 11 

50 row_end = row_start + qr_pix_len # 53 = 11+42 

51 col_start = center_row + 1 # 33 = 32+1 

52 col_end = col_start + qr_pix_half # 54 = 33+21 

53 qr_slice = (0, self.w_channel, slice(row_start, row_end), slice(col_start, col_end)) # (42,21) 

54 diff = torch.abs(qr_complex - target_fft[qr_slice]) # (42,21) 

55 return torch.mean(diff).item() 

56 

57 def eval_watermark(self, 

58 reversed_latents: torch.Tensor, 

59 reference_latents: torch.Tensor = None, 

60 detector_type: str = "l1_distance") -> float: 

61 h = reversed_latents.shape[-2] 

62 

63 # Handle small inputs (e.g. CI tests with 64x64 images -> 8x8 latents) 

64 if h < 44: 

65 return { 

66 'is_watermarked': False, 

67 'l1_distance': 0.0, 

68 'bit_acc': 0.0 

69 } 

70 

71 start, end = 10, 54 

72 center_slice = (slice(None), slice(None), slice(start, end), slice(start, end)) 

73 reversed_latents_fft = torch.zeros_like(reversed_latents, dtype=torch.complex64) 

74 reversed_latents_fft[center_slice] = torch.fft.fftshift(torch.fft.fft2(reversed_latents[center_slice]), dim=(-1, -2)) 

75 if self.wm_type == "HSQR": 

76 if detector_type == 'l1_distance': 

77 hsqr_distance = self.get_distance_hsqr(qr_gt_bool=self.gt_patch, target_fft=reversed_latents_fft) 

78 return { 

79 'is_watermarked': hsqr_distance < self.threshold, 

80 'l1_distance': hsqr_distance 

81 } 

82 else: 

83 raise ValueError(f"SFW(HSQR)'s watermark detector type {self.detector_type} not supported") 

84 else: 

85 if detector_type == 'l1_distance': 

86 target_patch = self.gt_patch #[self.watermarking_mask].flatten() 

87 l1_distance = torch.abs(reversed_latents_fft[self.watermarking_mask] - target_patch[self.watermarking_mask]).mean().item() 

88 return { 

89 'is_watermarked': l1_distance < self.threshold, 

90 'l1_distance': l1_distance 

91 } 

92 elif detector_type == 'p_value': 

93 reversed_latents_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents), dim=(-1, -2))[self.watermarking_mask].flatten() 

94 target_patch = self.gt_patch[self.watermarking_mask].flatten() 

95 target_patch = torch.concatenate([target_patch.real, target_patch.imag]) 

96 reversed_latents_fft = torch.concatenate([reversed_latents_fft.real, reversed_latents_fft.imag]) 

97 sigma_ = reversed_latents_fft.std() 

98 lambda_ = (target_patch ** 2 / sigma_ ** 2).sum().item() 

99 x = (((reversed_latents_fft - target_patch) / sigma_) ** 2).sum().item() 

100 p = ncx2.cdf(x=x, df=len(target_patch), nc=lambda_) 

101 return { 

102 'is_watermarked': bool(p < self.threshold), 

103 'p_value': p 

104 } 

105 else: 

106 raise ValueError(f"SFW(HSTR)'s watermark detector type {self.detector_type} not supported") 

107