Coverage for visualize / base.py: 97.17%

318 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16from abc import ABC, abstractmethod 

17from typing import Optional, Dict, Any, List 

18import torch 

19from PIL import Image 

20from visualize.data_for_visualization import DataForVisualization 

21import matplotlib.pyplot as plt 

22from matplotlib.axes import Axes 

23import numpy as np 

24from typing import Tuple 

25from numpy.fft import fft2, fftshift, ifft2, ifftshift 

26from PIL import Image 

27from matplotlib.gridspec import GridSpecFromSubplotSpec 

28 

29class BaseVisualizer(ABC): 

30 """Base class for watermark visualization data""" 

31 

32 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1, is_video: bool = False): 

33 """Initialize with common attributes""" 

34 self.data = data_for_visualization 

35 self.dpi = dpi 

36 self.watermarking_step = -1 # The step for inserting the watermark, defaults to -1 for the last step 

37 self.is_video = is_video # Whether this is for T2V (video) or T2I (image) model 

38 

39 def _fft_transform(self, latent: torch.Tensor) -> np.ndarray: 

40 """ 

41 Apply FFT transform to the latent tensor of the watermarked image. 

42 """ 

43 return fftshift(fft2(latent.cpu().numpy())) 

44 

45 def _ifft_transform(self, fft_data: np.ndarray) -> np.ndarray: 

46 """ 

47 Apply inverse FFT transform to the fft data. 

48 """ 

49 return ifft2(ifftshift(fft_data)) 

50 

51 def _get_latent_data(self, latents: torch.Tensor, channel: Optional[int] = None, frame: Optional[int] = None) -> torch.Tensor: 

52 """ 

53 Extract latent data with proper indexing for both T2I and T2V models. 

54  

55 Parameters: 

56 latents: The latent tensor [B, C, H, W] for T2I or [B, C, F, H, W] for T2V 

57 channel: The channel index to extract 

58 frame: The frame index to extract (only for T2V models) 

59  

60 Returns: 

61 The extracted latent tensor 

62 """ 

63 if self.is_video: 

64 # T2V model: [B, C, F, H, W] 

65 if frame is not None: 

66 if channel is not None: 

67 return latents[0, channel, frame] # [H, W] 

68 else: 

69 return latents[0, :, frame] # [C, H, W] 

70 else: 

71 # If no frame specified, use the middle frame 

72 mid_frame = latents.shape[2] // 2 

73 if channel is not None: 

74 return latents[0, channel, mid_frame] # [H, W] 

75 else: 

76 return latents[0, :, mid_frame] # [C, H, W] 

77 else: 

78 # T2I model: [B, C, H, W] 

79 if channel is not None: 

80 return latents[0, channel] # [H, W] 

81 else: 

82 return latents[0] # [C, H, W] 

83 

84 def draw_orig_latents(self, 

85 channel: Optional[int] = None, 

86 frame: Optional[int] = None, 

87 title: str = "Original Latents", 

88 cmap: str = "viridis", 

89 use_color_bar: bool = True, 

90 vmin: Optional[float] = None, 

91 vmax: Optional[float] = None, 

92 ax: Optional[Axes] = None, 

93 **kwargs) -> Axes: 

94 """ 

95 Draw the original latents of the watermarked image. 

96 

97 Parameters: 

98 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown. 

99 frame (Optional[int]): The frame index for T2V models. If None, uses middle frame for videos. 

100 title (str): The title of the plot. 

101 cmap (str): The colormap to use. 

102 use_color_bar (bool): Whether to display the colorbar. 

103 ax (Axes): The axes to plot on. 

104  

105 Returns: 

106 Axes: The plotted axes. 

107 """ 

108 if channel is not None: 

109 # Single channel visualization 

110 latent_data = self._get_latent_data(self.data.orig_watermarked_latents, channel, frame).cpu().numpy() 

111 im = ax.imshow(latent_data, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

112 if title != "": 

113 ax.set_title(title) 

114 if use_color_bar: 

115 ax.figure.colorbar(im, ax=ax) 

116 ax.axis('off') 

117 else: 

118 # Multi-channel visualization 

119 num_channels = 4 

120 rows = 2 

121 cols = 2 

122 

123 # Clear the axis and set title 

124 ax.clear() 

125 if title != "": 

126 ax.set_title(title, pad=20) 

127 ax.axis('off') 

128 

129 # Use gridspec for better control 

130 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), 

131 wspace=0.3, hspace=0.4) 

132 

133 # Create subplots for each channel 

