Coverage for visualize / wind / wind_visualizer.py: 100.00%

182 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from typing import Optional 

2import torch 

3import matplotlib.pyplot as plt 

4from matplotlib.axes import Axes 

5import numpy as np 

6from visualize.base import BaseVisualizer 

7from visualize.data_for_visualization import DataForVisualization 

8from matplotlib.gridspec import GridSpecFromSubplotSpec 

9 

10class WINDVisualizer(BaseVisualizer): 

11 """WIND watermark visualization class""" 

12 

13 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1): 

14 super().__init__(data_for_visualization, dpi, watermarking_step) 

15 index = self.data.current_index % self.data.M 

16 self.group_pattern = self.data.group_patterns[index] # shape: [4, 64, 64] 

17 

18 def draw_group_pattern_fft(self, 

19 channel: Optional[int] = None, 

20 title: str = "Group Pattern in Fourier Domain", 

21 cmap: str = "viridis", 

22 use_color_bar: bool = True, 

23 vmin: Optional[float] = None, 

24 vmax: Optional[float] = None, 

25 ax: Optional[Axes] = None, 

26 **kwargs) -> Axes: 

27 """ 

28 Draw the group pattern in Fourier Domain. 

29  

30 Parameters: 

31 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown. 

32 title (str): The title of the plot. 

33 cmap (str): The colormap to use. 

34 use_color_bar (bool): Whether to display the colorbar. 

35 ax (Axes): The axes to plot on. 

36  

37 Returns: 

38 Axes: The plotted axes. 

39 """ 

40 if channel is not None: 

41 im = ax.imshow(np.abs(self.group_pattern[channel].cpu().numpy()), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

42 if title != "": 

43 ax.set_title(title) 

44 if use_color_bar: 

45 ax.figure.colorbar(im, ax=ax) 

46 ax.axis('off') 

47 else: 

48 # Multi-channel visualization 

49 num_channels = 4 

50 rows = 2 

51 cols = 2 

52 

53 # Clear the axis and set title 

54 ax.clear() 

55 if title != "": 

56 ax.set_title(title, pad=20) 

57 ax.axis('off') 

58 

59 # Use gridspec for better control 

60 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), 

61 wspace=0.3, hspace=0.4) 

62 

63 # Create subplots for each channel 

64 for i in range(num_channels): 

65 row_idx = i // cols 

66 col_idx = i % cols 

67 

68 # Create subplot using gridspec 

69 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

70 

71 # Draw the latent channel 

72 latent_data = self.group_pattern[i].cpu().numpy() 

73 im = sub_ax.imshow(np.abs(latent_data), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

74 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3) 

75 sub_ax.axis('off') 

76 # Add small colorbar for each subplot 

77 if use_color_bar: 

78 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04) 

79 cbar.ax.tick_params(labelsize=6) 

80 

81 return ax 

82 

83 

84 def draw_orig_noise_wo_group_pattern(self, 

85 channel: Optional[int] = None, 

86 title: str = "Original Noise without Group Pattern", 

87 cmap: str = "viridis", 

88 use_color_bar: bool = True, 

89 vmin: Optional[float] = None, 

90 vmax: Optional[float] = None, 

91 ax: Optional[Axes] = None, 

92 **kwargs) -> Axes: 

93 """ 

94 Draw the original noise without group pattern. 

95  

96 Parameters: 

97 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown. 

98 title (str): The title of the plot. 

99 cmap (str): The colormap to use. 

100 use_color_bar (bool): Whether to display the colorbar. 

101 ax (plt.Axes): The axes to plot on. 

102  

103 Returns: 

104 plt.axes.Axes: The plotted axes. 

105 """ 

106 if channel is not None: 

107 # Single channel visualization 

108 orig_noise_fft = self._fft_transform(self.data.orig_watermarked_latents[0, channel]) 

109 z_fft = orig_noise_fft - self.group_pattern[channel].cpu().numpy() 

110 z_cleaned = self._ifft_transform(z_fft).real 

111 im = ax.imshow(z_cleaned, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

112 if title != "": 

113 ax.set_title(title) 

114 if use_color_bar: 

115 ax.figure.colorbar(im, ax=ax) 

116 ax.axis('off') 

117 else: 

118 # Multi-channel visualization 

119 num_channels = 4 

120 rows = 2 

121 cols = 2 

122 

123 # Clear the axis and set title 

124 ax.clear() 

125 if title != "": 

126 ax.set_title(title, pad=20) 

127 ax.axis('off') 

128 

129 # Use gridspec for better control 

130 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), 

131 wspace=0.3, hspace=0.4) 

132 

133 # Create subplots for each channel 

134 for i in range(num_channels): 

135 row_idx = i // cols 

136 col_idx = i % cols 

137 

138 # Create subplot using gridspec 

139 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

140 

141 # Compute original noise without group pattern for this channel 

