Coverage for visualize / gm / gm_visualizer.py: 90.91%

264 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from typing import Optional, List 

2import torch 

3import matplotlib.pyplot as plt 

4from matplotlib.axes import Axes 

5from matplotlib.gridspec import GridSpecFromSubplotSpec 

6import numpy as np 

7from visualize.base import BaseVisualizer 

8from visualize.data_for_visualization import DataForVisualization 

9 

10 

11class GaussMarkerVisualizer(BaseVisualizer): 

12 """GaussMarker watermark visualization class""" 

13 

14 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1): 

15 super().__init__(data_for_visualization, dpi, watermarking_step) 

16 

17 # ------------------------------------------------------------------ 

18 # Helper utilities 

19 # ------------------------------------------------------------------ 

20 def _get_boolean_mask(self) -> torch.Tensor: 

21 mask = self.data.watermarking_mask 

22 if mask.dtype == torch.bool: 

23 return mask 

24 return mask != 0 

25 

26 def _resolve_channels(self, mask: torch.Tensor, requested_channel: Optional[int] = None) -> List[int]: 

27 if mask.dim() == 3: 

28 mask = mask.unsqueeze(0) 

29 num_channels = mask.shape[1] 

30 

31 if requested_channel is not None: 

32 if requested_channel < 0 or requested_channel >= num_channels: 

33 raise ValueError(f"Channel {requested_channel} is out of range. Max channel: {num_channels - 1}") 

34 return [requested_channel] 

35 

36 default_channel = getattr(self.data, "w_channel", -1) 

37 if default_channel not in (-1, None): 

38 if default_channel < 0 or default_channel >= num_channels: 

39 raise ValueError(f"w_channel {default_channel} is out of range. Max channel: {num_channels - 1}") 

40 return [int(default_channel)] 

41 

42 mask_flat = mask.view(mask.shape[0], mask.shape[1], -1).any(dim=-1) 

43 active_channels = [idx for idx, flag in enumerate(mask_flat[0].tolist()) if flag] 

44 if not active_channels: 

45 return list(range(num_channels)) 

46 return active_channels 

47 

48 def _grid_shape(self, count: int) -> tuple[int, int]: 

49 rows = int(np.ceil(np.sqrt(count))) 

50 cols = int(np.ceil(count / rows)) 

51 return rows, cols 

52 

53 def _get_reversed_latent(self, step: Optional[int] = None) -> torch.Tensor: 

54 reversed_series = self.data.reversed_latents 

55 if isinstance(reversed_series, (list, tuple)): 

56 total = len(reversed_series) 

57 index = self.watermarking_step if step is None else step 

58 if index < 0: 

59 index = total + index 

60 index = max(0, min(total - 1, index)) 

61 return reversed_series[index] 

62 return reversed_series 

63 

64 def _prepare_watermark_tensor(self) -> torch.Tensor: 

65 generator = self.data.watermark_generator 

66 watermark = generator.watermark_tensor().to(torch.float32) 

67 if watermark.dim() == 3: 

68 watermark = watermark.unsqueeze(0) 

69 return watermark.cpu() 

70 

71 # ------------------------------------------------------------------ 

72 # Bit-level visualizations 

73 # ------------------------------------------------------------------ 

74 def draw_watermark_bits( 

75 self, 

76 channel: Optional[int] = None, 

77 title: str = "Original Watermark Bits", 

78 cmap: str = "binary", 

79 ax: Optional[Axes] = None, 

80 ) -> Axes: 

81 """Visualize the original watermark bits generated by GaussMarker.""" 

82 

83 watermark = self._prepare_watermark_tensor() 

84 num_channels = watermark.shape[1] 

85 

86 if channel is not None: 

87 if channel < 0 or channel >= num_channels: 

88 raise ValueError(f"Channel {channel} is out of range. Max channel: {num_channels - 1}") 

89 channels = [channel] 

90 else: 

91 channels = list(range(num_channels)) 

92 

93 if len(channels) == 1: 

94 ch = channels[0] 

95 if ch >= num_channels: 