134 for i in range(num_channels): 

135 row_idx = i // cols 

136 col_idx = i % cols 

137 

138 # Create subplot using gridspec 

139 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

140 

141 # Draw the latent channel 

142 latent_data = self._get_latent_data(self.data.orig_watermarked_latents, i, frame).cpu().numpy() 

143 im = sub_ax.imshow(latent_data, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

144 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3) 

145 sub_ax.axis('off') 

146 # Add small colorbar for each subplot 

147 if use_color_bar: 

148 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04) 

149 cbar.ax.tick_params(labelsize=6) 

150 

151 return ax 

152 

153 def draw_orig_latents_fft(self, 

154 channel: Optional[int] = None, 

155 frame: Optional[int] = None, 

156 title: str = "Original Latents in Fourier Domain", 

157 cmap: str = "viridis", 

158 use_color_bar: bool = True, 

159 vmin: Optional[float] = None, 

160 vmax: Optional[float] = None, 

161 ax: Optional[Axes] = None, 

162 **kwargs) -> Axes: 

163 """ 

164 Draw the original latents of the watermarked image in the Fourier domain. 

165  

166 Parameters: 

167 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown. 

168 frame (Optional[int]): The frame index for T2V models. If None, uses middle frame for videos. 

169 title (str): The title of the plot. 

170 cmap (str): The colormap to use. 

171 use_color_bar (bool): Whether to display the colorbar. 

172 ax (Axes): The axes to plot on. 

173 

174 Returns: 

175 Axes: The plotted axes. 

176 """ 

177 if channel is not None: 

178 # Single channel visualization 

179 latent_data = self._get_latent_data(self.data.orig_watermarked_latents, channel, frame) 

180 fft_data = self._fft_transform(latent_data) 

181 

182 im = ax.imshow(np.abs(fft_data), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

183 if title != "": 

184 ax.set_title(title) 

185 if use_color_bar: 

186 ax.figure.colorbar(im, ax=ax) 

187 ax.axis('off') 

188 else: 

189 # Multi-channel visualization 

190 num_channels = 4 

191 rows = 2 

192 cols = 2 

193 

194 # Clear the axis and set title 

195 ax.clear() 

196 if title != "": 

197 ax.set_title(title, pad=20) 

198 ax.axis('off') 

199 

200 # Use gridspec for better control 

201 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), 

202 wspace=0.3, hspace=0.4) 

203 

204 # Create subplots for each channel 

205 for i in range(num_channels): 

206 row_idx = i // cols 

207 col_idx = i % cols 

208 

209 # Create subplot using gridspec 

210 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

211 

212 # Draw the FFT of latent channel 

213 latent_data = self._get_latent_data(self.data.orig_watermarked_latents, i, frame) 

214 fft_data = self._fft_transform(latent_data) 

215 im = sub_ax.imshow(np.abs(fft_data), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

216 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3) 

217 sub_ax.axis('off') 

218 # Add small colorbar for each subplot 

219 if use_color_bar: 

220 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04) 

221 cbar.ax.tick_params(labelsize=6) 

222 

223 return ax 

224 

225 def draw_inverted_latents(self, 

226 channel: Optional[int] = None, 

227 frame: Optional[int] = None, 

228 step: Optional[int] = None, 

229 title: str = "Inverted Latents", 

230 cmap: str = "viridis", 

231 use_color_bar: bool = True, 

232 vmin: Optional[float] = None, 

233 vmax: Optional[float] = None, 

234 ax: Optional[Axes] = None, 

235 **kwargs) -> Axes: 

236 """ 

237 Draw the inverted latents of the watermarked image. 

238  

239 Parameters: 

240 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown. 

241 frame (Optional[int]): The frame index for T2V models. If None, uses middle frame for videos. 

242 step (Optional[int]): The timestep of the inverted latents. If None, the last timestep is used. 

243 title (str): The title of the plot. 

244 cmap (str): The colormap to use. 

245 use_color_bar (bool): Whether to display the colorbar. 

246 ax (Axes): The axes to plot on. 

247  

248 Returns: 

249 Axes: The plotted axes. 

250 """ 

251 if channel is not None: 

252 # Single channel visualization 

253 # Get inverted latents data 

254 if step is None: 

255 reversed_latents = self.data.reversed_latents[self.watermarking_step] 

256 else: 

257 reversed_latents = self.data.reversed_latents[step] 

258 

259 latent_data = self._get_latent_data(reversed_latents, channel, frame).cpu().numpy() 

