Coverage for visualize / videoshield / video_shield_visualizer.py: 97.20%

143 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16import torch 

17import matplotlib.pyplot as plt 

18from matplotlib.axes import Axes 

19from matplotlib.gridspec import GridSpecFromSubplotSpec 

20import numpy as np 

21from typing import Optional 

22from visualize.base import BaseVisualizer 

23from visualize.data_for_visualization import DataForVisualization 

24from Crypto.Cipher import ChaCha20 

25 

26 

27class VideoShieldVisualizer(BaseVisualizer): 

28 """VideoShield watermark visualization class. 

29  

30 This visualizer handles watermark visualization for VideoShield algorithm, 

31 which extends Gaussian Shading to the video domain by adding frame dimensions. 

32  

33 Key Members for VideoShieldVisualizer: 

34 - self.data.orig_watermarked_latents: [B, C, F, H, W] 

35 - self.data.reversed_latents: List[[B, C, F, H, W]] 

36 """ 

37 

38 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1, is_video: bool = True): 

39 super().__init__(data_for_visualization, dpi, watermarking_step, is_video) 

40 

41 def _stream_key_decrypt(self, reversed_m: np.ndarray) -> torch.Tensor: 

42 """Decrypt the watermark using ChaCha20 cipher. 

43  

44 Args: 

45 reversed_m: Encrypted binary message array 

46  

47 Returns: 

48 Decrypted watermark tensor 

49 """ 

50 cipher = ChaCha20.new(key=self.data.chacha_key, nonce=self.data.chacha_nonce) 

51 

52 sd_byte = cipher.decrypt(np.packbits(reversed_m).tobytes()) 

53 sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8)) 

54 

55 return sd_bit 

56 

57 def _diffusion_inverse(self, reversed_sd: torch.Tensor) -> torch.Tensor: 

58 """Video-specific diffusion inverse with frame dimension handling. 

59  

60 Args: 

61 reversed_sd: Video watermark tensor with shape (B, C, F, H, W) 

62  

63 Returns: 

64 Extracted watermark pattern 

65 """ 

66 ch_stride = 4 // self.data.k_c 

67 frame_stride = self.data.num_frames // self.data.k_f 

68 h_stride = self.data.latents_height // self.data.k_h 

69 w_stride = self.data.latents_width // self.data.k_w 

70 

71 ch_list = [ch_stride] * self.data.k_c 

72 frame_list = [frame_stride] * self.data.k_f 

73 h_list = [h_stride] * self.data.k_h 

74 w_list = [w_stride] * self.data.k_w 

75 

76 # Split and reorganize dimensions for voting 

77 split_dim1 = torch.cat(torch.split(reversed_sd, tuple(ch_list), dim=1), dim=0) 

78 split_dim2 = torch.cat(torch.split(split_dim1, tuple(frame_list), dim=2), dim=0) 

79 split_dim3 = torch.cat(torch.split(split_dim2, tuple(h_list), dim=3), dim=0) 

80 split_dim4 = torch.cat(torch.split(split_dim3, tuple(w_list), dim=4), dim=0) 

81 

82 # Voting 

83 vote = torch.sum(split_dim4, dim=0).clone() 

84 vote[vote <= self.data.vote_threshold] = 0 

85 vote[vote > self.data.vote_threshold] = 1 

86 

87 return vote 

88 

89 def draw_watermark_bits(self, 

90 channel: Optional[int] = None, 

91 frame: Optional[int] = None, 

92 title: str = "Original Watermark Bits", 

93 cmap: str = "binary", 

94 ax: Optional[Axes] = None) -> Axes: 

95 """Draw the original watermark bits for VideoShield. 

96  

97 For video watermarks, this method can visualize specific frames or average 

98 across frames to create a 2D visualization. 

99  

100 Args: 

101 channel: The channel to visualize. If None, all channels are shown. 

102 frame: The frame to visualize. If None, uses middle frame for videos. 

103 title: The title of the plot. 

104 cmap: The colormap to use. 

105 ax: The axes to plot on. 

106  

107 Returns: 

108 The plotted axes. 

109 """ 

110 # Reshape watermark to video dimensions based on repetition factors 

111 # VideoShield watermark shape: [1, C//k_c, F//k_f, H//k_h, W//k_w] 

112 ch_stride = 4 // self.data.k_c 

113 frame_stride = self.data.num_frames // self.data.k_f 

114 h_stride = self.data.latents_height // self.data.k_h 

115 w_stride = self.data.latents_width // self.data.k_w 

116 

117 watermark = self.data.watermark.reshape(1, ch_stride, frame_stride, h_stride, w_stride) 

118 

