Coverage for visualize / gs / gs_visualizer.py: 94.85%

97 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1import torch 

2import matplotlib.pyplot as plt 

3from matplotlib.axes import Axes 

4from matplotlib.gridspec import GridSpecFromSubplotSpec 

5from typing import Optional 

6import numpy as np 

7from visualize.base import BaseVisualizer 

8from visualize.data_for_visualization import DataForVisualization 

9from Crypto.Cipher import ChaCha20 

10 

11class GaussianShadingVisualizer(BaseVisualizer): 

12 """Gaussian Shading watermark visualization class""" 

13 

14 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1): 

15 super().__init__(data_for_visualization, dpi, watermarking_step) 

16 

17 def _stream_key_decrypt(self, reversed_m): 

18 """Decrypt the watermark using ChaCha20 cipher.""" 

19 cipher = ChaCha20.new(key=self.data.chacha_key, nonce=self.data.chacha_nonce) 

20 sd_byte = cipher.decrypt(np.packbits(reversed_m).tobytes()) 

21 sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8)) 

22 sd_tensor = torch.from_numpy(sd_bit).reshape(1, 4, 64, 64).to(torch.uint8) 

23 return sd_tensor.to(self.data.device) 

24 

25 def _diffusion_inverse(self, reversed_sd): 

26 """Inverse the diffusion process to extract the watermark.""" 

27 ch_stride = 4 // self.data.channel_copy 

28 hw_stride = 64 // self.data.hw_copy 

29 ch_list = [ch_stride] * self.data.channel_copy 

30 hw_list = [hw_stride] * self.data.hw_copy 

31 split_dim1 = torch.cat(torch.split(reversed_sd, tuple(ch_list), dim=1), dim=0) 

32 split_dim2 = torch.cat(torch.split(split_dim1, tuple(hw_list), dim=2), dim=0) 

33 split_dim3 = torch.cat(torch.split(split_dim2, tuple(hw_list), dim=3), dim=0) 

34 vote = torch.sum(split_dim3, dim=0).clone() 

35 vote[vote <= self.data.vote_threshold] = 0 

36 vote[vote > self.data.vote_threshold] = 1 

37 return vote 

38 

39 def draw_watermark_bits(self, 

40 channel: Optional[int] = None, 

41 title: str = "Original Watermark Bits", 

42 cmap: str = "binary", 

43 ax: Optional[Axes] = None) -> Axes: 

44 """ 

45 Draw the original watermark bits.(sd in GS class). draw ch // channel_copy images in one ax. 

46  

47 Parameters: 

48 channel (Optional[int]): The channel to visualize. If None, all channels are shown. 

49 title (str): The title of the plot. 

50 cmap (str): The colormap to use. 

51 ax (Axes): The axes to plot on. 

52  

53 Returns: 

54 Axes: The plotted axes. 

55 """ 

56 # Step 1: reshape self.data.watermark to [1, 4 // self.data.channel_copy, 64 // self.data.hw_copy, 64 // self.data.hw_copy] 

57 watermark = self.data.watermark.reshape(1, 4 // self.data.channel_copy, 64 // self.data.hw_copy, 64 // self.data.hw_copy) 

58 

59 if channel is not None: 

60 # Single channel visualization 

61 if channel >= 4 // self.data.channel_copy: 

62 raise ValueError(f"Channel {channel} is out of range. Max channel: {4 // self.data.channel_copy - 1}") 

63 

64 watermark_data = watermark[0, channel].cpu().numpy() 

65 im = ax.imshow(watermark_data, cmap=cmap, vmin=0, vmax=1, interpolation='nearest') 

66 if title != "": 

67 ax.set_title(f"{title} - Channel {channel}", fontsize=10) 

68 ax.axis('off') 

69 

70 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0) 

71 cbar.ax.set_visible(False) 

72 else: 

73 # Multi-channel visualization 

74 # Step 2: draw watermark for (4 // self.data.channel_copy) images in this ax 

75 num_channels = 4 // self.data.channel_copy 

76 

77 # Calculate grid layout 

78 rows = int(np.ceil(np.sqrt(num_channels))) 

79 cols = int(np.ceil(num_channels / rows)) 

80 

81 # Clear the axis and set title 

82 ax.clear() 

83 if title != "": 

84 ax.set_title(title, pad=20) 

85 ax.axis('off') 

86 

