Coverage for detection / videoshield / videoshield_detection.py: 92.92%
113 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 10:24 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 10:24 +0000
1import torch
2import numpy as np
3from typing import Dict, Optional, Tuple, Union
4from detection.base import BaseDetector
5from Crypto.Cipher import ChaCha20
6from scipy.stats import truncnorm, norm
7from functools import reduce
8import logging
10logger = logging.getLogger(__name__)
13class VideoShieldDetector(BaseDetector):
14 """VideoShield watermark detector class."""
16 def __init__(self,
17 watermark: torch.Tensor,
18 threshold: float,
19 device: torch.device,
20 chacha_key: Optional[bytes] = None,
21 chacha_nonce: Optional[bytes] = None,
22 height: int = 64,
23 width: int = 64,
24 num_frames: int = 0,
25 k_f: int = 8,
26 k_c: int = 1,
27 k_h: int = 4,
28 k_w: int = 4) -> None:
29 """Initialize the VideoShield detector.
31 Args:
32 watermark: The watermarking bits
33 threshold: Threshold for watermark detection
34 device: The device to use for computation
35 chacha_key: ChaCha20 encryption key (optional)
36 chacha_nonce: ChaCha20 nonce (optional)
37 height: Height of the video
38 width: Width of the video
39 num_frames: Number of frames in the video
40 k_f: Frame repetition factor
41 k_c: Channel repetition factor
42 k_h: Height repetition factor
43 k_w: Width repetition factor
44 """
45 super().__init__(threshold, device)
46 self.watermark = watermark.to(device)
47 self.chacha_key = chacha_key
48 self.chacha_nonce = chacha_nonce
49 self.num_frames = num_frames
50 self.height = height
51 self.width = width
52 # Repetition factors
53 self.k_f = k_f
54 self.k_c = k_c
55 self.k_h = k_h
56 self.k_w = k_w
58 # Calculate voting threshold
59 if k_f == 1 and k_c == 1 and k_h == 1 and k_w == 1:
60 self.vote_threshold = 1
61 else:
62 self.vote_threshold = (k_f * k_c * k_h * k_w) // 2
64 def _stream_key_decrypt(self, reversed_m: np.ndarray) -> np.ndarray:
65 """Decrypt the watermark using ChaCha20 cipher."""
66 if self.chacha_key is None or self.chacha_nonce is None:
67 # If no encryption keys provided, return as-is
68 return reversed_m
70 cipher = ChaCha20.new(key=self.chacha_key, nonce=self.chacha_nonce)
71 sd_byte = cipher.decrypt(np.packbits(reversed_m).tobytes())
72 sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8))
74 return sd_bit
76 def _diffusion_inverse(self, watermark_r: torch.Tensor, is_video: bool = False) -> torch.Tensor:
77 """Inverse the diffusion process to extract the watermark through voting.
79 Args:
80 watermark_r: The reversed watermark tensor
81 is_video: Whether this is video (5D) or image (4D) data
83 Returns:
84 Extracted watermark through voting
85 """
86 if is_video and watermark_r.dim() == 5:
87 return self._video_diffusion_inverse(watermark_r)
88 # else:
89 # return self._image_diffusion_inverse(watermark_r)
91 def _video_diffusion_inverse(self, watermark_r: torch.Tensor) -> torch.Tensor:
92 """Video-specific diffusion inverse with frame dimension handling."""
93 batch, channels, frames, height, width = watermark_r.shape
95 expected_frames = getattr(self, "num_frames", 0)
96 frames_to_use = frames
98 if expected_frames:
99 if frames != expected_frames:
100 logger.warning(
101 "Frame count mismatch detected: received %s frames, expected %s frames.",
102 frames,
103 expected_frames,
104 )
105 frames_to_use = min(frames, expected_frames)
106 logger.info("Truncated to the first %d frames for detection.", frames_to_use)
108 if frames_to_use != frames:
109 watermark_r = watermark_r[:, :, :frames_to_use, :, :]
110 frames = frames_to_use
112 if frames < self.k_f:
113 logger.error(
114 "VideoShield detector cannot process %s frames with repetition factor %s.",
115 frames,
116 self.k_f,
117 )
118 return torch.zeros_like(self.watermark)
120 remainder = frames % self.k_f
121 if remainder:
122 aligned_frames = frames - remainder
123 if aligned_frames <= 0:
124 logger.error(
125 "Unable to align frame count (%s) with repetition factor %s.",
126 frames,
127 self.k_f,
128 )
129 return torch.zeros_like(self.watermark)
130 logger.info(
131 "Aligning detection frames to %s frames to satisfy repetition factor %s.",
132 self.k_f,
133 aligned_frames,
134 )
135 watermark_r = watermark_r[:, :, :aligned_frames, :, :]
136 frames = aligned_frames
138 ch_stride = channels // self.k_c
139 frame_stride = frames // self.k_f
140 h_stride = height // self.k_h
141 w_stride = width // self.k_w
143 # Ensure strides are at least 1
144 ch_stride = max(1, ch_stride)
145 frame_stride = max(1, frame_stride)
146 h_stride = max(1, h_stride)
147 w_stride = max(1, w_stride)
149 # Adjust repetition factors if dimensions are too small
150 k_c = min(self.k_c, channels)
151 k_f = min(self.k_f, frames)
152 k_h = min(self.k_h, height)
153 k_w = min(self.k_w, width)
155 ch_list = [ch_stride] * k_c
156 frame_list = [frame_stride] * k_f
157 h_list = [h_stride] * k_h
158 w_list = [w_stride] * k_w
160 # Handle remainder pixels
161 if sum(ch_list) < channels:
162 ch_list[-1] += channels - sum(ch_list)
163 if sum(frame_list) < frames:
164 frame_list[-1] += frames - sum(frame_list)
165 if sum(h_list) < height:
166 h_list[-1] += height - sum(h_list)
167 if sum(w_list) < width:
168 w_list[-1] += width - sum(w_list)
170 try:
171 split_dim1 = torch.cat(torch.split(watermark_r, tuple(ch_list), dim=1), dim=0)
172 split_dim2 = torch.cat(torch.split(split_dim1, tuple(frame_list), dim=2), dim=0)
173 split_dim3 = torch.cat(torch.split(split_dim2, tuple(h_list), dim=3), dim=0)
174 split_dim4 = torch.cat(torch.split(split_dim3, tuple(w_list), dim=4), dim=0)
176 vote = torch.sum(split_dim4, dim=0).clone()
177 vote[vote <= self.vote_threshold] = 0
178 vote[vote > self.vote_threshold] = 1
180 return vote
181 except Exception as e:
182 logger.error(f"Video diffusion inverse failed: {e}")
183 return torch.zeros_like(self.watermark)
185 # def _image_diffusion_inverse(self, watermark_r: torch.Tensor) -> torch.Tensor:
186 # """Image-specific diffusion inverse."""
187 # # Handle both 4D and 5D tensors by squeezing if needed
188 # if watermark_r.dim() == 5:
189 # watermark_r = watermark_r.squeeze(2) # Remove frame dimension
191 # batch, channels, height, width = watermark_r.shape
193 # ch_stride = channels // self.k_c
194 # h_stride = height // self.k_h
195 # w_stride = width // self.k_w
197 # # Ensure strides are at least 1
198 # ch_stride = max(1, ch_stride)
199 # h_stride = max(1, h_stride)
200 # w_stride = max(1, w_stride)
202 # # Adjust repetition factors if dimensions are too small
203 # k_c = min(self.k_c, channels)
204 # k_h = min(self.k_h, height)
205 # k_w = min(self.k_w, width)
207 # ch_list = [ch_stride] * k_c
208 # h_list = [h_stride] * k_h
209 # w_list = [w_stride] * k_w
211 # # Handle remainder pixels
212 # if sum(ch_list) < channels:
213 # ch_list[-1] += channels - sum(ch_list)
214 # if sum(h_list) < height:
215 # h_list[-1] += height - sum(h_list)
216 # if sum(w_list) < width:
217 # w_list[-1] += width - sum(w_list)
219 # try:
220 # split_dim1 = torch.cat(torch.split(watermark_r, tuple(ch_list), dim=1), dim=0)
221 # split_dim2 = torch.cat(torch.split(split_dim1, tuple(h_list), dim=2), dim=0)
222 # split_dim3 = torch.cat(torch.split(split_dim2, tuple(w_list), dim=3), dim=0)
224 # vote = torch.sum(split_dim3, dim=0).clone()
225 # vote[vote <= self.vote_threshold] = 0
226 # vote[vote > self.vote_threshold] = 1
228 # return vote
229 # except Exception as e:
230 # logger.error(f"Image diffusion inverse failed: {e}")
231 # # Return a fallback result
232 # return torch.zeros_like(self.watermark)
234 def eval_watermark(self,
235 reversed_latents: torch.Tensor,
236 detector_type: str = "bit_acc") -> Dict[str, Union[bool, float]]:
237 """Evaluate the watermark in the reversed latents.
239 Args:
240 reversed_latents: The reversed latents from forward diffusion
241 detector_type: The type of detector to use ('bit_acc', 'standard', etc.)
243 Returns:
244 Dict containing detection results and confidence scores
245 """
246 if detector_type not in ['bit_acc']:
247 raise ValueError(f'Detector type {detector_type} is not supported for VideoShield.')
249 # Basic validation
250 if reversed_latents.numel() == 0:
251 return {'is_watermarked': False, 'bit_acc': 0.0, 'confidence': 0.0}
253 # Convert latents to binary bits
254 reversed_m = (reversed_latents > 0).int()
256 # Decrypt if encryption keys are available
257 if self.chacha_key is not None and self.chacha_nonce is not None:
258 reversed_sd = self._stream_key_decrypt(reversed_m.flatten().cpu().numpy())
259 else:
260 # No decryption, use reversed bits directly
261 reversed_sd = reversed_m
263 # Reshape back to tensor format
264 if reversed_latents.dim() == 5:
265 # Video case
266 # [B, C, F, H, W] for T2V model
267 # [B, F, C, H, W] for I2V model
268 batch, channels_or_frames, frames_or_channels, height, width = reversed_latents.shape
269 reversed_sd_tensor = torch.from_numpy(reversed_sd).reshape(
270 batch, channels_or_frames, frames_or_channels, height, width
271 ).to(torch.uint8).to(self.device)
272 else:
273 # Image case
274 batch, channels, height, width = reversed_latents.shape
275 reversed_sd_tensor = torch.from_numpy(reversed_sd).reshape(
276 batch, channels, height, width
277 ).to(torch.uint8).to(self.device)
279 # Extract watermark through voting
280 is_video = reversed_latents.dim() == 5
281 reversed_watermark = self._diffusion_inverse(reversed_sd_tensor, is_video)
283 correct = (reversed_watermark == self.watermark).float().mean().item()
285 return {
286 'is_watermarked': bool(correct > self.threshold),
287 'bit_acc': correct,
288 }