Coverage for visualize / sfw / sfw_visualizer.py: 96.20%
79 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from typing import Optional
2import torch
3from matplotlib.axes import Axes
4import numpy as np
5from visualize.base import BaseVisualizer
6from visualize.data_for_visualization import DataForVisualization
8class SFWVisualizer(BaseVisualizer):
9 """SFW watermark visualization class"""
11 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1):
12 super().__init__(data_for_visualization, dpi, watermarking_step)
14 def draw_pattern_fft(self,
15 title: str = "SFW FFT with Watermark Area",
16 cmap: str = "viridis",
17 use_color_bar: bool = True,
18 vmin: Optional[float] = None,
19 vmax: Optional[float] = None,
20 ax: Optional[Axes] = None,
21 **kwargs) -> Axes:
22 """
23 Draw FFT visualization with original watermark pattern, with all 0 background.
25 Parameters:
26 title (str): The title of the plot.
27 cmap (str): The colormap to use.
28 use_color_bar (bool): Whether to display the colorbar.
29 ax (Axes): The axes to plot on.
31 Returns:
32 Axes: The plotted axes.
33 """
34 orig_latent = self.data.orig_watermarked_latents[0, self.data.w_channel].cpu()
35 if self.data.wm_type=="HSQR":
36 gt_patch = self.data.gt_patch
37 # extract patch for channel (may be complex)
38 #patt = gt_patch[0, self.data.w_channel] if torch.is_tensor(gt_patch) and gt_patch.dim() == 4 else gt_patch
39 patt_np = gt_patch.detach().cpu().numpy() if torch.is_tensor(gt_patch) else np.array(gt_patch)
40 if patt_np.ndim ==4:
41 patt_np = patt_np[0]
42 # if multi-channel, try to squeeze to 2D
43 if patt_np.ndim == 3:
44 # (c, H, W) => choose first channel
45 patt_np = patt_np[0]
46 # use magnitude if complex
47 if np.iscomplexobj(patt_np):
48 patt_vis = np.abs(patt_np)
49 else:
50 patt_vis = np.abs(patt_np)
51 fft_np = self._fft_transform(orig_latent)
52 H, W = fft_np.shape
53 ph, pw = patt_vis.shape
54 sh = (H - ph) // 2
55 sw = (W - pw) // 2
56 fft_data = torch.from_numpy(fft_np)
57 fft_vis = torch.zeros_like(fft_data)
58 patt_t = torch.from_numpy(patt_vis.astype(np.float32)).to(dtype=fft_vis.dtype, device=fft_vis.device)
59 if patt_t.shape != (ph, pw):
60 raise ValueError(f"Pattern shape mismatch after conversion: {patt_t.shape} vs {(ph,pw)}")
61 fft_vis[sh:sh+ph, sw:sw+pw] = patt_t
62 else:
63 watermarking_mask = self.data.watermarking_mask[0, self.data.w_channel].cpu()
65 fft_data = torch.from_numpy(self._fft_transform(orig_latent))
66 fft_vis = torch.zeros_like(fft_data)
67 fft_vis[watermarking_mask] = fft_data[watermarking_mask]
69 im = ax.imshow(np.abs(fft_vis.cpu().numpy()), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
70 if title != "":
71 ax.set_title(title)
72 if use_color_bar:
73 ax.figure.colorbar(im, ax=ax)
74 ax.axis('off')
76 return ax
78 def draw_inverted_pattern_fft(self,
79 step: Optional[int] = None,
80 title: str = "SFW FFT with Inverted Watermark Area",
81 cmap: str = "viridis",
82 use_color_bar: bool = True,
83 vmin: Optional[float] = None,
84 vmax: Optional[float] = None,
85 ax: Optional[Axes] = None,
86 **kwargs) -> Axes:
87 """
88 Draw FFT visualization with inverted pattern, with all 0 background.
90 Parameters:
91 step (Optional[int]): The timestep of the inverted latents. If None, the last timestep is used.
92 title (str): The title of the plot.
93 cmap (str): The colormap to use.
94 use_color_bar (bool): Whether to display the colorbar.
95 ax (Axes): The axes to plot on.
97 Returns:
98 Axes: The plotted axes.
99 """
100 if step is None:
101 inverted_latent = self.data.reversed_latents[self.watermarking_step][0, self.data.w_channel]
102 else:
103 inverted_latent = self.data.reversed_latents[step][0, self.data.w_channel]
105 if self.data.wm_type=="HSQR":
106 gt_patch = self.data.gt_patch
107 # extract patch for channel (may be complex)
108 #patt = gt_patch[0, self.data.w_channel] if torch.is_tensor(gt_patch) and gt_patch.dim() == 4 else gt_patch
109 patt_np = gt_patch.detach().cpu().numpy() if torch.is_tensor(gt_patch) else np.array(gt_patch)
110 if patt_np.ndim ==4:
111 patt_np = patt_np[0]
112 # if multi-channel, try to squeeze to 2D
113 if patt_np.ndim == 3:
114 # (c, H, W) => choose first channel
115 patt_np = patt_np[0]
116 # use magnitude if complex
117 if np.iscomplexobj(patt_np):
118 patt_vis = np.abs(patt_np)
119 else:
120 patt_vis = np.abs(patt_np)
121 fft_np = self._fft_transform(inverted_latent)
122 H, W = fft_np.shape
123 ph, pw = patt_vis.shape
124 sh = (H - ph) // 2
125 sw = (W - pw) // 2
126 fft_data = torch.from_numpy(fft_np)
127 fft_vis = torch.zeros_like(fft_data)
128 patt_t = torch.from_numpy(patt_vis.astype(np.float32)).to(dtype=fft_vis.dtype, device=fft_vis.device)
129 if patt_t.shape != (ph, pw):
130 raise ValueError(f"Pattern shape mismatch after conversion: {patt_t.shape} vs {(ph,pw)}")
131 fft_vis[sh:sh+ph, sw:sw+pw] = patt_t
132 else:
133 watermarking_mask = self.data.watermarking_mask[0, self.data.w_channel].cpu()
135 fft_data = torch.from_numpy(self._fft_transform(inverted_latent))
136 fft_vis = torch.zeros_like(fft_data).to(fft_data.device)
137 fft_vis[watermarking_mask] = fft_data[watermarking_mask]
139 im = ax.imshow(np.abs(fft_vis.cpu().numpy()), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
140 if title != "":
141 ax.set_title(title)
142 if use_color_bar:
143 ax.figure.colorbar(im, ax=ax)
144 ax.axis('off')
146 return ax