87 # Use gridspec for better control 

88 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), 

89 wspace=0.3, hspace=0.4) 

90 

91 # Create subplots 

92 for i in range(num_channels): 

93 row_idx = i // cols 

94 col_idx = i % cols 

95 

96 # Create subplot using gridspec 

97 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

98 

99 # Draw the watermark channel 

100 watermark_data = watermark[0, i].cpu().numpy() 

101 sub_ax.imshow(watermark_data, cmap=cmap, vmin=0, vmax=1, interpolation='nearest') 

102 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3) 

103 sub_ax.axis('off') 

104 

105 return ax 

106 

107 def draw_reconstructed_watermark_bits(self, 

108 channel: Optional[int] = None, 

109 title: str = "Reconstructed Watermark Bits", 

110 cmap: str = "binary", 

111 ax: Optional[Axes] = None) -> Axes: 

112 """ 

113 Draw the reconstructed watermark bits.(reversed_latents in GS class). draw ch // channel_copy images in one ax. 

114  

115 Parameters: 

116 channel (Optional[int]): The channel to visualize. If None, all channels are shown. 

117 title (str): The title of the plot. 

118 cmap (str): The colormap to use. 

119 ax (Axes): The axes to plot on. 

120  

121 Returns: 

122 Axes: The plotted axes. 

123 """ 

124 # Step 1: reconstruct the watermark bits 

125 reversed_latent = self.data.reversed_latents[self.watermarking_step] 

126 

127 reversed_m = (reversed_latent > 0).int() 

128 if self.data.chacha: 

129 reversed_sd = self._stream_key_decrypt(reversed_m.flatten().cpu().numpy()) 

130 else: 

131 reversed_sd = (reversed_m + self.data.key) % 2 

132 

133 reversed_watermark = self._diffusion_inverse(reversed_sd) 

134 bit_acc = (reversed_watermark == self.data.watermark).float().mean().item() 

135 reconstructed_watermark = reversed_watermark.reshape(1, 4 // self.data.channel_copy, 64 // self.data.hw_copy, 64 // self.data.hw_copy) 

136 

137 if channel is not None: 

138 # Single channel visualization 

139 if channel >= 4 // self.data.channel_copy: 

140 raise ValueError(f"Channel {channel} is out of range. Max channel: {4 // self.data.channel_copy - 1}") 

141 

142 reconstructed_watermark_data = reconstructed_watermark[0, channel].cpu().numpy() 

143 im = ax.imshow(reconstructed_watermark_data, cmap=cmap, vmin=0, vmax=1, interpolation='nearest') 

144 if title != "": 

145 ax.set_title(f"{title} - Channel {channel} (Bit Acc: {bit_acc:.3f})", fontsize=10) 

146 else: 

147 ax.set_title(f"Channel {channel} (Bit Acc: {bit_acc:.3f})", fontsize=10) 

148 ax.axis('off') 

149 

150 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0) 

151 cbar.ax.set_visible(False) 

152 else: 

153 # Multi-channel visualization 

154 # Step 2: draw reconstructed_watermark for (4 // self.data.channel_copy) images in this ax(add Bit_acc to the title) 

155 num_channels = 4 // self.data.channel_copy 

156 

157 # Calculate grid layout 

158 rows = int(np.ceil(np.sqrt(num_channels))) 

159 cols = int(np.ceil(num_channels / rows)) 

160 

161 # Clear the axis and set title with bit accuracy 

162 ax.clear() 

163 if title != "": 

164 ax.set_title(f'{title} (Bit Acc: {bit_acc:.3f})', pad=20) 

165 else: 

166 ax.set_title(f'(Bit Acc: {bit_acc:.3f})', pad=20) 

167 ax.axis('off') 

168 

169 # Use gridspec for better control 

170 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), 

171 wspace=0.3, hspace=0.4) 

172 

173 # Create subplots 

174 for i in range(num_channels): 

175 row_idx = i // cols 

176 col_idx = i % cols 

177 

178 # Create subplot using gridspec 

179 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

180 

181 # Draw the reconstructed watermark channel 

182 reconstructed_watermark_data = reconstructed_watermark[0, i].cpu().numpy() 

183 sub_ax.imshow(reconstructed_watermark_data, cmap=cmap, vmin=0, vmax=1, interpolation='nearest') 

184 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3) 

185 sub_ax.axis('off') 

186 

187 return ax