96 raise ValueError(f"Channel {ch} is out of range. Max channel: {num_channels - 1}") 

97 data = watermark[0, ch].cpu().numpy() 

98 im = ax.imshow(data, cmap=cmap, vmin=0, vmax=1, interpolation="nearest") 

99 if title != "": 

100 ax.set_title(f"{title} - Channel {ch}", fontsize=10) 

101 ax.axis('off') 

102 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0) 

103 cbar.ax.set_visible(False) 

104 else: 

105 ax.clear() 

106 if title != "": 

107 ax.set_title(title, pad=20) 

108 ax.axis('off') 

109 

110 rows, cols = self._grid_shape(len(channels)) 

111 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), wspace=0.3, hspace=0.4) 

112 for i, ch in enumerate(channels): 

113 if ch >= num_channels: 

114 raise ValueError(f"Channel {ch} is out of range. Max channel: {num_channels - 1}") 

115 row_idx = i // cols 

116 col_idx = i % cols 

117 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

118 data = watermark[0, ch].cpu().numpy() 

119 im = sub_ax.imshow(data, cmap=cmap, vmin=0, vmax=1, interpolation="nearest") 

120 sub_ax.set_title(f'Channel {ch}', fontsize=8, pad=3) 

121 sub_ax.axis('off') 

122 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04) 

123 cbar.ax.tick_params(labelsize=6) 

124 

125 return ax 

126 

127 def draw_reconstructed_watermark_bits( 

128 self, 

129 channel: Optional[int] = None, 

130 step: Optional[int] = None, 

131 title: str = "Reconstructed Watermark Bits", 

132 cmap: str = "binary", 

133 ax: Optional[Axes] = None, 

134 ) -> Axes: 

135 """Visualize watermark bits reconstructed from inverted latents.""" 

136 

137 reversed_latent = self._get_reversed_latent(step) 

138 generator = self.data.watermark_generator 

139 reconstructed = generator.pred_w_from_latent(reversed_latent).to(torch.float32) 

140 reference = generator.watermark_tensor(reconstructed.device) 

141 bit_acc = (reconstructed == reference).float().mean().item() 

142 

143 if reconstructed.dim() == 3: 

144 reconstructed = reconstructed.unsqueeze(0) 

145 

146 num_channels = reconstructed.shape[1] 

147 

148 if channel is not None: 

149 if channel < 0 or channel >= num_channels: 

150 raise ValueError(f"Channel {channel} is out of range. Max channel: {num_channels - 1}") 

151 channels = [channel] 

152 else: 

153 channels = list(range(num_channels)) 

154 

155 if len(channels) == 1: 

156 ch = channels[0] 

157 if ch >= num_channels: 

158 raise ValueError(f"Channel {ch} is out of range. Max channel: {num_channels - 1}") 

159 data = reconstructed[0, ch].cpu().numpy() 

160 im = ax.imshow(data, cmap=cmap, vmin=0, vmax=1, interpolation="nearest") 

161 title_text = f"{title} - Channel {ch} (Bit Acc: {bit_acc:.3f})" if title != "" else f"Channel {ch} (Bit Acc: {bit_acc:.3f})" 

162 ax.set_title(title_text, fontsize=10) 

163 ax.axis('off') 

164 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0) 

165 cbar.ax.set_visible(False) 

166 else: 

167 ax.clear() 

168 header = f"{title} (Bit Acc: {bit_acc:.3f})" if title != "" else f"(Bit Acc: {bit_acc:.3f})" 

169 ax.set_title(header, pad=20) 

170 ax.axis('off') 

171 

172 rows, cols = self._grid_shape(len(channels)) 

173 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), wspace=0.3, hspace=0.4) 

174 

175 for i, ch in enumerate(channels): 

176 if ch >= num_channels: 

177 raise ValueError(f"Channel {ch} is out of range. Max channel: {num_channels - 1}") 

178 row_idx = i // cols 

179 col_idx = i % cols 

180 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

181 data = reconstructed[0, ch].cpu().numpy() 

182 im = sub_ax.imshow(data, cmap=cmap, vmin=0, vmax=1, interpolation="nearest") 