142 orig_noise_fft = self._fft_transform(self.data.orig_watermarked_latents[0, i]) 

143 z_fft = orig_noise_fft - self.group_pattern[i].cpu().numpy() 

144 z_cleaned = self._ifft_transform(z_fft).real 

145 

146 # Draw the cleaned noise channel 

147 im = sub_ax.imshow(z_cleaned, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

148 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3) 

149 sub_ax.axis('off') 

150 # Add small colorbar for each subplot 

151 if use_color_bar: 

152 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04) 

153 cbar.ax.tick_params(labelsize=6) 

154 

155 return ax 

156 

157 def draw_inverted_noise_wo_group_pattern(self, 

158 channel: Optional[int] = None, 

159 title: str = "Inverted Noise without Group Pattern", 

160 cmap: str = "viridis", 

161 use_color_bar: bool = True, 

162 vmin: Optional[float] = None, 

163 vmax: Optional[float] = None, 

164 ax: Optional[Axes] = None, 

165 **kwargs) -> Axes: 

166 """ 

167 Draw the inverted noise without group pattern. 

168  

169 Parameters: 

170 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown. 

171 title (str): The title of the plot. 

172 cmap (str): The colormap to use. 

173 use_color_bar (bool): Whether to display the colorbar. 

174 ax (plt.Axes): The axes to plot on. 

175  

176 Returns: 

177 plt.axes.Axes: The plotted axes. 

178 """ 

179 if channel is not None: 

180 # Single channel visualization 

181 reversed_latent = self.data.reversed_latents[self.watermarking_step] 

182 reversed_latent_fft = self._fft_transform(reversed_latent[0, channel]) 

183 z_fft = reversed_latent_fft - self.group_pattern[channel].cpu().numpy() 

184 z_cleaned = self._ifft_transform(z_fft).real 

185 im = ax.imshow(z_cleaned, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

186 if title != "": 

187 ax.set_title(title) 

188 if use_color_bar: 

189 ax.figure.colorbar(im, ax=ax) 

190 ax.axis('off') 

191 else: 

192 # Multi-channel visualization 

193 num_channels = 4 

194 rows = 2 

195 cols = 2 

196 

197 # Clear the axis and set title 

198 ax.clear() 

199 if title != "": 

200 ax.set_title(title, pad=20) 

201 ax.axis('off') 

202 

203 # Use gridspec for better control 

204 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), 

205 wspace=0.3, hspace=0.4) 

206 

207 # Create subplots for each channel 

208 for i in range(num_channels): 

209 row_idx = i // cols 

210 col_idx = i % cols 

211 

212 # Create subplot using gridspec 

213 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

214 

215 # Compute inverted noise without group pattern for this channel 

216 reversed_latent = self.data.reversed_latents[self.watermarking_step] 

217 reversed_latent_fft = self._fft_transform(reversed_latent[0, i]) 

218 z_fft = reversed_latent_fft - self.group_pattern[i].cpu().numpy() 

219 z_cleaned = self._ifft_transform(z_fft).real 

220 

221 # Draw the cleaned inverted noise channel 

222 im = sub_ax.imshow(z_cleaned, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

223 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3) 

224 sub_ax.axis('off') 

225 # Add small colorbar for each subplot 

226 if use_color_bar: 

227 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04) 

228 cbar.ax.tick_params(labelsize=6) 

229 

230 return ax 

231 

232 def draw_diff_noise_wo_group_pattern(self, 

233 channel: Optional[int] = None, 

234 title: str = "Difference map without Group Pattern", 

235 cmap: str = "coolwarm", 

236 use_color_bar: bool = True, 

237 vmin: Optional[float] = None, 

238 vmax: Optional[float] = None, 

239 ax: Optional[Axes] = None, 

240 **kwargs) -> Axes: 

241 """ 

242 Draw the difference between original and inverted noise after removing group pattern. 

243  

244 Parameters: 

245 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown. 

246 title (str): The title of the plot. 

247 cmap (str): The colormap to use. 

248 use_color_bar (bool): Whether to display the colorbar. 

249 ax (plt.Axes): The axes to plot on. 

250  

251 Returns: 

252 plt.axes.Axes: The plotted axes. 

253 """ 

254 if channel is not None: 

255 # Single channel visualization 

256 # Process original latents 

257 orig_latent_channel = self.data.orig_watermarked_latents[0, channel] 

258 orig_noise_fft = self._fft_transform(orig_latent_channel) 

259 orig_z_fft = orig_noise_fft - self.group_pattern[channel].cpu().numpy() 

260 orig_z_cleaned = self._ifft_transform(orig_z_fft).real 

261 

262 # Process inverted latents 

263 reversed_latent = self.data.reversed_latents[self.watermarking_step] 

264 reversed_latent_fft = self._fft_transform(reversed_latent[0, channel]) 