260 im = ax.imshow(latent_data, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

261 if title != "": 

262 ax.set_title(title) 

263 if use_color_bar: 

264 ax.figure.colorbar(im, ax=ax) 

265 ax.axis('off') 

266 else: 

267 # Multi-channel visualization 

268 num_channels = 4 

269 rows = 2 

270 cols = 2 

271 

272 # Clear the axis and set title 

273 ax.clear() 

274 if title != "": 

275 ax.set_title(title, pad=20) 

276 ax.axis('off') 

277 

278 # Use gridspec for better control 

279 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), 

280 wspace=0.3, hspace=0.4) 

281 

282 # Create subplots for each channel 

283 for i in range(num_channels): 

284 row_idx = i // cols 

285 col_idx = i % cols 

286 

287 # Create subplot using gridspec 

288 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

289 

290 # Get inverted latents data 

291 if step is None: 

292 reversed_latents = self.data.reversed_latents[self.watermarking_step] 

293 else: 

294 reversed_latents = self.data.reversed_latents[step] 

295 

296 latent_data = self._get_latent_data(reversed_latents, i, frame).cpu().numpy() 

297 

298 # Draw the latent channel 

299 im = sub_ax.imshow(latent_data, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

300 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3) 

301 sub_ax.axis('off') 

302 # Add small colorbar for each subplot 

303 if use_color_bar: 

304 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04) 

305 cbar.ax.tick_params(labelsize=6) 

306 

307 return ax 

308 

309 def draw_inverted_latents_fft(self, 

310 channel: Optional[int] = None, 

311 frame: Optional[int] = None, 

312 step: int = -1, 

313 title: str = "Inverted Latents in Fourier Domain", 

314 cmap: str = "viridis", 

315 use_color_bar: bool = True, 

316 vmin: Optional[float] = None, 

317 vmax: Optional[float] = None, 

318 ax: Optional[Axes] = None, 

319 **kwargs) -> Axes: 

320 """ 

321 Draw the inverted latents of the watermarked image in the Fourier domain. 

322  

323 Parameters: 

324 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown. 

325 frame (Optional[int]): The frame index for T2V models. If None, uses middle frame for videos. 

326 step (Optional[int]): The timestep of the inverted latents. If None, the last timestep is used. 

327 title (str): The title of the plot. 

328 cmap (str): The colormap to use. 

329 use_color_bar (bool): Whether to display the colorbar. 

330 ax (Axes): The axes to plot on. 

331  

332 Returns: 

333 Axes: The plotted axes. 

334 """ 

335 if channel is not None: 

336 # Single channel visualization 

337 reversed_latents = self.data.reversed_latents[step] 

338 latent_data = self._get_latent_data(reversed_latents, channel, frame) 

339 fft_data = self._fft_transform(latent_data) 

340 

341 im = ax.imshow(np.abs(fft_data), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

342 if title != "": 

343 ax.set_title(title) 

344 if use_color_bar: 

345 ax.figure.colorbar(im, ax=ax) 

346 ax.axis('off') 

347 else: 

348 # Multi-channel visualization 

349 num_channels = 4 

350 rows = 2 

351 cols = 2 

352 

353 # Clear the axis and set title 

354 ax.clear() 

355 if title != "": 

356 ax.set_title(title, pad=20) 

357 ax.axis('off') 

358 

359 # Use gridspec for better control 

360 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), 

361 wspace=0.3, hspace=0.4) 

362 

363 # Create subplots for each channel 

364 for i in range(num_channels): 

365 row_idx = i // cols 

366 col_idx = i % cols 

367 

368 # Create subplot using gridspec 

369 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

370 

371 # Draw the FFT of inverted latent channel 

372 reversed_latents = self.data.reversed_latents[step] 

373 latent_data = self._get_latent_data(reversed_latents, i, frame) 

374 fft_data = self._fft_transform(latent_data) 

375 im = sub_ax.imshow(np.abs(fft_data), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

376 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3) 

377 sub_ax.axis('off') 

378 # Add small colorbar for each subplot 

379 if use_color_bar: 

380 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04) 

381 cbar.ax.tick_params(labelsize=6) 

382 

383 return ax 

384 

385 def draw_diff_latents_fft(self, 

386 channel: Optional[int] = None, 

387 frame: Optional[int] = None, 

388 title: str = "Difference between Original and Inverted Latents in Fourier Domain", 

389 cmap: str = "coolwarm", 

390 use_color_bar: bool = True, 

391 vmin: Optional[float] = None, 

392 vmax: Optional[float] = None, 

393 ax: Optional[Axes] = None, 

394 **kwargs) -> Axes: 

