Coverage for detection / wind / wind_detection.py: 98.44%

64 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16import torch 

17from detection.base import BaseDetector 

18import torch.nn.functional as F 

19import numpy as np 

20import logging 

21from typing import Dict, Any 

22 

23class WINDetector(BaseDetector): 

24 """WIND Watermark Detector (Two-Stage Robust Detection)""" 

25 

26 def __init__(self, 

27 noise_groups: Dict[int, torch.Tensor], 

28 group_patterns: Dict[int, torch.Tensor], 

29 threshold: float, 

30 device: torch.device, 

31 group_radius: int = 10): 

32 super().__init__(threshold, device) 

33 

34 self.noise_groups = { 

35 int(g): [noise.to(device) for noise in noise_list] 

36 for g, noise_list in noise_groups.items() 

37 } 

38 

39 self.group_patterns = { 

40 int(g): pattern.to(device) 

41 for g, pattern in group_patterns.items() 

42 } 

43 

44 self.group_radius = group_radius 

45 self.device = device 

46 self.group_threshold = 0.7 

47 self.logger = logging.getLogger(__name__) 

48 

49 

50 def _circle_mask(self, size: int, r: int) -> torch.Tensor: 

51 """Using circle mask(same with Utils)""" 

52 y, x = torch.meshgrid( 

53 torch.arange(size, device=self.device), 

54 torch.arange(size, device=self.device), 

55 indexing='ij' 

56 ) 

57 center = size // 2 

58 dist = (x - center)**2 + (y - center)**2 

59 return ((dist >= (r-2)**2) & (dist <= r**2)).float() 

60 

61 def _fft_transform(self, latents: torch.Tensor) -> torch.Tensor: 

62 """Convert to Fourier space with shift""" 

63 return torch.fft.fftshift(torch.fft.fft2(latents), dim=(-1, -2)) 

64 

65 

66 def _retrieve_group(self, z_fft: torch.Tensor) -> int: 

67 similarities = [] 

68 mask = self._circle_mask(z_fft.shape[-1], self.group_radius) 

69 

70 for group_id, pattern in self.group_patterns.items(): 

71 try: 

72 z_masked = torch.abs(z_fft) * mask 

73 pattern_masked = torch.abs(pattern) * mask 

74 

75 sim = F.cosine_similarity( 

76 z_masked.flatten().unsqueeze(0), 

77 pattern_masked.flatten().unsqueeze(0) 

78 ).item() 

79 

80 similarities.append((group_id, sim)) 

81 except Exception as e: 

82 self.logger.warning(f"Error processing group {group_id}: {str(e)}") 

83 continue 

84 return max(similarities, key=lambda x: x[1])[0] if similarities else -1 

85 

86 def _match_noise(self, z: torch.Tensor, group_id: int) -> Dict[str, Any]: 

87 if group_id not in self.noise_groups or group_id == -1: 

88 return {'cosine_similarity': 0.0, 'best_match': None} 

89 

90 mask = self._circle_mask(z.shape[-1], self.group_radius) 

91 z_fft = torch.fft.fftshift(torch.fft.fft2(z), dim=(-1, -2)) 

92 z_fft = z_fft - self.group_patterns[group_id] * mask 

93 z_cleaned = torch.fft.ifft2(torch.fft.ifftshift(z_fft)).real 

94 

95 max_sim = -1.0 

96 best_noise = None 

97 

98 for candidate in self.noise_groups[group_id]: 

99 sim = F.cosine_similarity( 

100 z_cleaned.flatten().unsqueeze(0), 

101 candidate.flatten().unsqueeze(0) 

102 ).item() 

103 if sim > max_sim: 

104 max_sim = sim 

105 best_noise = candidate 

106 

107 return { 

108 'cosine_similarity': max(max_sim, 0.0), 

109 'best_match': best_noise 

110 } 

111 

112 def eval_watermark(self, 

113 reversed_latents: torch.Tensor, 

114 reference_latents: torch.Tensor = None, 

115 detector_type: str = "cosine_similarity") -> Dict[str, Any]: 

116 """ 

117 Two-stage watermark detection 

118  

119 Args: 

120 reversed_latents: Latents obtained through reverse diffusion [C,H,W] 

121 reference_latents: Not used (for API compatibility) 

122 detector_type: Detection method ('cosine_similarity' only supported) 

123  

124 Returns: 

125 Dictionary containing detection results: 

126 - group_id: Identified group ID 

127 - similarity: Highest similarity score 

128 - is_watermarked: Detection result 

129 - best_match: Best matching noise tensor 

130 """ 

131 if detector_type != "cosine_similarity": 

132 raise ValueError(f"WIND detector only supports 'cosine' method, got {detector_type}") 

133 

134 try: 

135 # Input validation 

136 if not isinstance(reversed_latents, torch.Tensor): 

137 reversed_latents = torch.tensor(reversed_latents, device=self.device) 

138 reversed_latents = reversed_latents.to(self.device) 

139 

140 # Stage 1: Group identification 

141 z_fft = self._fft_transform(reversed_latents) 

142 group_id = self._retrieve_group(z_fft) 

143 

144 # Stage 2: Noise matching 

145 match_result = self._match_noise(reversed_latents, group_id) 

146 

147 return { 

148 'group_id': group_id, 

149 'cosine_similarity': match_result['cosine_similarity'], 

150 'is_watermarked': bool(match_result['cosine_similarity'] > self.threshold), 

151 #'best_match': match_result['best_match'] 

152 } 

153 

154 except Exception as e: 

155 self.logger.error(f"Detection failed: {str(e)}") 

156 return { 

157 'group_id': -1, 

158 'cosine_similarity': 0.0, 

159 'is_watermarked': False, 

160 # 'best_match': None 

161 } 

162