Coverage for visualize / wind / wind_visualizer.py: 100.00%
182 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from typing import Optional
2import torch
3import matplotlib.pyplot as plt
4from matplotlib.axes import Axes
5import numpy as np
6from visualize.base import BaseVisualizer
7from visualize.data_for_visualization import DataForVisualization
8from matplotlib.gridspec import GridSpecFromSubplotSpec
10class WINDVisualizer(BaseVisualizer):
11 """WIND watermark visualization class"""
13 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1):
14 super().__init__(data_for_visualization, dpi, watermarking_step)
15 index = self.data.current_index % self.data.M
16 self.group_pattern = self.data.group_patterns[index] # shape: [4, 64, 64]
18 def draw_group_pattern_fft(self,
19 channel: Optional[int] = None,
20 title: str = "Group Pattern in Fourier Domain",
21 cmap: str = "viridis",
22 use_color_bar: bool = True,
23 vmin: Optional[float] = None,
24 vmax: Optional[float] = None,
25 ax: Optional[Axes] = None,
26 **kwargs) -> Axes:
27 """
28 Draw the group pattern in Fourier Domain.
30 Parameters:
31 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown.
32 title (str): The title of the plot.
33 cmap (str): The colormap to use.
34 use_color_bar (bool): Whether to display the colorbar.
35 ax (Axes): The axes to plot on.
37 Returns:
38 Axes: The plotted axes.
39 """
40 if channel is not None:
41 im = ax.imshow(np.abs(self.group_pattern[channel].cpu().numpy()), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
42 if title != "":
43 ax.set_title(title)
44 if use_color_bar:
45 ax.figure.colorbar(im, ax=ax)
46 ax.axis('off')
47 else:
48 # Multi-channel visualization
49 num_channels = 4
50 rows = 2
51 cols = 2
53 # Clear the axis and set title
54 ax.clear()
55 if title != "":
56 ax.set_title(title, pad=20)
57 ax.axis('off')
59 # Use gridspec for better control
60 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(),
61 wspace=0.3, hspace=0.4)
63 # Create subplots for each channel
64 for i in range(num_channels):
65 row_idx = i // cols
66 col_idx = i % cols
68 # Create subplot using gridspec
69 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
71 # Draw the latent channel
72 latent_data = self.group_pattern[i].cpu().numpy()
73 im = sub_ax.imshow(np.abs(latent_data), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
74 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3)
75 sub_ax.axis('off')
76 # Add small colorbar for each subplot
77 if use_color_bar:
78 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04)
79 cbar.ax.tick_params(labelsize=6)
81 return ax
84 def draw_orig_noise_wo_group_pattern(self,
85 channel: Optional[int] = None,
86 title: str = "Original Noise without Group Pattern",
87 cmap: str = "viridis",
88 use_color_bar: bool = True,
89 vmin: Optional[float] = None,
90 vmax: Optional[float] = None,
91 ax: Optional[Axes] = None,
92 **kwargs) -> Axes:
93 """
94 Draw the original noise without group pattern.
96 Parameters:
97 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown.
98 title (str): The title of the plot.
99 cmap (str): The colormap to use.
100 use_color_bar (bool): Whether to display the colorbar.
101 ax (plt.Axes): The axes to plot on.
103 Returns:
104 plt.axes.Axes: The plotted axes.
105 """
106 if channel is not None:
107 # Single channel visualization
108 orig_noise_fft = self._fft_transform(self.data.orig_watermarked_latents[0, channel])
109 z_fft = orig_noise_fft - self.group_pattern[channel].cpu().numpy()
110 z_cleaned = self._ifft_transform(z_fft).real
111 im = ax.imshow(z_cleaned, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
112 if title != "":
113 ax.set_title(title)
114 if use_color_bar:
115 ax.figure.colorbar(im, ax=ax)
116 ax.axis('off')
117 else:
118 # Multi-channel visualization
119 num_channels = 4
120 rows = 2
121 cols = 2
123 # Clear the axis and set title
124 ax.clear()
125 if title != "":
126 ax.set_title(title, pad=20)
127 ax.axis('off')
129 # Use gridspec for better control
130 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(),
131 wspace=0.3, hspace=0.4)
133 # Create subplots for each channel
134 for i in range(num_channels):
135 row_idx = i // cols
136 col_idx = i % cols
138 # Create subplot using gridspec
139 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
141 # Compute original noise without group pattern for this channel
142 orig_noise_fft = self._fft_transform(self.data.orig_watermarked_latents[0, i])
143 z_fft = orig_noise_fft - self.group_pattern[i].cpu().numpy()
144 z_cleaned = self._ifft_transform(z_fft).real
146 # Draw the cleaned noise channel
147 im = sub_ax.imshow(z_cleaned, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
148 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3)
149 sub_ax.axis('off')
150 # Add small colorbar for each subplot
151 if use_color_bar:
152 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04)
153 cbar.ax.tick_params(labelsize=6)
155 return ax
157 def draw_inverted_noise_wo_group_pattern(self,
158 channel: Optional[int] = None,
159 title: str = "Inverted Noise without Group Pattern",
160 cmap: str = "viridis",
161 use_color_bar: bool = True,
162 vmin: Optional[float] = None,
163 vmax: Optional[float] = None,
164 ax: Optional[Axes] = None,
165 **kwargs) -> Axes:
166 """
167 Draw the inverted noise without group pattern.
169 Parameters:
170 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown.
171 title (str): The title of the plot.
172 cmap (str): The colormap to use.
173 use_color_bar (bool): Whether to display the colorbar.
174 ax (plt.Axes): The axes to plot on.
176 Returns:
177 plt.axes.Axes: The plotted axes.
178 """
179 if channel is not None:
180 # Single channel visualization
181 reversed_latent = self.data.reversed_latents[self.watermarking_step]
182 reversed_latent_fft = self._fft_transform(reversed_latent[0, channel])
183 z_fft = reversed_latent_fft - self.group_pattern[channel].cpu().numpy()
184 z_cleaned = self._ifft_transform(z_fft).real
185 im = ax.imshow(z_cleaned, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
186 if title != "":
187 ax.set_title(title)
188 if use_color_bar:
189 ax.figure.colorbar(im, ax=ax)
190 ax.axis('off')
191 else:
192 # Multi-channel visualization
193 num_channels = 4
194 rows = 2
195 cols = 2
197 # Clear the axis and set title
198 ax.clear()
199 if title != "":
200 ax.set_title(title, pad=20)
201 ax.axis('off')
203 # Use gridspec for better control
204 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(),
205 wspace=0.3, hspace=0.4)
207 # Create subplots for each channel
208 for i in range(num_channels):
209 row_idx = i // cols
210 col_idx = i % cols
212 # Create subplot using gridspec
213 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
215 # Compute inverted noise without group pattern for this channel
216 reversed_latent = self.data.reversed_latents[self.watermarking_step]
217 reversed_latent_fft = self._fft_transform(reversed_latent[0, i])
218 z_fft = reversed_latent_fft - self.group_pattern[i].cpu().numpy()
219 z_cleaned = self._ifft_transform(z_fft).real
221 # Draw the cleaned inverted noise channel
222 im = sub_ax.imshow(z_cleaned, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
223 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3)
224 sub_ax.axis('off')
225 # Add small colorbar for each subplot
226 if use_color_bar:
227 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04)
228 cbar.ax.tick_params(labelsize=6)
230 return ax
232 def draw_diff_noise_wo_group_pattern(self,
233 channel: Optional[int] = None,
234 title: str = "Difference map without Group Pattern",
235 cmap: str = "coolwarm",
236 use_color_bar: bool = True,
237 vmin: Optional[float] = None,
238 vmax: Optional[float] = None,
239 ax: Optional[Axes] = None,
240 **kwargs) -> Axes:
241 """
242 Draw the difference between original and inverted noise after removing group pattern.
244 Parameters:
245 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown.
246 title (str): The title of the plot.
247 cmap (str): The colormap to use.
248 use_color_bar (bool): Whether to display the colorbar.
249 ax (plt.Axes): The axes to plot on.
251 Returns:
252 plt.axes.Axes: The plotted axes.
253 """
254 if channel is not None:
255 # Single channel visualization
256 # Process original latents
257 orig_latent_channel = self.data.orig_watermarked_latents[0, channel]
258 orig_noise_fft = self._fft_transform(orig_latent_channel)
259 orig_z_fft = orig_noise_fft - self.group_pattern[channel].cpu().numpy()
260 orig_z_cleaned = self._ifft_transform(orig_z_fft).real
262 # Process inverted latents
263 reversed_latent = self.data.reversed_latents[self.watermarking_step]
264 reversed_latent_fft = self._fft_transform(reversed_latent[0, channel])
265 inv_z_fft = reversed_latent_fft - self.group_pattern[channel].cpu().numpy()
266 inv_z_cleaned = self._ifft_transform(inv_z_fft).real
268 # Compute difference
269 diff_cleaned = orig_z_cleaned - inv_z_cleaned
271 im = ax.imshow(diff_cleaned, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
272 if title != "":
273 ax.set_title(title)
274 if use_color_bar:
275 ax.figure.colorbar(im, ax=ax)
276 ax.axis('off')
277 else:
278 # Multi-channel visualization
279 num_channels = 4
280 rows = 2
281 cols = 2
283 # Clear the axis and set title
284 ax.clear()
285 if title != "":
286 ax.set_title(title, pad=20)
287 ax.axis('off')
289 # Use gridspec for better control
290 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(),
291 wspace=0.3, hspace=0.4)
293 # Create subplots for each channel
294 for i in range(num_channels):
295 row_idx = i // cols
296 col_idx = i % cols
298 # Create subplot using gridspec
299 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
301 # Process original latents for this channel
302 orig_latent_channel = self.data.orig_watermarked_latents[0, i]
303 orig_noise_fft = self._fft_transform(orig_latent_channel)
304 orig_z_fft = orig_noise_fft - self.group_pattern[i].cpu().numpy()
305 orig_z_cleaned = self._ifft_transform(orig_z_fft).real
307 # Process inverted latents for this channel
308 reversed_latent = self.data.reversed_latents[self.watermarking_step]
309 reversed_latent_fft = self._fft_transform(reversed_latent[0, i])
310 inv_z_fft = reversed_latent_fft - self.group_pattern[i].cpu().numpy()
311 inv_z_cleaned = self._ifft_transform(inv_z_fft).real
313 # Compute difference
314 diff_cleaned = orig_z_cleaned - inv_z_cleaned
316 # Draw the difference channel
317 im = sub_ax.imshow(diff_cleaned, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
318 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3)
319 sub_ax.axis('off')
320 # Add small colorbar for each subplot
321 if use_color_bar:
322 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04)
323 cbar.ax.tick_params(labelsize=6)
325 return ax
327 def draw_inverted_group_pattern_fft(self,
328 channel: Optional[int] = None,
329 title: str = "WIND Two-Stage Detection Visualization",
330 cmap: str = "viridis",
331 use_color_bar: bool = True,
332 ax: Optional[Axes] = None,
333 **kwargs) -> Axes:
335 # Get inverted latents
336 reversed_latents = self.data.reversed_latents[self.watermarking_step]
338 if channel is not None:
339 # Single channel visualization
340 latent_channel = reversed_latents[0, channel]
341 else:
342 # Average across all channels for clearer visualization
343 latent_channel = reversed_latents[0].mean(dim=0)
345 # Convert to frequency domain
346 z_fft = torch.fft.fftshift(torch.fft.fft2(latent_channel), dim=(-1, -2))
348 # Get the group pattern that would be detected
349 index = self.data.current_index % self.data.M
350 if channel is not None:
351 group_pattern = self.group_pattern[channel]
352 else:
353 group_pattern = self.group_pattern.mean(dim=0)
355 # Create circular mask
356 mask = self._create_circle_mask(64, self.data.group_radius)
358 # Remove group pattern
359 z_fft_cleaned = z_fft - group_pattern * mask
361 detection_signal = torch.abs(z_fft_cleaned)
363 # Apply same mask that detector uses to focus on watermark region
364 detection_signal = detection_signal * mask
366 # Plot the detection signal
367 im = ax.imshow(detection_signal.cpu().numpy(), cmap=cmap, **kwargs)
369 if title != "":
370 detection_info = f" (Group {index}, Radius {self.data.group_radius})"
371 ax.set_title(title + detection_info, fontsize=10)
373 ax.axis('off')
375 # Add colorbar
376 if use_color_bar:
377 cbar = ax.figure.colorbar(im, ax=ax)
378 cbar.set_label('Detection Signal Magnitude', fontsize=8)
380 return ax
382 def _create_circle_mask(self, size: int, r: int) -> torch.Tensor:
383 """Create circular mask for watermark region (same as in detector)"""
384 y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij')
385 center = size // 2
386 dist = (x - center)**2 + (y - center)**2
387 return ((dist >= (r-2)**2) & (dist <= r**2)).float().to(self.data.orig_watermarked_latents.device)