395 """ 

396 Draw the difference between the original and inverted initial latents of the watermarked image in the Fourier domain. 

397  

398 Parameters: 

399 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown. 

400 frame (Optional[int]): The frame index for T2V models. If None, uses middle frame for videos. 

401 title (str): The title of the plot. 

402 cmap (str): The colormap to use. 

403 use_color_bar (bool): Whether to display the colorbar. 

404 ax (Axes): The axes to plot on. 

405  

406 Returns: 

407 Axes: The plotted axes. 

408 """ 

409 if channel is not None: 

410 # Single channel visualization 

411 # Get original and inverted latents 

412 orig_data = self._get_latent_data(self.data.orig_watermarked_latents, channel, frame).cpu().numpy() 

413 

414 reversed_latents = self.data.reversed_latents[self.watermarking_step] 

415 inv_data = self._get_latent_data(reversed_latents, channel, frame).cpu().numpy() 

416 

417 # Compute difference 

418 diff_data = orig_data - inv_data 

419 

420 # Convert to tensor for FFT transform 

421 diff_tensor = torch.tensor(diff_data) 

422 fft_data = self._fft_transform(diff_tensor) 

423 

424 im = ax.imshow(np.abs(fft_data), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

425 if title != "": 

426 ax.set_title(title) 

427 if use_color_bar: 

428 ax.figure.colorbar(im, ax=ax) 

429 ax.axis('off') 

430 else: 

431 # Multi-channel visualization 

432 num_channels = 4 

433 rows = 2 

434 cols = 2 

435 

436 # Clear the axis and set title 

437 ax.clear() 

438 if title != "": 

439 ax.set_title(title, pad=20) 

440 ax.axis('off') 

441 

442 # Use gridspec for better control 

443 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), 

444 wspace=0.3, hspace=0.4) 

445 

446 # Create subplots for each channel 

447 for i in range(num_channels): 

448 row_idx = i // cols 

449 col_idx = i % cols 

450 

451 # Create subplot using gridspec 

452 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

453 

454 # Get original and inverted latents 

455 orig_data = self._get_latent_data(self.data.orig_watermarked_latents, i, frame).cpu().numpy() 

456 

457 reversed_latents = self.data.reversed_latents[self.watermarking_step] 

458 inv_data = self._get_latent_data(reversed_latents, i, frame).cpu().numpy() 

459 

460 # Compute difference and FFT 

461 diff_data = orig_data - inv_data 

462 diff_tensor = torch.tensor(diff_data) 

463 fft_data = self._fft_transform(diff_tensor) 

464 

465 # Draw the FFT of difference 

466 im = sub_ax.imshow(np.abs(fft_data), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

467 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3) 

468 sub_ax.axis('off') 

469 # Add small colorbar for each subplot 

470 if use_color_bar: 

471 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04) 

472 cbar.ax.tick_params(labelsize=6) 

473 

474 return ax 

475 

476 def draw_watermarked_image(self, 

477 title: str = "Watermarked Image", 

478 num_frames: int = 4, 

479 vmin: Optional[float] = None, 

480 vmax: Optional[float] = None, 

481 ax: Optional[Axes] = None, 

482 **kwargs) -> Axes: 

483 """ 

484 Draw the watermarked image or video frames. 

485 

486 For images (is_video=False), displays a single image. 

487 For videos (is_video=True), displays a grid of video frames. 

488 

489 Parameters: 

490 title (str): The title of the plot. 

491 num_frames (int): Number of frames to display for videos (default: 4). 

492 vmin (Optional[float]): Minimum value for colormap. 

493 vmax (Optional[float]): Maximum value for colormap. 

494 ax (Axes): The axes to plot on. 

495 

496 Returns: 

497 Axes: The plotted axes. 

498 """ 

499 if self.is_video: 

500 # Video visualization: display multiple frames 

501 return self._draw_video_frames(title=title, num_frames=num_frames, ax=ax, **kwargs) 

502 else: 

503 # Image visualization: display single image 

504 return self._draw_single_image(title=title, vmin=vmin, vmax=vmax, ax=ax, **kwargs) 

505 

506 def _draw_single_image(self, 

507 title: str = "Watermarked Image", 

508 vmin: Optional[float] = None, 

509 vmax: Optional[float] = None, 

510 ax: Optional[Axes] = None, 

511 **kwargs) -> Axes: 