183 sub_ax.set_title(f'Channel {ch}', fontsize=8, pad=3) 

184 sub_ax.axis('off') 

185 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04) 

186 cbar.ax.tick_params(labelsize=6) 

187 

188 return ax 

189 

190 # ------------------------------------------------------------------ 

191 # Frequency-domain visualizations 

192 # ------------------------------------------------------------------ 

193 def draw_pattern_fft( 

194 self, 

195 channel: Optional[int] = None, 

196 title: str = "GaussMarker Target Pattern (FFT)", 

197 cmap: str = "viridis", 

198 use_color_bar: bool = True, 

199 vmin: Optional[float] = None, 

200 vmax: Optional[float] = None, 

201 ax: Optional[Axes] = None, 

202 **kwargs, 

203 ) -> Axes: 

204 """Visualize the intended watermark pattern in the Fourier domain.""" 

205 

206 pattern = self.data.gt_patch 

207 if pattern.dim() == 3: 

208 pattern = pattern.unsqueeze(0) 

209 mask_bool = self._get_boolean_mask() 

210 if mask_bool.dim() == 3: 

211 mask_bool = mask_bool.unsqueeze(0) 

212 

213 channels = self._resolve_channels(mask_bool, channel) 

214 pattern_abs = torch.abs(pattern).detach().cpu() 

215 mask_cpu = mask_bool.cpu() 

216 

217 if len(channels) == 1: 

218 ch = channels[0] 

219 amplitude = torch.zeros_like(pattern_abs[0, ch], dtype=torch.float32) 

220 selection = mask_cpu[0, ch] 

221 if selection.any(): 

222 amplitude[selection] = pattern_abs[0, ch][selection] 

223 im = ax.imshow(amplitude.numpy(), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

224 if title != "": 

225 ax.set_title(f"{title} - Channel {ch}") 

226 if use_color_bar: 

227 cbar = ax.figure.colorbar(im, ax=ax) 

228 cbar.ax.tick_params(labelsize=8) 

229 ax.axis('off') 

230 else: 

231 ax.clear() 

232 if title != "": 

233 ax.set_title(title, pad=20) 

234 ax.axis('off') 

235 

236 rows, cols = self._grid_shape(len(channels)) 

237 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), wspace=0.3, hspace=0.4) 

238 

239 for i, ch in enumerate(channels): 

240 row_idx = i // cols 

241 col_idx = i % cols 

242 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

243 amplitude = torch.zeros_like(pattern_abs[0, ch], dtype=torch.float32) 

244 selection = mask_cpu[0, ch] 

245 if selection.any(): 

246 amplitude[selection] = pattern_abs[0, ch][selection] 

247 im = sub_ax.imshow(amplitude.numpy(), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

248 sub_ax.set_title(f'Channel {ch}', fontsize=8, pad=3) 

249 sub_ax.axis('off') 

250 if use_color_bar: 

251 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04) 

252 cbar.ax.tick_params(labelsize=6) 

253 

254 return ax 

255 

256 def draw_inverted_pattern_fft( 

257 self, 

258 channel: Optional[int] = None, 

259 step: Optional[int] = None, 

260 title: str = "Recovered Watermark Pattern (FFT)", 

261 cmap: str = "viridis", 

262 use_color_bar: bool = True, 

263 vmin: Optional[float] = None, 

264 vmax: Optional[float] = None, 

265 ax: Optional[Axes] = None, 

266 **kwargs, 

267 ) -> Axes: 

268 """Visualize the recovered watermark region in the Fourier domain.""" 

269 

270 reversed_latent = self._get_reversed_latent(step) 

271 fft_latents = torch.fft.fftshift(torch.fft.fft2(reversed_latent), dim=(-1, -2)) 

272 mask_bool = self._get_boolean_mask() 

273 if mask_bool.dim() == 3: 

274 mask_bool = mask_bool.unsqueeze(0) 

275 channels = self._resolve_channels(mask_bool, channel) 

276 

277 target_patch = self.data.gt_patch.to(fft_latents.device) 