119 if channel is not None: 

120 # Single channel visualization 

121 if channel >= ch_stride: 

122 raise ValueError(f"Channel {channel} is out of range. Max channel: {ch_stride - 1}") 

123 

124 # Select specific frame or use middle frame 

125 if frame is not None: 

126 if frame >= frame_stride: 

127 raise ValueError(f"Frame {frame} is out of range. Max frame: {frame_stride - 1}") 

128 watermark_data = watermark[0, channel, frame].cpu().numpy() 

129 frame_info = f" - Frame {frame}" 

130 else: 

131 # Use middle frame 

132 mid_frame = frame_stride // 2 

133 watermark_data = watermark[0, channel, mid_frame].cpu().numpy() 

134 frame_info = f" - Frame {mid_frame} (middle)" 

135 

136 im=ax.imshow(watermark_data, cmap=cmap, vmin=0, vmax=1, interpolation='nearest') 

137 if title != "": 

138 ax.set_title(f"{title} - Channel {channel}{frame_info}", fontsize=10) 

139 ax.axis('off') 

140 

141 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0) 

142 cbar.ax.set_visible(False) 

143 else: 

144 # Multi-channel visualization 

145 num_channels = ch_stride 

146 

147 # Calculate grid layout 

148 rows = int(np.ceil(np.sqrt(num_channels))) 

149 cols = int(np.ceil(num_channels / rows)) 

150 

151 # Clear the axis and set title 

152 ax.clear() 

153 if title != "": 

154 if frame is not None: 

155 ax.set_title(f"{title} - Frame {frame}", pad=20, fontsize=10) 

156 else: 

157 mid_frame = frame_stride // 2 

158 ax.set_title(f"{title} - Frame {mid_frame} (middle)", pad=20, fontsize=10) 

159 ax.axis('off') 

160 

161 # Use gridspec for better control 

162 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), 

163 wspace=0.3, hspace=0.4) 

164 

165 # Create subplots 

166 for i in range(num_channels): 

167 row_idx = i // cols 

168 col_idx = i % cols 

169 

170 # Create subplot using gridspec 

171 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

172 

173 # Select specific frame or use middle frame 

174 if frame is not None: 

175 if frame >= frame_stride: 

176 raise ValueError(f"Frame {frame} is out of range. Max frame: {frame_stride - 1}") 

177 watermark_data = watermark[0, i, frame].cpu().numpy() 

178 else: 

179 mid_frame = frame_stride // 2 

180 watermark_data = watermark[0, i, mid_frame].cpu().numpy() 

181 

182 # Draw the watermark channel 

183 sub_ax.imshow(watermark_data, cmap=cmap, vmin=0, vmax=1, interpolation='nearest') 

184 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3) 

185 sub_ax.axis('off') 

186 

187 return ax 

188 

189 def draw_reconstructed_watermark_bits(self, 

190 channel: Optional[int] = None, 

191 frame: Optional[int] = None, 

192 title: str = "Reconstructed Watermark Bits", 

193 cmap: str = "binary", 

194 ax: Optional[Axes] = None) -> Axes: 

195 """Draw the reconstructed watermark bits for VideoShield. 

196  

197 Args: 

198 channel: The channel to visualize. If None, all channels are shown. 

199 frame: The frame to visualize. If None, uses middle frame for videos. 

200 title: The title of the plot. 

201 cmap: The colormap to use. 

202 ax: The axes to plot on. 

203  

204 Returns: 

205 The plotted axes. 

206 """ 

207 # Step 1: Get reversed latents and reconstruct the watermark bits 

208 reversed_latent = self.data.reversed_latents[self.watermarking_step] 

209 

210 # Convert to binary bits 

211 reversed_m = (reversed_latent > 0).int() 

212 

213 # Decrypt  

214 reversed_sd_flat = self._stream_key_decrypt(reversed_m.flatten().cpu().numpy()) 

215 # Reshape back to video tensor 

216 reversed_sd = torch.from_numpy(reversed_sd_flat).reshape(reversed_latent.shape).to(torch.uint8) 

217 

218 # Extract watermark through voting mechanism 

219 reversed_watermark = self._diffusion_inverse(reversed_sd.to(self.data.device)) 

220 

221 # Calculate bit accuracy 

222 bit_acc = (reversed_watermark == self.data.watermark).float().mean().item() 

223 

224 # Reshape to video dimensions for visualization 

225 ch_stride = 4 // self.data.k_c 

226 frame_stride = self.data.num_frames // self.data.k_f 

227 h_stride = self.data.latents_height // self.data.k_h 

228 w_stride = self.data.latents_width // self.data.k_w 