512 """ 

513 Draw a single watermarked image. 

514 

515 Parameters: 

516 title (str): The title of the plot. 

517 vmin (Optional[float]): Minimum value for colormap. 

518 vmax (Optional[float]): Maximum value for colormap. 

519 ax (Axes): The axes to plot on. 

520 

521 Returns: 

522 Axes: The plotted axes. 

523 """ 

524 # Convert image data to numpy array 

525 if torch.is_tensor(self.data.image): 

526 # Handle tensor format (like in RI watermark) 

527 if self.data.image.dim() == 4: # [B, C, H, W] 

528 image_array = self.data.image[0].permute(1, 2, 0).cpu().numpy() 

529 elif self.data.image.dim() == 3: # [C, H, W] 

530 image_array = self.data.image.permute(1, 2, 0).cpu().numpy() 

531 else: 

532 image_array = self.data.image.cpu().numpy() 

533 

534 # Normalize to 0-1 if needed 

535 if image_array.max() > 1.0: 

536 image_array = image_array / 255.0 

537 

538 # Normalize [-1, 1] range to [0, 1] for imshow 

539 if image_array.min() < 0: 

540 image_array = (image_array + 1.0) / 2.0 

541 

542 # Normalize [-1, 1] range to [0, 1] for imshow 

543 if image_array.min() < 0: 

544 image_array = (image_array + 1.0) / 2.0 

545 

546 # Clip to valid range 

547 image_array = np.clip(image_array, 0, 1) 

548 else: 

549 # Handle PIL Image format 

550 image_array = np.array(self.data.image) 

551 

552 im = ax.imshow(image_array, vmin=vmin, vmax=vmax, **kwargs) 

553 if title != "": 

554 ax.set_title(title, fontsize=12) 

555 ax.axis('off') 

556 

557 # Hidden colorbar for nice visualization 

558 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0) 

559 cbar.ax.set_visible(False) 

560 

561 return ax 

562 

563 def _draw_video_frames(self, 

564 title: str = "Watermarked Video Frames", 

565 num_frames: int = 4, 

566 ax: Optional[Axes] = None, 

567 **kwargs) -> Axes: 

568 """ 

569 Draw multiple frames from the watermarked video. 

570 

571 This method displays a grid of video frames to show the temporal 

572 consistency of the watermarked video. 

573 

574 Parameters: 

575 title (str): The title of the plot. 

576 num_frames (int): Number of frames to display (default: 4). 

577 ax (Axes): The axes to plot on. 

578 

579 Returns: 

580 Axes: The plotted axes. 

581 """ 

582 if not hasattr(self.data, 'video_frames') or self.data.video_frames is None: 

583 raise ValueError("No video frames available for visualization. Please ensure video_frames is provided in data_for_visualization.") 

584 

585 video_frames = self.data.video_frames 

586 total_frames = len(video_frames) 

587 

588 # Limit num_frames to available frames 

589 num_frames = min(num_frames, total_frames) 

590 

591 # Calculate which frames to show (evenly distributed) 

592 if num_frames == 1: 

593 frame_indices = [total_frames // 2] # Middle frame 

594 else: 

595 frame_indices = [int(i * (total_frames - 1) / (num_frames - 1)) for i in range(num_frames)] 

596 

597 # Calculate grid layout 

598 rows = int(np.ceil(np.sqrt(num_frames))) 

599 cols = int(np.ceil(num_frames / rows)) 

600 

601 # Clear the axis and set title 

602 ax.clear() 

603 if title != "": 

604 ax.set_title(title, pad=20, fontsize=12) 

605 ax.axis('off') 

606 

607 # Use gridspec for better control 

608 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), 

609 wspace=0.1, hspace=0.4) 

610 

611 # Create subplots for each frame 

612 for i, frame_idx in enumerate(frame_indices): 

613 row_idx = i // cols 

614 col_idx = i % cols 

615 

616 # Create subplot using gridspec 

617 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

618 

619 # Get the frame 

620 frame = video_frames[frame_idx] 

621 

622 # Convert frame to displayable format 

623 try: 

624 # First, convert tensor to numpy if needed 

625 if hasattr(frame, 'cpu'): # PyTorch tensor 

626 frame = frame.cpu().numpy() 

627 elif hasattr(frame, 'numpy'): # Other tensor types 

628 frame = frame.numpy() 

629 elif hasattr(frame, 'convert'): # PIL Image 

630 frame = np.array(frame) 

631 

632 # Handle channels-first format (C, H, W) -> (H, W, C) for numpy arrays 

633 if isinstance(frame, np.ndarray) and len(frame.shape) == 3: 

