Coverage for detection / wind / wind_detection.py: 98.44%
64 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16import torch
17from detection.base import BaseDetector
18import torch.nn.functional as F
19import numpy as np
20import logging
21from typing import Dict, Any
23class WINDetector(BaseDetector):
24 """WIND Watermark Detector (Two-Stage Robust Detection)"""
26 def __init__(self,
27 noise_groups: Dict[int, torch.Tensor],
28 group_patterns: Dict[int, torch.Tensor],
29 threshold: float,
30 device: torch.device,
31 group_radius: int = 10):
32 super().__init__(threshold, device)
34 self.noise_groups = {
35 int(g): [noise.to(device) for noise in noise_list]
36 for g, noise_list in noise_groups.items()
37 }
39 self.group_patterns = {
40 int(g): pattern.to(device)
41 for g, pattern in group_patterns.items()
42 }
44 self.group_radius = group_radius
45 self.device = device
46 self.group_threshold = 0.7
47 self.logger = logging.getLogger(__name__)
50 def _circle_mask(self, size: int, r: int) -> torch.Tensor:
51 """Using circle mask(same with Utils)"""
52 y, x = torch.meshgrid(
53 torch.arange(size, device=self.device),
54 torch.arange(size, device=self.device),
55 indexing='ij'
56 )
57 center = size // 2
58 dist = (x - center)**2 + (y - center)**2
59 return ((dist >= (r-2)**2) & (dist <= r**2)).float()
61 def _fft_transform(self, latents: torch.Tensor) -> torch.Tensor:
62 """Convert to Fourier space with shift"""
63 return torch.fft.fftshift(torch.fft.fft2(latents), dim=(-1, -2))
66 def _retrieve_group(self, z_fft: torch.Tensor) -> int:
67 similarities = []
68 mask = self._circle_mask(z_fft.shape[-1], self.group_radius)
70 for group_id, pattern in self.group_patterns.items():
71 try:
72 z_masked = torch.abs(z_fft) * mask
73 pattern_masked = torch.abs(pattern) * mask
75 sim = F.cosine_similarity(
76 z_masked.flatten().unsqueeze(0),
77 pattern_masked.flatten().unsqueeze(0)
78 ).item()
80 similarities.append((group_id, sim))
81 except Exception as e:
82 self.logger.warning(f"Error processing group {group_id}: {str(e)}")
83 continue
84 return max(similarities, key=lambda x: x[1])[0] if similarities else -1
86 def _match_noise(self, z: torch.Tensor, group_id: int) -> Dict[str, Any]:
87 if group_id not in self.noise_groups or group_id == -1:
88 return {'cosine_similarity': 0.0, 'best_match': None}
90 mask = self._circle_mask(z.shape[-1], self.group_radius)
91 z_fft = torch.fft.fftshift(torch.fft.fft2(z), dim=(-1, -2))
92 z_fft = z_fft - self.group_patterns[group_id] * mask
93 z_cleaned = torch.fft.ifft2(torch.fft.ifftshift(z_fft)).real
95 max_sim = -1.0
96 best_noise = None
98 for candidate in self.noise_groups[group_id]:
99 sim = F.cosine_similarity(
100 z_cleaned.flatten().unsqueeze(0),
101 candidate.flatten().unsqueeze(0)
102 ).item()
103 if sim > max_sim:
104 max_sim = sim
105 best_noise = candidate
107 return {
108 'cosine_similarity': max(max_sim, 0.0),
109 'best_match': best_noise
110 }
112 def eval_watermark(self,
113 reversed_latents: torch.Tensor,
114 reference_latents: torch.Tensor = None,
115 detector_type: str = "cosine_similarity") -> Dict[str, Any]:
116 """
117 Two-stage watermark detection
119 Args:
120 reversed_latents: Latents obtained through reverse diffusion [C,H,W]
121 reference_latents: Not used (for API compatibility)
122 detector_type: Detection method ('cosine_similarity' only supported)
124 Returns:
125 Dictionary containing detection results:
126 - group_id: Identified group ID
127 - similarity: Highest similarity score
128 - is_watermarked: Detection result
129 - best_match: Best matching noise tensor
130 """
131 if detector_type != "cosine_similarity":
132 raise ValueError(f"WIND detector only supports 'cosine' method, got {detector_type}")
134 try:
135 # Input validation
136 if not isinstance(reversed_latents, torch.Tensor):
137 reversed_latents = torch.tensor(reversed_latents, device=self.device)
138 reversed_latents = reversed_latents.to(self.device)
140 # Stage 1: Group identification
141 z_fft = self._fft_transform(reversed_latents)
142 group_id = self._retrieve_group(z_fft)
144 # Stage 2: Noise matching
145 match_result = self._match_noise(reversed_latents, group_id)
147 return {
148 'group_id': group_id,
149 'cosine_similarity': match_result['cosine_similarity'],
150 'is_watermarked': bool(match_result['cosine_similarity'] > self.threshold),
151 #'best_match': match_result['best_match']
152 }
154 except Exception as e:
155 self.logger.error(f"Detection failed: {str(e)}")
156 return {
157 'group_id': -1,
158 'cosine_similarity': 0.0,
159 'is_watermarked': False,
160 # 'best_match': None
161 }