Coverage for visualize / sfw / sfw_visualizer.py: 96.20%

79 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from typing import Optional 

2import torch 

3from matplotlib.axes import Axes 

4import numpy as np 

5from visualize.base import BaseVisualizer 

6from visualize.data_for_visualization import DataForVisualization 

7 

8class SFWVisualizer(BaseVisualizer): 

9 """SFW watermark visualization class""" 

10 

11 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1): 

12 super().__init__(data_for_visualization, dpi, watermarking_step) 

13 

14 def draw_pattern_fft(self, 

15 title: str = "SFW FFT with Watermark Area", 

16 cmap: str = "viridis", 

17 use_color_bar: bool = True, 

18 vmin: Optional[float] = None, 

19 vmax: Optional[float] = None, 

20 ax: Optional[Axes] = None, 

21 **kwargs) -> Axes: 

22 """ 

23 Draw FFT visualization with original watermark pattern, with all 0 background. 

24  

25 Parameters: 

26 title (str): The title of the plot. 

27 cmap (str): The colormap to use. 

28 use_color_bar (bool): Whether to display the colorbar. 

29 ax (Axes): The axes to plot on. 

30  

31 Returns: 

32 Axes: The plotted axes. 

33 """ 

34 orig_latent = self.data.orig_watermarked_latents[0, self.data.w_channel].cpu() 

35 if self.data.wm_type=="HSQR": 

36 gt_patch = self.data.gt_patch 

37 # extract patch for channel (may be complex) 

38 #patt = gt_patch[0, self.data.w_channel] if torch.is_tensor(gt_patch) and gt_patch.dim() == 4 else gt_patch 

39 patt_np = gt_patch.detach().cpu().numpy() if torch.is_tensor(gt_patch) else np.array(gt_patch) 

40 if patt_np.ndim ==4: 

41 patt_np = patt_np[0] 

42 # if multi-channel, try to squeeze to 2D 

43 if patt_np.ndim == 3: 

44 # (c, H, W) => choose first channel 

45 patt_np = patt_np[0] 

46 # use magnitude if complex 

47 if np.iscomplexobj(patt_np): 

48 patt_vis = np.abs(patt_np) 

49 else: 

50 patt_vis = np.abs(patt_np) 

51 fft_np = self._fft_transform(orig_latent) 

52 H, W = fft_np.shape 

53 ph, pw = patt_vis.shape 

54 sh = (H - ph) // 2 

55 sw = (W - pw) // 2 

56 fft_data = torch.from_numpy(fft_np) 

57 fft_vis = torch.zeros_like(fft_data) 

58 patt_t = torch.from_numpy(patt_vis.astype(np.float32)).to(dtype=fft_vis.dtype, device=fft_vis.device) 

59 if patt_t.shape != (ph, pw): 

60 raise ValueError(f"Pattern shape mismatch after conversion: {patt_t.shape} vs {(ph,pw)}") 

61 fft_vis[sh:sh+ph, sw:sw+pw] = patt_t 

62 else: 

63 watermarking_mask = self.data.watermarking_mask[0, self.data.w_channel].cpu() 

64 

65 fft_data = torch.from_numpy(self._fft_transform(orig_latent)) 

66 fft_vis = torch.zeros_like(fft_data) 

67 fft_vis[watermarking_mask] = fft_data[watermarking_mask] 

68 

69 im = ax.imshow(np.abs(fft_vis.cpu().numpy()), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

70 if title != "": 

71 ax.set_title(title) 

72 if use_color_bar: 

73 ax.figure.colorbar(im, ax=ax) 

74 ax.axis('off') 

75 

76 return ax 

77 

78 def draw_inverted_pattern_fft(self, 

79 step: Optional[int] = None, 

80 title: str = "SFW FFT with Inverted Watermark Area", 

81 cmap: str = "viridis", 

82 use_color_bar: bool = True, 

83 vmin: Optional[float] = None, 

84 vmax: Optional[float] = None, 

85 ax: Optional[Axes] = None, 

86 **kwargs) -> Axes: 

87 """ 

88 Draw FFT visualization with inverted pattern, with all 0 background. 

89  

90 Parameters: 

91 step (Optional[int]): The timestep of the inverted latents. If None, the last timestep is used. 

92 title (str): The title of the plot. 

93 cmap (str): The colormap to use. 

94 use_color_bar (bool): Whether to display the colorbar. 

95 ax (Axes): The axes to plot on. 

96  

97 Returns: 

98 Axes: The plotted axes. 

99 """ 

100 if step is None: 

101 inverted_latent = self.data.reversed_latents[self.watermarking_step][0, self.data.w_channel] 

102 else: 

103 inverted_latent = self.data.reversed_latents[step][0, self.data.w_channel] 

104 

105 if self.data.wm_type=="HSQR": 

106 gt_patch = self.data.gt_patch 

107 # extract patch for channel (may be complex) 

108 #patt = gt_patch[0, self.data.w_channel] if torch.is_tensor(gt_patch) and gt_patch.dim() == 4 else gt_patch 

109 patt_np = gt_patch.detach().cpu().numpy() if torch.is_tensor(gt_patch) else np.array(gt_patch) 

110 if patt_np.ndim ==4: 

111 patt_np = patt_np[0] 

112 # if multi-channel, try to squeeze to 2D 

113 if patt_np.ndim == 3: 

114 # (c, H, W) => choose first channel 

115 patt_np = patt_np[0] 

116 # use magnitude if complex 

117 if np.iscomplexobj(patt_np): 

118 patt_vis = np.abs(patt_np) 

119 else: 

120 patt_vis = np.abs(patt_np) 

121 fft_np = self._fft_transform(inverted_latent) 

122 H, W = fft_np.shape 

123 ph, pw = patt_vis.shape 

124 sh = (H - ph) // 2 

125 sw = (W - pw) // 2 

126 fft_data = torch.from_numpy(fft_np) 

127 fft_vis = torch.zeros_like(fft_data) 

128 patt_t = torch.from_numpy(patt_vis.astype(np.float32)).to(dtype=fft_vis.dtype, device=fft_vis.device) 

129 if patt_t.shape != (ph, pw): 

130 raise ValueError(f"Pattern shape mismatch after conversion: {patt_t.shape} vs {(ph,pw)}") 

131 fft_vis[sh:sh+ph, sw:sw+pw] = patt_t 

132 else: 

133 watermarking_mask = self.data.watermarking_mask[0, self.data.w_channel].cpu() 

134 

135 fft_data = torch.from_numpy(self._fft_transform(inverted_latent)) 

136 fft_vis = torch.zeros_like(fft_data).to(fft_data.device) 

137 fft_vis[watermarking_mask] = fft_data[watermarking_mask] 

138 

139 im = ax.imshow(np.abs(fft_vis.cpu().numpy()), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

140 if title != "": 

141 ax.set_title(title) 

142 if use_color_bar: 

143 ax.figure.colorbar(im, ax=ax) 

144 ax.axis('off') 

145 

146 return ax 

147