634 if frame.shape[0] in [1, 3, 4]: # Channels first 

635 frame = np.transpose(frame, (1, 2, 0)) 

636 

637 # Ensure proper data type for matplotlib 

638 if isinstance(frame, np.ndarray): 

639 if frame.dtype == np.float64: 

640 frame = frame.astype(np.float32) 

641 elif frame.dtype not in [np.uint8, np.float32]: 

642 # Convert to float32 and normalize if needed 

643 frame = frame.astype(np.float32) 

644 if frame.max() > 1.0: 

645 frame = frame / 255.0 

646 

647 # Normalize [-1, 1] range to [0, 1] for imshow 

648 if frame.min() < 0: 

649 frame = (frame + 1.0) / 2.0 

650 

651 # Clip to valid range [0, 1] 

652 frame = np.clip(frame, 0, 1) 

653 

654 # Normalize [-1, 1] range to [0, 1] for imshow 

655 if frame.min() < 0: 

656 frame = (frame + 1.0) / 2.0 

657 

658 # Clip to valid range [0, 1] 

659 frame = np.clip(frame, 0, 1) 

660 

661 im = sub_ax.imshow(frame) 

662 

663 except Exception as e: 

664 print(f"Error displaying frame {frame_idx}: {e}") 

665 

666 sub_ax.set_title(f'Frame {frame_idx}', fontsize=10, pad=5) 

667 sub_ax.axis('off') 

668 

669 # Hide unused subplots 

670 for i in range(num_frames, rows * cols): 

671 row_idx = i // cols 

672 col_idx = i % cols 

673 if row_idx < rows and col_idx < cols: 

674 empty_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

675 empty_ax.axis('off') 

676 

677 return ax 

678 

679 def visualize(self, 

680 rows: int, 

681 cols: int, 

682 methods: List[str], 

683 figsize: Optional[Tuple[int, int]] = None, 

684 method_kwargs: Optional[List[Dict[str, Any]]] = None, 

685 save_path: Optional[str] = None) -> plt.Figure: 

686 """ 

687 Comprehensive visualization of watermark analysis. 

688  

689 Parameters: 

690 rows (int): The number of rows of the subplots. 

691 cols (int): The number of columns of the subplots. 

692 methods (List[str]): List of methods to call. 

693 method_kwargs (Optional[List[Dict[str, Any]]]): List of keyword arguments for each method. 

694 figsize (Tuple[int, int]): The size of the figure. 

695 save_path (Optional[str]): The path to save the figure. 

696  

697 Returns: 

698 plt.Figure: The matplotlib figure object. 

699 """ 

700 # Check if the rows and cols are compatible with the number of methods 

701 if len(methods) != rows * cols: 

702 raise ValueError(f"The number of methods ({len(methods)}) is not compatible with the layout ({rows}x{cols})") 

703 

704 # Initialize the figure size if not provided 

705 if figsize is None: 

706 figsize = (cols * 5, rows * 5) 

707 

708 # Create figure and subplots 

709 fig, axes = plt.subplots(rows, cols, figsize=figsize) 

710 

711 # Ensure axes is always a 2D array for consistent indexing 

712 if rows == 1 and cols == 1: 

713 axes = np.array([[axes]]) 

714 elif rows == 1: 

715 axes = axes.reshape(1, -1) 

716 elif cols == 1: 

717 axes = axes.reshape(-1, 1) 

718 

719 if method_kwargs is None: 

720 method_kwargs = [{} for _ in methods] 

721 

722 # Plot each method 

723 for i, method_name in enumerate(methods): 

724 row = i // cols 

725 col = i % cols 

726 ax = axes[row, col] 

727 

728 try: 

729 method = getattr(self, method_name) 

730 except AttributeError: 

731 raise ValueError(f"Method '{method_name}' not found in {self.__class__.__name__}") 

732 

733 try: 

734 # print(method_kwargs[i]) 

735 method(ax=ax, **method_kwargs[i]) 

736 except TypeError: 

737 raise ValueError(f"Method '{method_name}' does not accept the given arguments: {method_kwargs[i]}") 

738 

739 # if the number of methods is less than the number of axes, hide the unused axes 

740 for i in range(len(methods), rows * cols): 

741 row = i // cols 

742 col = i % cols 

743 axes[row, col].axis('off') 

744 

745 plt.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0) 

746 

747 if save_path is not None: 

748 plt.savefig(save_path, bbox_inches='tight', dpi=self.dpi) 

749 

750 return fig