Coverage for visualize / videoshield / video_shield_visualizer.py: 97.20%
143 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16import torch
17import matplotlib.pyplot as plt
18from matplotlib.axes import Axes
19from matplotlib.gridspec import GridSpecFromSubplotSpec
20import numpy as np
21from typing import Optional
22from visualize.base import BaseVisualizer
23from visualize.data_for_visualization import DataForVisualization
24from Crypto.Cipher import ChaCha20
27class VideoShieldVisualizer(BaseVisualizer):
28 """VideoShield watermark visualization class.
30 This visualizer handles watermark visualization for VideoShield algorithm,
31 which extends Gaussian Shading to the video domain by adding frame dimensions.
33 Key Members for VideoShieldVisualizer:
34 - self.data.orig_watermarked_latents: [B, C, F, H, W]
35 - self.data.reversed_latents: List[[B, C, F, H, W]]
36 """
38 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1, is_video: bool = True):
39 super().__init__(data_for_visualization, dpi, watermarking_step, is_video)
41 def _stream_key_decrypt(self, reversed_m: np.ndarray) -> torch.Tensor:
42 """Decrypt the watermark using ChaCha20 cipher.
44 Args:
45 reversed_m: Encrypted binary message array
47 Returns:
48 Decrypted watermark tensor
49 """
50 cipher = ChaCha20.new(key=self.data.chacha_key, nonce=self.data.chacha_nonce)
52 sd_byte = cipher.decrypt(np.packbits(reversed_m).tobytes())
53 sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8))
55 return sd_bit
57 def _diffusion_inverse(self, reversed_sd: torch.Tensor) -> torch.Tensor:
58 """Video-specific diffusion inverse with frame dimension handling.
60 Args:
61 reversed_sd: Video watermark tensor with shape (B, C, F, H, W)
63 Returns:
64 Extracted watermark pattern
65 """
66 ch_stride = 4 // self.data.k_c
67 frame_stride = self.data.num_frames // self.data.k_f
68 h_stride = self.data.latents_height // self.data.k_h
69 w_stride = self.data.latents_width // self.data.k_w
71 ch_list = [ch_stride] * self.data.k_c
72 frame_list = [frame_stride] * self.data.k_f
73 h_list = [h_stride] * self.data.k_h
74 w_list = [w_stride] * self.data.k_w
76 # Split and reorganize dimensions for voting
77 split_dim1 = torch.cat(torch.split(reversed_sd, tuple(ch_list), dim=1), dim=0)
78 split_dim2 = torch.cat(torch.split(split_dim1, tuple(frame_list), dim=2), dim=0)
79 split_dim3 = torch.cat(torch.split(split_dim2, tuple(h_list), dim=3), dim=0)
80 split_dim4 = torch.cat(torch.split(split_dim3, tuple(w_list), dim=4), dim=0)
82 # Voting
83 vote = torch.sum(split_dim4, dim=0).clone()
84 vote[vote <= self.data.vote_threshold] = 0
85 vote[vote > self.data.vote_threshold] = 1
87 return vote
89 def draw_watermark_bits(self,
90 channel: Optional[int] = None,
91 frame: Optional[int] = None,
92 title: str = "Original Watermark Bits",
93 cmap: str = "binary",
94 ax: Optional[Axes] = None) -> Axes:
95 """Draw the original watermark bits for VideoShield.
97 For video watermarks, this method can visualize specific frames or average
98 across frames to create a 2D visualization.
100 Args:
101 channel: The channel to visualize. If None, all channels are shown.
102 frame: The frame to visualize. If None, uses middle frame for videos.
103 title: The title of the plot.
104 cmap: The colormap to use.
105 ax: The axes to plot on.
107 Returns:
108 The plotted axes.
109 """
110 # Reshape watermark to video dimensions based on repetition factors
111 # VideoShield watermark shape: [1, C//k_c, F//k_f, H//k_h, W//k_w]
112 ch_stride = 4 // self.data.k_c
113 frame_stride = self.data.num_frames // self.data.k_f
114 h_stride = self.data.latents_height // self.data.k_h
115 w_stride = self.data.latents_width // self.data.k_w
117 watermark = self.data.watermark.reshape(1, ch_stride, frame_stride, h_stride, w_stride)
119 if channel is not None:
120 # Single channel visualization
121 if channel >= ch_stride:
122 raise ValueError(f"Channel {channel} is out of range. Max channel: {ch_stride - 1}")
124 # Select specific frame or use middle frame
125 if frame is not None:
126 if frame >= frame_stride:
127 raise ValueError(f"Frame {frame} is out of range. Max frame: {frame_stride - 1}")
128 watermark_data = watermark[0, channel, frame].cpu().numpy()
129 frame_info = f" - Frame {frame}"
130 else:
131 # Use middle frame
132 mid_frame = frame_stride // 2
133 watermark_data = watermark[0, channel, mid_frame].cpu().numpy()
134 frame_info = f" - Frame {mid_frame} (middle)"
136 im=ax.imshow(watermark_data, cmap=cmap, vmin=0, vmax=1, interpolation='nearest')
137 if title != "":
138 ax.set_title(f"{title} - Channel {channel}{frame_info}", fontsize=10)
139 ax.axis('off')
141 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0)
142 cbar.ax.set_visible(False)
143 else:
144 # Multi-channel visualization
145 num_channels = ch_stride
147 # Calculate grid layout
148 rows = int(np.ceil(np.sqrt(num_channels)))
149 cols = int(np.ceil(num_channels / rows))
151 # Clear the axis and set title
152 ax.clear()
153 if title != "":
154 if frame is not None:
155 ax.set_title(f"{title} - Frame {frame}", pad=20, fontsize=10)
156 else:
157 mid_frame = frame_stride // 2
158 ax.set_title(f"{title} - Frame {mid_frame} (middle)", pad=20, fontsize=10)
159 ax.axis('off')
161 # Use gridspec for better control
162 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(),
163 wspace=0.3, hspace=0.4)
165 # Create subplots
166 for i in range(num_channels):
167 row_idx = i // cols
168 col_idx = i % cols
170 # Create subplot using gridspec
171 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
173 # Select specific frame or use middle frame
174 if frame is not None:
175 if frame >= frame_stride:
176 raise ValueError(f"Frame {frame} is out of range. Max frame: {frame_stride - 1}")
177 watermark_data = watermark[0, i, frame].cpu().numpy()
178 else:
179 mid_frame = frame_stride // 2
180 watermark_data = watermark[0, i, mid_frame].cpu().numpy()
182 # Draw the watermark channel
183 sub_ax.imshow(watermark_data, cmap=cmap, vmin=0, vmax=1, interpolation='nearest')
184 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3)
185 sub_ax.axis('off')
187 return ax
189 def draw_reconstructed_watermark_bits(self,
190 channel: Optional[int] = None,
191 frame: Optional[int] = None,
192 title: str = "Reconstructed Watermark Bits",
193 cmap: str = "binary",
194 ax: Optional[Axes] = None) -> Axes:
195 """Draw the reconstructed watermark bits for VideoShield.
197 Args:
198 channel: The channel to visualize. If None, all channels are shown.
199 frame: The frame to visualize. If None, uses middle frame for videos.
200 title: The title of the plot.
201 cmap: The colormap to use.
202 ax: The axes to plot on.
204 Returns:
205 The plotted axes.
206 """
207 # Step 1: Get reversed latents and reconstruct the watermark bits
208 reversed_latent = self.data.reversed_latents[self.watermarking_step]
210 # Convert to binary bits
211 reversed_m = (reversed_latent > 0).int()
213 # Decrypt
214 reversed_sd_flat = self._stream_key_decrypt(reversed_m.flatten().cpu().numpy())
215 # Reshape back to video tensor
216 reversed_sd = torch.from_numpy(reversed_sd_flat).reshape(reversed_latent.shape).to(torch.uint8)
218 # Extract watermark through voting mechanism
219 reversed_watermark = self._diffusion_inverse(reversed_sd.to(self.data.device))
221 # Calculate bit accuracy
222 bit_acc = (reversed_watermark == self.data.watermark).float().mean().item()
224 # Reshape to video dimensions for visualization
225 ch_stride = 4 // self.data.k_c
226 frame_stride = self.data.num_frames // self.data.k_f
227 h_stride = self.data.latents_height // self.data.k_h
228 w_stride = self.data.latents_width // self.data.k_w
230 reconstructed_watermark = reversed_watermark.reshape(1, ch_stride, frame_stride, h_stride, w_stride)
232 if channel is not None:
233 # Single channel visualization
234 if channel >= ch_stride:
235 raise ValueError(f"Channel {channel} is out of range. Max channel: {ch_stride - 1}")
237 # Select specific frame or use middle frame
238 if frame is not None:
239 if frame >= frame_stride:
240 raise ValueError(f"Frame {frame} is out of range. Max frame: {frame_stride - 1}")
241 reconstructed_watermark_data = reconstructed_watermark[0, channel, frame].cpu().numpy()
242 frame_info = f" - Frame {frame}"
243 else:
244 # Use middle frame
245 mid_frame = frame_stride // 2
246 reconstructed_watermark_data = reconstructed_watermark[0, channel, mid_frame].cpu().numpy()
247 frame_info = f" - Frame {mid_frame} (middle)"
249 im=ax.imshow(reconstructed_watermark_data, cmap=cmap, vmin=0, vmax=1, interpolation='nearest')
250 if title != "":
251 ax.set_title(f"{title} - Channel {channel}{frame_info} (Bit Acc: {bit_acc:.3f})", fontsize=10)
252 else:
253 ax.set_title(f"Channel {channel}{frame_info} (Bit Acc: {bit_acc:.3f})", fontsize=10)
254 ax.axis('off')
255 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0)
256 cbar.ax.set_visible(False)
257 else:
258 # Multi-channel visualization
259 num_channels = ch_stride
261 # Calculate grid layout
262 rows = int(np.ceil(np.sqrt(num_channels)))
263 cols = int(np.ceil(num_channels / rows))
265 # Clear the axis and set title with bit accuracy
266 ax.clear()
267 if title != "":
268 if frame is not None:
269 ax.set_title(f'{title} - Frame {frame} (Bit Acc: {bit_acc:.3f})', pad=20, fontsize=10)
270 else:
271 mid_frame = frame_stride // 2
272 ax.set_title(f'{title} - Frame {mid_frame} (middle) (Bit Acc: {bit_acc:.3f})', pad=20, fontsize=10)
273 else:
274 if frame is not None:
275 ax.set_title(f'Frame {frame} (Bit Acc: {bit_acc:.3f})', pad=20, fontsize=10)
276 else:
277 mid_frame = frame_stride // 2
278 ax.set_title(f'Frame {mid_frame} (middle) (Bit Acc: {bit_acc:.3f})', pad=20, fontsize=10)
279 ax.axis('off')
281 # Use gridspec for better control
282 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(),
283 wspace=0.3, hspace=0.4)
285 # Create subplots
286 for i in range(num_channels):
287 row_idx = i // cols
288 col_idx = i % cols
290 # Create subplot using gridspec
291 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
293 # Select specific frame or use middle frame
294 if frame is not None:
295 if frame >= frame_stride:
296 raise ValueError(f"Frame {frame} is out of range. Max frame: {frame_stride - 1}")
297 reconstructed_watermark_data = reconstructed_watermark[0, i, frame].cpu().numpy()
298 else:
299 mid_frame = frame_stride // 2
300 reconstructed_watermark_data = reconstructed_watermark[0, i, mid_frame].cpu().numpy()
302 # Draw the reconstructed watermark channel
303 sub_ax.imshow(reconstructed_watermark_data, cmap=cmap, vmin=0, vmax=1, interpolation='nearest')
304 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3)
305 sub_ax.axis('off')
307 return ax
309 def draw_watermarked_video_frames(self,
310 num_frames: int = 4,
311 title: str = "Watermarked Video Frames",
312 ax: Optional[Axes] = None) -> Axes:
313 """
314 Draw multiple frames from the watermarked video.
316 DEPRECATED:
317 This method is deprecated and will be removed in a future version.
318 Please use `draw_watermarked_image` instead.
320 This method displays a grid of video frames to show the temporal
321 consistency of the watermarked video.
323 Args:
324 num_frames: Number of frames to display (default: 4)
325 title: The title of the plot
326 ax: The axes to plot on
328 Returns:
329 The plotted axes
330 """
331 return self._draw_video_frames(
332 title=title,
333 num_frames=num_frames,
334 ax=ax
335 )