229 

230 reconstructed_watermark = reversed_watermark.reshape(1, ch_stride, frame_stride, h_stride, w_stride) 

231 

232 if channel is not None: 

233 # Single channel visualization 

234 if channel >= ch_stride: 

235 raise ValueError(f"Channel {channel} is out of range. Max channel: {ch_stride - 1}") 

236 

237 # Select specific frame or use middle frame 

238 if frame is not None: 

239 if frame >= frame_stride: 

240 raise ValueError(f"Frame {frame} is out of range. Max frame: {frame_stride - 1}") 

241 reconstructed_watermark_data = reconstructed_watermark[0, channel, frame].cpu().numpy() 

242 frame_info = f" - Frame {frame}" 

243 else: 

244 # Use middle frame 

245 mid_frame = frame_stride // 2 

246 reconstructed_watermark_data = reconstructed_watermark[0, channel, mid_frame].cpu().numpy() 

247 frame_info = f" - Frame {mid_frame} (middle)" 

248 

249 im=ax.imshow(reconstructed_watermark_data, cmap=cmap, vmin=0, vmax=1, interpolation='nearest') 

250 if title != "": 

251 ax.set_title(f"{title} - Channel {channel}{frame_info} (Bit Acc: {bit_acc:.3f})", fontsize=10) 

252 else: 

253 ax.set_title(f"Channel {channel}{frame_info} (Bit Acc: {bit_acc:.3f})", fontsize=10) 

254 ax.axis('off') 

255 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0) 

256 cbar.ax.set_visible(False) 

257 else: 

258 # Multi-channel visualization 

259 num_channels = ch_stride 

260 

261 # Calculate grid layout 

262 rows = int(np.ceil(np.sqrt(num_channels))) 

263 cols = int(np.ceil(num_channels / rows)) 

264 

265 # Clear the axis and set title with bit accuracy 

266 ax.clear() 

267 if title != "": 

268 if frame is not None: 

269 ax.set_title(f'{title} - Frame {frame} (Bit Acc: {bit_acc:.3f})', pad=20, fontsize=10) 

270 else: 

271 mid_frame = frame_stride // 2 

272 ax.set_title(f'{title} - Frame {mid_frame} (middle) (Bit Acc: {bit_acc:.3f})', pad=20, fontsize=10) 

273 else: 

274 if frame is not None: 

275 ax.set_title(f'Frame {frame} (Bit Acc: {bit_acc:.3f})', pad=20, fontsize=10) 

276 else: 

277 mid_frame = frame_stride // 2 

278 ax.set_title(f'Frame {mid_frame} (middle) (Bit Acc: {bit_acc:.3f})', pad=20, fontsize=10) 

279 ax.axis('off') 

280 

281 # Use gridspec for better control 

282 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(), 

283 wspace=0.3, hspace=0.4) 

284 

285 # Create subplots 

286 for i in range(num_channels): 

287 row_idx = i // cols 

288 col_idx = i % cols 

289 

290 # Create subplot using gridspec 

291 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx]) 

292 

293 # Select specific frame or use middle frame 

294 if frame is not None: 

295 if frame >= frame_stride: 

296 raise ValueError(f"Frame {frame} is out of range. Max frame: {frame_stride - 1}") 

297 reconstructed_watermark_data = reconstructed_watermark[0, i, frame].cpu().numpy() 

298 else: 

299 mid_frame = frame_stride // 2 

300 reconstructed_watermark_data = reconstructed_watermark[0, i, mid_frame].cpu().numpy() 

301 

302 # Draw the reconstructed watermark channel 

303 sub_ax.imshow(reconstructed_watermark_data, cmap=cmap, vmin=0, vmax=1, interpolation='nearest') 

304 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3) 

305 sub_ax.axis('off') 

306 

307 return ax 

308 

309 def draw_watermarked_video_frames(self, 

310 num_frames: int = 4, 

311 title: str = "Watermarked Video Frames", 

312 ax: Optional[Axes] = None) -> Axes: 

313 """ 

314 Draw multiple frames from the watermarked video. 

315 

316 DEPRECATED: 

317 This method is deprecated and will be removed in a future version. 

318 Please use `draw_watermarked_image` instead. 

319 

320 This method displays a grid of video frames to show the temporal 

321 consistency of the watermarked video. 

322 

323 Args: 

324 num_frames: Number of frames to display (default: 4) 

325 title: The title of the plot 

326 ax: The axes to plot on 

327 

328 Returns: 

329 The plotted axes 

330 """ 

331 return self._draw_video_frames( 

332 title=title, 

333 num_frames=num_frames, 

334 ax=ax 

335 )