Coverage for visualize / tr / tr_visualizer.py: 97.37%
38 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1import torch
2import matplotlib.pyplot as plt
3from matplotlib.axes import Axes
4from typing import Optional
5import numpy as np
6from visualize.base import BaseVisualizer
7from visualize.data_for_visualization import DataForVisualization
9class TreeRingVisualizer(BaseVisualizer):
10 """Tree-Ring watermark visualization class"""
12 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1):
13 super().__init__(data_for_visualization, dpi, watermarking_step)
15 def draw_pattern_fft(self,
16 title: str = "Tree-Ring FFT with Watermark Area",
17 cmap: str = "viridis",
18 use_color_bar: bool = True,
19 vmin: Optional[float] = None,
20 vmax: Optional[float] = None,
21 ax: Optional[Axes] = None,
22 **kwargs) -> Axes:
23 """
24 Draw FFT visualization with original watermark pattern, with all 0 background.
26 Parameters:
27 title (str): The title of the plot.
28 cmap (str): The colormap to use.
29 use_color_bar (bool): Whether to display the colorbar.
30 ax (Axes): The axes to plot on.
32 Returns:
33 Axes: The plotted axes.
34 """
35 orig_latent = self.data.orig_watermarked_latents[0, self.data.w_channel].cpu()
36 watermarking_mask = self.data.watermarking_mask[0, self.data.w_channel].cpu()
38 fft_data = torch.from_numpy(self._fft_transform(orig_latent))
39 fft_vis = torch.zeros_like(fft_data)
40 fft_vis[watermarking_mask] = fft_data[watermarking_mask]
42 im = ax.imshow(np.abs(fft_vis.cpu().numpy()), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
43 if title != "":
44 ax.set_title(title)
45 if use_color_bar:
46 ax.figure.colorbar(im, ax=ax)
47 ax.axis('off')
49 return ax
51 def draw_inverted_pattern_fft(self,
52 step: Optional[int] = None,
53 title: str = "Tree-Ring FFT with Inverted Watermark Area",
54 cmap: str = "viridis",
55 use_color_bar: bool = True,
56 vmin: Optional[float] = None,
57 vmax: Optional[float] = None,
58 ax: Optional[Axes] = None,
59 **kwargs) -> Axes:
60 """
61 Draw FFT visualization with inverted pattern, with all 0 background.
63 Parameters:
64 step (Optional[int]): The timestep of the inverted latents. If None, the last timestep is used.
65 title (str): The title of the plot.
66 cmap (str): The colormap to use.
67 use_color_bar (bool): Whether to display the colorbar.
68 ax (Axes): The axes to plot on.
70 Returns:
71 Axes: The plotted axes.
72 """
73 if step is None:
74 inverted_latent = self.data.reversed_latents[self.watermarking_step][0, self.data.w_channel]
75 else:
76 inverted_latent = self.data.reversed_latents[step][0, self.data.w_channel]
78 watermarking_mask = self.data.watermarking_mask[0, self.data.w_channel].cpu()
80 fft_data = torch.from_numpy(self._fft_transform(inverted_latent))
81 fft_vis = torch.zeros_like(fft_data).to(fft_data.device)
82 fft_vis[watermarking_mask] = fft_data[watermarking_mask]
84 im = ax.imshow(np.abs(fft_vis.cpu().numpy()), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
85 if title != "":
86 ax.set_title(title)
87 if use_color_bar:
88 ax.figure.colorbar(im, ax=ax)
89 ax.axis('off')
91 return ax