Coverage for visualize / base.py: 97.17%
318 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16from abc import ABC, abstractmethod
17from typing import Optional, Dict, Any, List
18import torch
19from PIL import Image
20from visualize.data_for_visualization import DataForVisualization
21import matplotlib.pyplot as plt
22from matplotlib.axes import Axes
23import numpy as np
24from typing import Tuple
25from numpy.fft import fft2, fftshift, ifft2, ifftshift
26from PIL import Image
27from matplotlib.gridspec import GridSpecFromSubplotSpec
29class BaseVisualizer(ABC):
30 """Base class for watermark visualization data"""
32 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1, is_video: bool = False):
33 """Initialize with common attributes"""
34 self.data = data_for_visualization
35 self.dpi = dpi
36 self.watermarking_step = -1 # The step for inserting the watermark, defaults to -1 for the last step
37 self.is_video = is_video # Whether this is for T2V (video) or T2I (image) model
39 def _fft_transform(self, latent: torch.Tensor) -> np.ndarray:
40 """
41 Apply FFT transform to the latent tensor of the watermarked image.
42 """
43 return fftshift(fft2(latent.cpu().numpy()))
45 def _ifft_transform(self, fft_data: np.ndarray) -> np.ndarray:
46 """
47 Apply inverse FFT transform to the fft data.
48 """
49 return ifft2(ifftshift(fft_data))
51 def _get_latent_data(self, latents: torch.Tensor, channel: Optional[int] = None, frame: Optional[int] = None) -> torch.Tensor:
52 """
53 Extract latent data with proper indexing for both T2I and T2V models.
55 Parameters:
56 latents: The latent tensor [B, C, H, W] for T2I or [B, C, F, H, W] for T2V
57 channel: The channel index to extract
58 frame: The frame index to extract (only for T2V models)
60 Returns:
61 The extracted latent tensor
62 """
63 if self.is_video:
64 # T2V model: [B, C, F, H, W]
65 if frame is not None:
66 if channel is not None:
67 return latents[0, channel, frame] # [H, W]
68 else:
69 return latents[0, :, frame] # [C, H, W]
70 else:
71 # If no frame specified, use the middle frame
72 mid_frame = latents.shape[2] // 2
73 if channel is not None:
74 return latents[0, channel, mid_frame] # [H, W]
75 else:
76 return latents[0, :, mid_frame] # [C, H, W]
77 else:
78 # T2I model: [B, C, H, W]
79 if channel is not None:
80 return latents[0, channel] # [H, W]
81 else:
82 return latents[0] # [C, H, W]
84 def draw_orig_latents(self,
85 channel: Optional[int] = None,
86 frame: Optional[int] = None,
87 title: str = "Original Latents",
88 cmap: str = "viridis",
89 use_color_bar: bool = True,
90 vmin: Optional[float] = None,
91 vmax: Optional[float] = None,
92 ax: Optional[Axes] = None,
93 **kwargs) -> Axes:
94 """
95 Draw the original latents of the watermarked image.
97 Parameters:
98 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown.
99 frame (Optional[int]): The frame index for T2V models. If None, uses middle frame for videos.
100 title (str): The title of the plot.
101 cmap (str): The colormap to use.
102 use_color_bar (bool): Whether to display the colorbar.
103 ax (Axes): The axes to plot on.
105 Returns:
106 Axes: The plotted axes.
107 """
108 if channel is not None:
109 # Single channel visualization
110 latent_data = self._get_latent_data(self.data.orig_watermarked_latents, channel, frame).cpu().numpy()
111 im = ax.imshow(latent_data, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
112 if title != "":
113 ax.set_title(title)
114 if use_color_bar:
115 ax.figure.colorbar(im, ax=ax)
116 ax.axis('off')
117 else:
118 # Multi-channel visualization
119 num_channels = 4
120 rows = 2
121 cols = 2
123 # Clear the axis and set title
124 ax.clear()
125 if title != "":
126 ax.set_title(title, pad=20)
127 ax.axis('off')
129 # Use gridspec for better control
130 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(),
131 wspace=0.3, hspace=0.4)
133 # Create subplots for each channel
134 for i in range(num_channels):
135 row_idx = i // cols
136 col_idx = i % cols
138 # Create subplot using gridspec
139 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
141 # Draw the latent channel
142 latent_data = self._get_latent_data(self.data.orig_watermarked_latents, i, frame).cpu().numpy()
143 im = sub_ax.imshow(latent_data, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
144 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3)
145 sub_ax.axis('off')
146 # Add small colorbar for each subplot
147 if use_color_bar:
148 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04)
149 cbar.ax.tick_params(labelsize=6)
151 return ax
153 def draw_orig_latents_fft(self,
154 channel: Optional[int] = None,
155 frame: Optional[int] = None,
156 title: str = "Original Latents in Fourier Domain",
157 cmap: str = "viridis",
158 use_color_bar: bool = True,
159 vmin: Optional[float] = None,
160 vmax: Optional[float] = None,
161 ax: Optional[Axes] = None,
162 **kwargs) -> Axes:
163 """
164 Draw the original latents of the watermarked image in the Fourier domain.
166 Parameters:
167 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown.
168 frame (Optional[int]): The frame index for T2V models. If None, uses middle frame for videos.
169 title (str): The title of the plot.
170 cmap (str): The colormap to use.
171 use_color_bar (bool): Whether to display the colorbar.
172 ax (Axes): The axes to plot on.
174 Returns:
175 Axes: The plotted axes.
176 """
177 if channel is not None:
178 # Single channel visualization
179 latent_data = self._get_latent_data(self.data.orig_watermarked_latents, channel, frame)
180 fft_data = self._fft_transform(latent_data)
182 im = ax.imshow(np.abs(fft_data), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
183 if title != "":
184 ax.set_title(title)
185 if use_color_bar:
186 ax.figure.colorbar(im, ax=ax)
187 ax.axis('off')
188 else:
189 # Multi-channel visualization
190 num_channels = 4
191 rows = 2
192 cols = 2
194 # Clear the axis and set title
195 ax.clear()
196 if title != "":
197 ax.set_title(title, pad=20)
198 ax.axis('off')
200 # Use gridspec for better control
201 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(),
202 wspace=0.3, hspace=0.4)
204 # Create subplots for each channel
205 for i in range(num_channels):
206 row_idx = i // cols
207 col_idx = i % cols
209 # Create subplot using gridspec
210 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
212 # Draw the FFT of latent channel
213 latent_data = self._get_latent_data(self.data.orig_watermarked_latents, i, frame)
214 fft_data = self._fft_transform(latent_data)
215 im = sub_ax.imshow(np.abs(fft_data), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
216 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3)
217 sub_ax.axis('off')
218 # Add small colorbar for each subplot
219 if use_color_bar:
220 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04)
221 cbar.ax.tick_params(labelsize=6)
223 return ax
225 def draw_inverted_latents(self,
226 channel: Optional[int] = None,
227 frame: Optional[int] = None,
228 step: Optional[int] = None,
229 title: str = "Inverted Latents",
230 cmap: str = "viridis",
231 use_color_bar: bool = True,
232 vmin: Optional[float] = None,
233 vmax: Optional[float] = None,
234 ax: Optional[Axes] = None,
235 **kwargs) -> Axes:
236 """
237 Draw the inverted latents of the watermarked image.
239 Parameters:
240 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown.
241 frame (Optional[int]): The frame index for T2V models. If None, uses middle frame for videos.
242 step (Optional[int]): The timestep of the inverted latents. If None, the last timestep is used.
243 title (str): The title of the plot.
244 cmap (str): The colormap to use.
245 use_color_bar (bool): Whether to display the colorbar.
246 ax (Axes): The axes to plot on.
248 Returns:
249 Axes: The plotted axes.
250 """
251 if channel is not None:
252 # Single channel visualization
253 # Get inverted latents data
254 if step is None:
255 reversed_latents = self.data.reversed_latents[self.watermarking_step]
256 else:
257 reversed_latents = self.data.reversed_latents[step]
259 latent_data = self._get_latent_data(reversed_latents, channel, frame).cpu().numpy()
260 im = ax.imshow(latent_data, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
261 if title != "":
262 ax.set_title(title)
263 if use_color_bar:
264 ax.figure.colorbar(im, ax=ax)
265 ax.axis('off')
266 else:
267 # Multi-channel visualization
268 num_channels = 4
269 rows = 2
270 cols = 2
272 # Clear the axis and set title
273 ax.clear()
274 if title != "":
275 ax.set_title(title, pad=20)
276 ax.axis('off')
278 # Use gridspec for better control
279 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(),
280 wspace=0.3, hspace=0.4)
282 # Create subplots for each channel
283 for i in range(num_channels):
284 row_idx = i // cols
285 col_idx = i % cols
287 # Create subplot using gridspec
288 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
290 # Get inverted latents data
291 if step is None:
292 reversed_latents = self.data.reversed_latents[self.watermarking_step]
293 else:
294 reversed_latents = self.data.reversed_latents[step]
296 latent_data = self._get_latent_data(reversed_latents, i, frame).cpu().numpy()
298 # Draw the latent channel
299 im = sub_ax.imshow(latent_data, cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
300 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3)
301 sub_ax.axis('off')
302 # Add small colorbar for each subplot
303 if use_color_bar:
304 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04)
305 cbar.ax.tick_params(labelsize=6)
307 return ax
309 def draw_inverted_latents_fft(self,
310 channel: Optional[int] = None,
311 frame: Optional[int] = None,
312 step: int = -1,
313 title: str = "Inverted Latents in Fourier Domain",
314 cmap: str = "viridis",
315 use_color_bar: bool = True,
316 vmin: Optional[float] = None,
317 vmax: Optional[float] = None,
318 ax: Optional[Axes] = None,
319 **kwargs) -> Axes:
320 """
321 Draw the inverted latents of the watermarked image in the Fourier domain.
323 Parameters:
324 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown.
325 frame (Optional[int]): The frame index for T2V models. If None, uses middle frame for videos.
326 step (Optional[int]): The timestep of the inverted latents. If None, the last timestep is used.
327 title (str): The title of the plot.
328 cmap (str): The colormap to use.
329 use_color_bar (bool): Whether to display the colorbar.
330 ax (Axes): The axes to plot on.
332 Returns:
333 Axes: The plotted axes.
334 """
335 if channel is not None:
336 # Single channel visualization
337 reversed_latents = self.data.reversed_latents[step]
338 latent_data = self._get_latent_data(reversed_latents, channel, frame)
339 fft_data = self._fft_transform(latent_data)
341 im = ax.imshow(np.abs(fft_data), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
342 if title != "":
343 ax.set_title(title)
344 if use_color_bar:
345 ax.figure.colorbar(im, ax=ax)
346 ax.axis('off')
347 else:
348 # Multi-channel visualization
349 num_channels = 4
350 rows = 2
351 cols = 2
353 # Clear the axis and set title
354 ax.clear()
355 if title != "":
356 ax.set_title(title, pad=20)
357 ax.axis('off')
359 # Use gridspec for better control
360 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(),
361 wspace=0.3, hspace=0.4)
363 # Create subplots for each channel
364 for i in range(num_channels):
365 row_idx = i // cols
366 col_idx = i % cols
368 # Create subplot using gridspec
369 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
371 # Draw the FFT of inverted latent channel
372 reversed_latents = self.data.reversed_latents[step]
373 latent_data = self._get_latent_data(reversed_latents, i, frame)
374 fft_data = self._fft_transform(latent_data)
375 im = sub_ax.imshow(np.abs(fft_data), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
376 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3)
377 sub_ax.axis('off')
378 # Add small colorbar for each subplot
379 if use_color_bar:
380 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04)
381 cbar.ax.tick_params(labelsize=6)
383 return ax
385 def draw_diff_latents_fft(self,
386 channel: Optional[int] = None,
387 frame: Optional[int] = None,
388 title: str = "Difference between Original and Inverted Latents in Fourier Domain",
389 cmap: str = "coolwarm",
390 use_color_bar: bool = True,
391 vmin: Optional[float] = None,
392 vmax: Optional[float] = None,
393 ax: Optional[Axes] = None,
394 **kwargs) -> Axes:
395 """
396 Draw the difference between the original and inverted initial latents of the watermarked image in the Fourier domain.
398 Parameters:
399 channel (Optional[int]): The channel of the latent tensor to visualize. If None, all 4 channels are shown.
400 frame (Optional[int]): The frame index for T2V models. If None, uses middle frame for videos.
401 title (str): The title of the plot.
402 cmap (str): The colormap to use.
403 use_color_bar (bool): Whether to display the colorbar.
404 ax (Axes): The axes to plot on.
406 Returns:
407 Axes: The plotted axes.
408 """
409 if channel is not None:
410 # Single channel visualization
411 # Get original and inverted latents
412 orig_data = self._get_latent_data(self.data.orig_watermarked_latents, channel, frame).cpu().numpy()
414 reversed_latents = self.data.reversed_latents[self.watermarking_step]
415 inv_data = self._get_latent_data(reversed_latents, channel, frame).cpu().numpy()
417 # Compute difference
418 diff_data = orig_data - inv_data
420 # Convert to tensor for FFT transform
421 diff_tensor = torch.tensor(diff_data)
422 fft_data = self._fft_transform(diff_tensor)
424 im = ax.imshow(np.abs(fft_data), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
425 if title != "":
426 ax.set_title(title)
427 if use_color_bar:
428 ax.figure.colorbar(im, ax=ax)
429 ax.axis('off')
430 else:
431 # Multi-channel visualization
432 num_channels = 4
433 rows = 2
434 cols = 2
436 # Clear the axis and set title
437 ax.clear()
438 if title != "":
439 ax.set_title(title, pad=20)
440 ax.axis('off')
442 # Use gridspec for better control
443 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(),
444 wspace=0.3, hspace=0.4)
446 # Create subplots for each channel
447 for i in range(num_channels):
448 row_idx = i // cols
449 col_idx = i % cols
451 # Create subplot using gridspec
452 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
454 # Get original and inverted latents
455 orig_data = self._get_latent_data(self.data.orig_watermarked_latents, i, frame).cpu().numpy()
457 reversed_latents = self.data.reversed_latents[self.watermarking_step]
458 inv_data = self._get_latent_data(reversed_latents, i, frame).cpu().numpy()
460 # Compute difference and FFT
461 diff_data = orig_data - inv_data
462 diff_tensor = torch.tensor(diff_data)
463 fft_data = self._fft_transform(diff_tensor)
465 # Draw the FFT of difference
466 im = sub_ax.imshow(np.abs(fft_data), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
467 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3)
468 sub_ax.axis('off')
469 # Add small colorbar for each subplot
470 if use_color_bar:
471 cbar = ax.figure.colorbar(im, ax=sub_ax, fraction=0.046, pad=0.04)
472 cbar.ax.tick_params(labelsize=6)
474 return ax
476 def draw_watermarked_image(self,
477 title: str = "Watermarked Image",
478 num_frames: int = 4,
479 vmin: Optional[float] = None,
480 vmax: Optional[float] = None,
481 ax: Optional[Axes] = None,
482 **kwargs) -> Axes:
483 """
484 Draw the watermarked image or video frames.
486 For images (is_video=False), displays a single image.
487 For videos (is_video=True), displays a grid of video frames.
489 Parameters:
490 title (str): The title of the plot.
491 num_frames (int): Number of frames to display for videos (default: 4).
492 vmin (Optional[float]): Minimum value for colormap.
493 vmax (Optional[float]): Maximum value for colormap.
494 ax (Axes): The axes to plot on.
496 Returns:
497 Axes: The plotted axes.
498 """
499 if self.is_video:
500 # Video visualization: display multiple frames
501 return self._draw_video_frames(title=title, num_frames=num_frames, ax=ax, **kwargs)
502 else:
503 # Image visualization: display single image
504 return self._draw_single_image(title=title, vmin=vmin, vmax=vmax, ax=ax, **kwargs)
506 def _draw_single_image(self,
507 title: str = "Watermarked Image",
508 vmin: Optional[float] = None,
509 vmax: Optional[float] = None,
510 ax: Optional[Axes] = None,
511 **kwargs) -> Axes:
512 """
513 Draw a single watermarked image.
515 Parameters:
516 title (str): The title of the plot.
517 vmin (Optional[float]): Minimum value for colormap.
518 vmax (Optional[float]): Maximum value for colormap.
519 ax (Axes): The axes to plot on.
521 Returns:
522 Axes: The plotted axes.
523 """
524 # Convert image data to numpy array
525 if torch.is_tensor(self.data.image):
526 # Handle tensor format (like in RI watermark)
527 if self.data.image.dim() == 4: # [B, C, H, W]
528 image_array = self.data.image[0].permute(1, 2, 0).cpu().numpy()
529 elif self.data.image.dim() == 3: # [C, H, W]
530 image_array = self.data.image.permute(1, 2, 0).cpu().numpy()
531 else:
532 image_array = self.data.image.cpu().numpy()
534 # Normalize to 0-1 if needed
535 if image_array.max() > 1.0:
536 image_array = image_array / 255.0
538 # Normalize [-1, 1] range to [0, 1] for imshow
539 if image_array.min() < 0:
540 image_array = (image_array + 1.0) / 2.0
542 # Normalize [-1, 1] range to [0, 1] for imshow
543 if image_array.min() < 0:
544 image_array = (image_array + 1.0) / 2.0
546 # Clip to valid range
547 image_array = np.clip(image_array, 0, 1)
548 else:
549 # Handle PIL Image format
550 image_array = np.array(self.data.image)
552 im = ax.imshow(image_array, vmin=vmin, vmax=vmax, **kwargs)
553 if title != "":
554 ax.set_title(title, fontsize=12)
555 ax.axis('off')
557 # Hidden colorbar for nice visualization
558 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0)
559 cbar.ax.set_visible(False)
561 return ax
563 def _draw_video_frames(self,
564 title: str = "Watermarked Video Frames",
565 num_frames: int = 4,
566 ax: Optional[Axes] = None,
567 **kwargs) -> Axes:
568 """
569 Draw multiple frames from the watermarked video.
571 This method displays a grid of video frames to show the temporal
572 consistency of the watermarked video.
574 Parameters:
575 title (str): The title of the plot.
576 num_frames (int): Number of frames to display (default: 4).
577 ax (Axes): The axes to plot on.
579 Returns:
580 Axes: The plotted axes.
581 """
582 if not hasattr(self.data, 'video_frames') or self.data.video_frames is None:
583 raise ValueError("No video frames available for visualization. Please ensure video_frames is provided in data_for_visualization.")
585 video_frames = self.data.video_frames
586 total_frames = len(video_frames)
588 # Limit num_frames to available frames
589 num_frames = min(num_frames, total_frames)
591 # Calculate which frames to show (evenly distributed)
592 if num_frames == 1:
593 frame_indices = [total_frames // 2] # Middle frame
594 else:
595 frame_indices = [int(i * (total_frames - 1) / (num_frames - 1)) for i in range(num_frames)]
597 # Calculate grid layout
598 rows = int(np.ceil(np.sqrt(num_frames)))
599 cols = int(np.ceil(num_frames / rows))
601 # Clear the axis and set title
602 ax.clear()
603 if title != "":
604 ax.set_title(title, pad=20, fontsize=12)
605 ax.axis('off')
607 # Use gridspec for better control
608 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(),
609 wspace=0.1, hspace=0.4)
611 # Create subplots for each frame
612 for i, frame_idx in enumerate(frame_indices):
613 row_idx = i // cols
614 col_idx = i % cols
616 # Create subplot using gridspec
617 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
619 # Get the frame
620 frame = video_frames[frame_idx]
622 # Convert frame to displayable format
623 try:
624 # First, convert tensor to numpy if needed
625 if hasattr(frame, 'cpu'): # PyTorch tensor
626 frame = frame.cpu().numpy()
627 elif hasattr(frame, 'numpy'): # Other tensor types
628 frame = frame.numpy()
629 elif hasattr(frame, 'convert'): # PIL Image
630 frame = np.array(frame)
632 # Handle channels-first format (C, H, W) -> (H, W, C) for numpy arrays
633 if isinstance(frame, np.ndarray) and len(frame.shape) == 3:
634 if frame.shape[0] in [1, 3, 4]: # Channels first
635 frame = np.transpose(frame, (1, 2, 0))
637 # Ensure proper data type for matplotlib
638 if isinstance(frame, np.ndarray):
639 if frame.dtype == np.float64:
640 frame = frame.astype(np.float32)
641 elif frame.dtype not in [np.uint8, np.float32]:
642 # Convert to float32 and normalize if needed
643 frame = frame.astype(np.float32)
644 if frame.max() > 1.0:
645 frame = frame / 255.0
647 # Normalize [-1, 1] range to [0, 1] for imshow
648 if frame.min() < 0:
649 frame = (frame + 1.0) / 2.0
651 # Clip to valid range [0, 1]
652 frame = np.clip(frame, 0, 1)
654 # Normalize [-1, 1] range to [0, 1] for imshow
655 if frame.min() < 0:
656 frame = (frame + 1.0) / 2.0
658 # Clip to valid range [0, 1]
659 frame = np.clip(frame, 0, 1)
661 im = sub_ax.imshow(frame)
663 except Exception as e:
664 print(f"Error displaying frame {frame_idx}: {e}")
666 sub_ax.set_title(f'Frame {frame_idx}', fontsize=10, pad=5)
667 sub_ax.axis('off')
669 # Hide unused subplots
670 for i in range(num_frames, rows * cols):
671 row_idx = i // cols
672 col_idx = i % cols
673 if row_idx < rows and col_idx < cols:
674 empty_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
675 empty_ax.axis('off')
677 return ax
679 def visualize(self,
680 rows: int,
681 cols: int,
682 methods: List[str],
683 figsize: Optional[Tuple[int, int]] = None,
684 method_kwargs: Optional[List[Dict[str, Any]]] = None,
685 save_path: Optional[str] = None) -> plt.Figure:
686 """
687 Comprehensive visualization of watermark analysis.
689 Parameters:
690 rows (int): The number of rows of the subplots.
691 cols (int): The number of columns of the subplots.
692 methods (List[str]): List of methods to call.
693 method_kwargs (Optional[List[Dict[str, Any]]]): List of keyword arguments for each method.
694 figsize (Tuple[int, int]): The size of the figure.
695 save_path (Optional[str]): The path to save the figure.
697 Returns:
698 plt.Figure: The matplotlib figure object.
699 """
700 # Check if the rows and cols are compatible with the number of methods
701 if len(methods) != rows * cols:
702 raise ValueError(f"The number of methods ({len(methods)}) is not compatible with the layout ({rows}x{cols})")
704 # Initialize the figure size if not provided
705 if figsize is None:
706 figsize = (cols * 5, rows * 5)
708 # Create figure and subplots
709 fig, axes = plt.subplots(rows, cols, figsize=figsize)
711 # Ensure axes is always a 2D array for consistent indexing
712 if rows == 1 and cols == 1:
713 axes = np.array([[axes]])
714 elif rows == 1:
715 axes = axes.reshape(1, -1)
716 elif cols == 1:
717 axes = axes.reshape(-1, 1)
719 if method_kwargs is None:
720 method_kwargs = [{} for _ in methods]
722 # Plot each method
723 for i, method_name in enumerate(methods):
724 row = i // cols
725 col = i % cols
726 ax = axes[row, col]
728 try:
729 method = getattr(self, method_name)
730 except AttributeError:
731 raise ValueError(f"Method '{method_name}' not found in {self.__class__.__name__}")
733 try:
734 # print(method_kwargs[i])
735 method(ax=ax, **method_kwargs[i])
736 except TypeError:
737 raise ValueError(f"Method '{method_name}' does not accept the given arguments: {method_kwargs[i]}")
739 # if the number of methods is less than the number of axes, hide the unused axes
740 for i in range(len(methods), rows * cols):
741 row = i // cols
742 col = i % cols
743 axes[row, col].axis('off')
745 plt.tight_layout(pad=2.0, w_pad=3.0, h_pad=2.0)
747 if save_path is not None:
748 plt.savefig(save_path, bbox_inches='tight', dpi=self.dpi)
750 return fig