Coverage for visualize / gm / gm_visualizer.py: 90.91%
264 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from typing import Optional, List
2import torch
3import matplotlib.pyplot as plt
4from matplotlib.axes import Axes
5from matplotlib.gridspec import GridSpecFromSubplotSpec
6import numpy as np
7from visualize.base import BaseVisualizer
8from visualize.data_for_visualization import DataForVisualization
11class GaussMarkerVisualizer(BaseVisualizer):
12 """GaussMarker watermark visualization class"""
14 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1):
15 super().__init__(data_for_visualization, dpi, watermarking_step)
17 # ------------------------------------------------------------------
18 # Helper utilities
19 # ------------------------------------------------------------------
20 def _get_boolean_mask(self) -> torch.Tensor:
21 mask = self.data.watermarking_mask
22 if mask.dtype == torch.bool:
23 return mask
24 return mask != 0
26 def _resolve_channels(self, mask: torch.Tensor, requested_channel: Optional[int] = None) -> List[int]:
27 if mask.dim() == 3:
28 mask = mask.unsqueeze(0)
29 num_channels = mask.shape[1]
31 if requested_channel is not None:
32 if requested_channel < 0 or requested_channel >= num_channels:
33 raise ValueError(f"Channel {requested_channel} is out of range. Max channel: {num_channels - 1}")
34 return [requested_channel]
36 default_channel = getattr(self.data, "w_channel", -1)
37 if default_channel not in (-1, None):
38 if default_channel < 0 or default_channel >= num_channels:
39 raise ValueError(f"w_channel {default_channel} is out of range. Max channel: {num_channels - 1}")
40 return [int(default_channel)]
42 mask_flat = mask.view(mask.shape[0], mask.shape[1], -1).any(dim=-1)
43 active_channels = [idx for idx, flag in enumerate(mask_flat[0].tolist()) if flag]
44 if not active_channels:
45 return list(range(num_channels))
46 return active_channels
48 def _grid_shape(self, count: int) -> tuple[int, int]:
49 rows = int(np.ceil(np.sqrt(count)))
50 cols = int(np.ceil(count / rows))
51 return rows, cols
53 def _get_reversed_latent(self, step: Optional[int] = None) -> torch.Tensor:
54 reversed_series = self.data.reversed_latents
55 if isinstance(reversed_series, (list, tuple)):
56 total = len(reversed_series)
57 index = self.watermarking_step if step is None else step
58 if index < 0:
59 index = total + index
60 index = max(0, min(total - 1, index))
61 return reversed_series[index]
62 return reversed_series
64 def _prepare_watermark_tensor(self) -> torch.Tensor:
65 generator = self.data.watermark_generator
66 watermark = generator.watermark_tensor().to(torch.float32)
67 if watermark.dim() == 3:
68 watermark = watermark.unsqueeze(0)
69 return watermark.cpu()
71 # ------------------------------------------------------------------
72 # Bit-level visualizations
73 # ------------------------------------------------------------------
74 def draw_watermark_bits(
75 self,
76 channel: Optional[int] = None,
77 title: str = "Original Watermark Bits",
78 cmap: str = "binary",
79 ax: Optional[Axes] = None,
80 ) -> Axes:
81 """Visualize the original watermark bits generated by GaussMarker."""
83 watermark = self._prepare_watermark_tensor()
84 num_channels = watermark.shape[1]
86 if channel is not None:
87 if channel < 0 or channel >= num_channels:
88 raise ValueError(f"Channel {channel} is out of range. Max channel: {num_channels - 1}")
89 channels = [channel]
90 else:
91 channels = list(range(num_channels))
93 if len(channels) == 1:
94 ch = channels[0]
95 if ch >= num_channels:
96 raise ValueError(f"Channel {ch} is out of range. Max channel: {num_channels - 1}")
97 data = watermark[0, ch].cpu().numpy()
98 im = ax.imshow(data, cmap=cmap, vmin=0, vmax=1, interpolation="nearest")
99 if title != "":
100 ax.set_title(f"{title} - Channel {ch}", fontsize=10)
101 ax.axis('off')
102 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0)
103 cbar.ax.set_visible(False)
104 else:
105 ax.clear()
106 if title != "":
107 ax.set_title(title, pad=20)
108 ax.axis('off')
110 rows, cols = self._grid_shape(len(channels))
111 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), wspace=0.3, hspace=0.4)
112 for i, ch in enumerate(channels):
113 if ch >= num_channels:
114 raise ValueError(f"Channel {ch} is out of range. Max channel: {num_channels - 1}")
115 row_idx = i // cols
116 col_idx = i % cols
117 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
118 data = watermark[0, ch].cpu().numpy()
119 im = sub_ax.imshow(data, cmap=cmap, vmin=0, vmax=1, interpolation="nearest")
120 sub_ax.set_title(f'Channel {ch}', fontsize=8, pad=3)
121 sub_ax.axis('off')
122 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04)
123 cbar.ax.tick_params(labelsize=6)
125 return ax
127 def draw_reconstructed_watermark_bits(
128 self,
129 channel: Optional[int] = None,
130 step: Optional[int] = None,
131 title: str = "Reconstructed Watermark Bits",
132 cmap: str = "binary",
133 ax: Optional[Axes] = None,
134 ) -> Axes:
135 """Visualize watermark bits reconstructed from inverted latents."""
137 reversed_latent = self._get_reversed_latent(step)
138 generator = self.data.watermark_generator
139 reconstructed = generator.pred_w_from_latent(reversed_latent).to(torch.float32)
140 reference = generator.watermark_tensor(reconstructed.device)
141 bit_acc = (reconstructed == reference).float().mean().item()
143 if reconstructed.dim() == 3:
144 reconstructed = reconstructed.unsqueeze(0)
146 num_channels = reconstructed.shape[1]
148 if channel is not None:
149 if channel < 0 or channel >= num_channels:
150 raise ValueError(f"Channel {channel} is out of range. Max channel: {num_channels - 1}")
151 channels = [channel]
152 else:
153 channels = list(range(num_channels))
155 if len(channels) == 1:
156 ch = channels[0]
157 if ch >= num_channels:
158 raise ValueError(f"Channel {ch} is out of range. Max channel: {num_channels - 1}")
159 data = reconstructed[0, ch].cpu().numpy()
160 im = ax.imshow(data, cmap=cmap, vmin=0, vmax=1, interpolation="nearest")
161 title_text = f"{title} - Channel {ch} (Bit Acc: {bit_acc:.3f})" if title != "" else f"Channel {ch} (Bit Acc: {bit_acc:.3f})"
162 ax.set_title(title_text, fontsize=10)
163 ax.axis('off')
164 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0)
165 cbar.ax.set_visible(False)
166 else:
167 ax.clear()
168 header = f"{title} (Bit Acc: {bit_acc:.3f})" if title != "" else f"(Bit Acc: {bit_acc:.3f})"
169 ax.set_title(header, pad=20)
170 ax.axis('off')
172 rows, cols = self._grid_shape(len(channels))
173 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), wspace=0.3, hspace=0.4)
175 for i, ch in enumerate(channels):
176 if ch >= num_channels:
177 raise ValueError(f"Channel {ch} is out of range. Max channel: {num_channels - 1}")
178 row_idx = i // cols
179 col_idx = i % cols
180 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
181 data = reconstructed[0, ch].cpu().numpy()
182 im = sub_ax.imshow(data, cmap=cmap, vmin=0, vmax=1, interpolation="nearest")
183 sub_ax.set_title(f'Channel {ch}', fontsize=8, pad=3)
184 sub_ax.axis('off')
185 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04)
186 cbar.ax.tick_params(labelsize=6)
188 return ax
190 # ------------------------------------------------------------------
191 # Frequency-domain visualizations
192 # ------------------------------------------------------------------
193 def draw_pattern_fft(
194 self,
195 channel: Optional[int] = None,
196 title: str = "GaussMarker Target Pattern (FFT)",
197 cmap: str = "viridis",
198 use_color_bar: bool = True,
199 vmin: Optional[float] = None,
200 vmax: Optional[float] = None,
201 ax: Optional[Axes] = None,
202 **kwargs,
203 ) -> Axes:
204 """Visualize the intended watermark pattern in the Fourier domain."""
206 pattern = self.data.gt_patch
207 if pattern.dim() == 3:
208 pattern = pattern.unsqueeze(0)
209 mask_bool = self._get_boolean_mask()
210 if mask_bool.dim() == 3:
211 mask_bool = mask_bool.unsqueeze(0)
213 channels = self._resolve_channels(mask_bool, channel)
214 pattern_abs = torch.abs(pattern).detach().cpu()
215 mask_cpu = mask_bool.cpu()
217 if len(channels) == 1:
218 ch = channels[0]
219 amplitude = torch.zeros_like(pattern_abs[0, ch], dtype=torch.float32)
220 selection = mask_cpu[0, ch]
221 if selection.any():
222 amplitude[selection] = pattern_abs[0, ch][selection]
223 im = ax.imshow(amplitude.numpy(), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
224 if title != "":
225 ax.set_title(f"{title} - Channel {ch}")
226 if use_color_bar:
227 cbar = ax.figure.colorbar(im, ax=ax)
228 cbar.ax.tick_params(labelsize=8)
229 ax.axis('off')
230 else:
231 ax.clear()
232 if title != "":
233 ax.set_title(title, pad=20)
234 ax.axis('off')
236 rows, cols = self._grid_shape(len(channels))
237 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), wspace=0.3, hspace=0.4)
239 for i, ch in enumerate(channels):
240 row_idx = i // cols
241 col_idx = i % cols
242 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
243 amplitude = torch.zeros_like(pattern_abs[0, ch], dtype=torch.float32)
244 selection = mask_cpu[0, ch]
245 if selection.any():
246 amplitude[selection] = pattern_abs[0, ch][selection]
247 im = sub_ax.imshow(amplitude.numpy(), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
248 sub_ax.set_title(f'Channel {ch}', fontsize=8, pad=3)
249 sub_ax.axis('off')
250 if use_color_bar:
251 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04)
252 cbar.ax.tick_params(labelsize=6)
254 return ax
256 def draw_inverted_pattern_fft(
257 self,
258 channel: Optional[int] = None,
259 step: Optional[int] = None,
260 title: str = "Recovered Watermark Pattern (FFT)",
261 cmap: str = "viridis",
262 use_color_bar: bool = True,
263 vmin: Optional[float] = None,
264 vmax: Optional[float] = None,
265 ax: Optional[Axes] = None,
266 **kwargs,
267 ) -> Axes:
268 """Visualize the recovered watermark region in the Fourier domain."""
270 reversed_latent = self._get_reversed_latent(step)
271 fft_latents = torch.fft.fftshift(torch.fft.fft2(reversed_latent), dim=(-1, -2))
272 mask_bool = self._get_boolean_mask()
273 if mask_bool.dim() == 3:
274 mask_bool = mask_bool.unsqueeze(0)
275 channels = self._resolve_channels(mask_bool, channel)
277 target_patch = self.data.gt_patch.to(fft_latents.device)
278 selection = mask_bool.to(fft_latents.device)
279 if selection.sum() > 0:
280 complex_l1 = torch.abs(fft_latents - target_patch)[selection].mean().item()
281 else:
282 complex_l1 = 0.0
284 fft_abs = torch.abs(fft_latents).detach().cpu()
285 mask_cpu = mask_bool.cpu()
287 title_suffix = f" (L1: {complex_l1:.3e})"
289 if len(channels) == 1:
290 ch = channels[0]
291 amplitude = torch.zeros_like(fft_abs[0, ch], dtype=torch.float32)
292 selection = mask_cpu[0, ch]
293 if selection.any():
294 amplitude[selection] = fft_abs[0, ch][selection]
295 im = ax.imshow(amplitude.numpy(), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
296 if title != "":
297 ax.set_title(f"{title} - Channel {ch}{title_suffix}")
298 elif title_suffix:
299 ax.set_title(title_suffix.strip(), fontsize=10)
300 if use_color_bar:
301 cbar = ax.figure.colorbar(im, ax=ax)
302 cbar.ax.tick_params(labelsize=8)
303 ax.axis('off')
304 else:
305 ax.clear()
306 if title != "":
307 ax.set_title(f"{title}{title_suffix}", pad=20)
308 else:
309 ax.set_title(title_suffix.strip(), pad=20)
310 ax.axis('off')
312 rows, cols = self._grid_shape(len(channels))
313 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), wspace=0.3, hspace=0.4)
315 for i, ch in enumerate(channels):
316 row_idx = i // cols
317 col_idx = i % cols
318 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
319 amplitude = torch.zeros_like(fft_abs[0, ch], dtype=torch.float32)
320 selection = mask_cpu[0, ch]
321 if selection.any():
322 amplitude[selection] = fft_abs[0, ch][selection]
323 im = sub_ax.imshow(amplitude.numpy(), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
324 sub_ax.set_title(f'Channel {ch}', fontsize=8, pad=3)
325 sub_ax.axis('off')
326 if use_color_bar:
327 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04)
328 cbar.ax.tick_params(labelsize=6)
330 return ax
332 def draw_watermark_mask(
333 self,
334 channel: Optional[int] = None,
335 title: str = "Watermark Mask",
336 cmap: str = "viridis",
337 ax: Optional[Axes] = None,
338 ) -> Axes:
339 """Visualize the spatial region where the watermark is applied."""
341 mask_bool = self._get_boolean_mask()
342 if mask_bool.dim() == 3:
343 mask_bool = mask_bool.unsqueeze(0)
344 channels = self._resolve_channels(mask_bool, channel)
345 mask_cpu = mask_bool.float().cpu()
347 if len(channels) == 1:
348 ch = channels[0]
349 data = mask_cpu[0, ch].numpy()
350 im = ax.imshow(data, cmap=cmap, vmin=0, vmax=1)
351 if title != "":
352 ax.set_title(f"{title} - Channel {ch}")
353 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0)
354 cbar.ax.set_visible(False)
355 ax.axis('off')
356 else:
357 ax.clear()
358 if title != "":
359 ax.set_title(title, pad=20)
360 ax.axis('off')
362 rows, cols = self._grid_shape(len(channels))
363 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), wspace=0.3, hspace=0.4)
365 for i, ch in enumerate(channels):
366 row_idx = i // cols
367 col_idx = i % cols
368 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
369 data = mask_cpu[0, ch].numpy()
370 im = sub_ax.imshow(data, cmap=cmap, vmin=0, vmax=1)
371 sub_ax.set_title(f'Channel {ch}', fontsize=8, pad=3)
372 sub_ax.axis('off')
373 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04)
374 cbar.ax.tick_params(labelsize=6)
376 return ax