265 inv_z_fft = reversed_latent_fft - self.group_pattern[channel].cpu().numpy() 

266 inv_z_cleaned = self._ifft_transform(inv_z_fft).real 

267 

268 # Compute difference 

269 diff_cleaned = orig_z_cleaned - inv_z_cleaned 

270 

271 im = ax.imshow(diff_cleaned, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

272 if title != "": 

273 ax.set_title(title) 

274 if use_color_bar: 

275 ax.figure.colorbar(im, ax=ax) 

276 ax.axis('off') 

277 else: 

278 # Multi-channel visualization 

279 num_channels = 4 

280 rows = 2 

281 cols = 2 

282 

283 # Clear the axis and set title 

284 ax.clear() 

285 if title != "": 

286 ax.set_title(title, pad=20) 

287 ax.axis('off') 

288 

289 # Use gridspec for better control 

290 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), 

291 wspace=0.3, hspace=0.4) 

292 

293 # Create subplots for each channel 

294 for i in range(num_channels): 

295 row_idx = i // cols 

296 col_idx = i % cols 

297 

298 # Create subplot using gridspec 

299 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

300 

301 # Process original latents for this channel 

302 orig_latent_channel = self.data.orig_watermarked_latents[0, i] 

303 orig_noise_fft = self._fft_transform(orig_latent_channel) 

304 orig_z_fft = orig_noise_fft - self.group_pattern[i].cpu().numpy() 

305 orig_z_cleaned = self._ifft_transform(orig_z_fft).real 

306 

307 # Process inverted latents for this channel 

308 reversed_latent = self.data.reversed_latents[self.watermarking_step] 

309 reversed_latent_fft = self._fft_transform(reversed_latent[0, i]) 

310 inv_z_fft = reversed_latent_fft - self.group_pattern[i].cpu().numpy() 

311 inv_z_cleaned = self._ifft_transform(inv_z_fft).real 

312 

313 # Compute difference 

314 diff_cleaned = orig_z_cleaned - inv_z_cleaned 

315 

316 # Draw the difference channel 

317 im = sub_ax.imshow(diff_cleaned, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

318 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3) 

319 sub_ax.axis('off') 

320 # Add small colorbar for each subplot 

321 if use_color_bar: 

322 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04) 

323 cbar.ax.tick_params(labelsize=6) 

324 

325 return ax 

326 

327 def draw_inverted_group_pattern_fft(self, 

328 channel: Optional[int] = None, 

329 title: str = "WIND Two-Stage Detection Visualization", 

330 cmap: str = "viridis", 

331 use_color_bar: bool = True, 

332 ax: Optional[Axes] = None, 

333 **kwargs) -> Axes: 

334 

335 # Get inverted latents 

336 reversed_latents = self.data.reversed_latents[self.watermarking_step] 

337 

338 if channel is not None: 

339 # Single channel visualization 

340 latent_channel = reversed_latents[0, channel] 

341 else: 

342 # Average across all channels for clearer visualization 

343 latent_channel = reversed_latents[0].mean(dim=0) 

344 

345 # Convert to frequency domain  

346 z_fft = torch.fft.fftshift(torch.fft.fft2(latent_channel), dim=(-1, -2)) 

347 

348 # Get the group pattern that would be detected 

349 index = self.data.current_index % self.data.M 

350 if channel is not None: 

351 group_pattern = self.group_pattern[channel] 

352 else: 

353 group_pattern = self.group_pattern.mean(dim=0) 

354 

355 # Create circular mask  

356 mask = self._create_circle_mask(64, self.data.group_radius) 

357 

358 # Remove group pattern 

359 z_fft_cleaned = z_fft - group_pattern * mask 

360 

361 detection_signal = torch.abs(z_fft_cleaned) 

362 

363 # Apply same mask that detector uses to focus on watermark region 

364 detection_signal = detection_signal * mask 

365 

366 # Plot the detection signal 

367 im = ax.imshow(detection_signal.cpu().numpy(), cmap=cmap, **kwargs) 

368 

369 if title != "": 

370 detection_info = f" (Group {index}, Radius {self.data.group_radius})" 

371 ax.set_title(title + detection_info, fontsize=10) 

372 

373 ax.axis('off') 

374 

375 # Add colorbar 

376 if use_color_bar: 

377 cbar = ax.figure.colorbar(im, ax=ax) 

378 cbar.set_label('Detection Signal Magnitude', fontsize=8) 

379 

380 return ax 

381 

382 def _create_circle_mask(self, size: int, r: int) -> torch.Tensor: 

383 """Create circular mask for watermark region (same as in detector)""" 

384 y, x = torch.meshgrid(torch.arange(size), torch.arange(size), indexing='ij') 

385 center = size // 2 

386 dist = (x - center)**2 + (y - center)**2 

387 return ((dist >= (r-2)**2) & (dist <= r**2)).float().to(self.data.orig_watermarked_latents.device)