278 selection = mask_bool.to(fft_latents.device) 

279 if selection.sum() > 0: 

280 complex_l1 = torch.abs(fft_latents - target_patch)[selection].mean().item() 

281 else: 

282 complex_l1 = 0.0 

283 

284 fft_abs = torch.abs(fft_latents).detach().cpu() 

285 mask_cpu = mask_bool.cpu() 

286 

287 title_suffix = f" (L1: {complex_l1:.3e})" 

288 

289 if len(channels) == 1: 

290 ch = channels[0] 

291 amplitude = torch.zeros_like(fft_abs[0, ch], dtype=torch.float32) 

292 selection = mask_cpu[0, ch] 

293 if selection.any(): 

294 amplitude[selection] = fft_abs[0, ch][selection] 

295 im = ax.imshow(amplitude.numpy(), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

296 if title != "": 

297 ax.set_title(f"{title} - Channel {ch}{title_suffix}") 

298 elif title_suffix: 

299 ax.set_title(title_suffix.strip(), fontsize=10) 

300 if use_color_bar: 

301 cbar = ax.figure.colorbar(im, ax=ax) 

302 cbar.ax.tick_params(labelsize=8) 

303 ax.axis('off') 

304 else: 

305 ax.clear() 

306 if title != "": 

307 ax.set_title(f"{title}{title_suffix}", pad=20) 

308 else: 

309 ax.set_title(title_suffix.strip(), pad=20) 

310 ax.axis('off') 

311 

312 rows, cols = self._grid_shape(len(channels)) 

313 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), wspace=0.3, hspace=0.4) 

314 

315 for i, ch in enumerate(channels): 

316 row_idx = i // cols 

317 col_idx = i % cols 

318 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

319 amplitude = torch.zeros_like(fft_abs[0, ch], dtype=torch.float32) 

320 selection = mask_cpu[0, ch] 

321 if selection.any(): 

322 amplitude[selection] = fft_abs[0, ch][selection] 

323 im = sub_ax.imshow(amplitude.numpy(), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

324 sub_ax.set_title(f'Channel {ch}', fontsize=8, pad=3) 

325 sub_ax.axis('off') 

326 if use_color_bar: 

327 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04) 

328 cbar.ax.tick_params(labelsize=6) 

329 

330 return ax 

331 

332 def draw_watermark_mask( 

333 self, 

334 channel: Optional[int] = None, 

335 title: str = "Watermark Mask", 

336 cmap: str = "viridis", 

337 ax: Optional[Axes] = None, 

338 ) -> Axes: 

339 """Visualize the spatial region where the watermark is applied.""" 

340 

341 mask_bool = self._get_boolean_mask() 

342 if mask_bool.dim() == 3: 

343 mask_bool = mask_bool.unsqueeze(0) 

344 channels = self._resolve_channels(mask_bool, channel) 

345 mask_cpu = mask_bool.float().cpu() 

346 

347 if len(channels) == 1: 

348 ch = channels[0] 

349 data = mask_cpu[0, ch].numpy() 

350 im = ax.imshow(data, cmap=cmap, vmin=0, vmax=1) 

351 if title != "": 

352 ax.set_title(f"{title} - Channel {ch}") 

353 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0) 

354 cbar.ax.set_visible(False) 

355 ax.axis('off') 

356 else: 

357 ax.clear() 

358 if title != "": 

359 ax.set_title(title, pad=20) 

360 ax.axis('off') 

361 

362 rows, cols = self._grid_shape(len(channels)) 

363 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), wspace=0.3, hspace=0.4) 

364 

365 for i, ch in enumerate(channels): 

366 row_idx = i // cols 

367 col_idx = i % cols 

368 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

369 data = mask_cpu[0, ch].numpy() 

370 im = sub_ax.imshow(data, cmap=cmap, vmin=0, vmax=1) 

371 sub_ax.set_title(f'Channel {ch}', fontsize=8, pad=3) 

372 sub_ax.axis('off') 

373 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04) 

374 cbar.ax.tick_params(labelsize=6) 

375 

376 return ax