Coverage for detection / videoshield / videoshield_detection.py: 92.92%

113 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 10:24 +0000

1import torch 

2import numpy as np 

3from typing import Dict, Optional, Tuple, Union 

4from detection.base import BaseDetector 

5from Crypto.Cipher import ChaCha20 

6from scipy.stats import truncnorm, norm 

7from functools import reduce 

8import logging 

9 

10logger = logging.getLogger(__name__) 

11 

12 

13class VideoShieldDetector(BaseDetector): 

14 """VideoShield watermark detector class.""" 

15 

16 def __init__(self, 

17 watermark: torch.Tensor, 

18 threshold: float, 

19 device: torch.device, 

20 chacha_key: Optional[bytes] = None, 

21 chacha_nonce: Optional[bytes] = None, 

22 height: int = 64, 

23 width: int = 64, 

24 num_frames: int = 0, 

25 k_f: int = 8, 

26 k_c: int = 1, 

27 k_h: int = 4, 

28 k_w: int = 4) -> None: 

29 """Initialize the VideoShield detector. 

30  

31 Args: 

32 watermark: The watermarking bits 

33 threshold: Threshold for watermark detection 

34 device: The device to use for computation 

35 chacha_key: ChaCha20 encryption key (optional) 

36 chacha_nonce: ChaCha20 nonce (optional) 

37 height: Height of the video 

38 width: Width of the video 

39 num_frames: Number of frames in the video 

40 k_f: Frame repetition factor 

41 k_c: Channel repetition factor 

42 k_h: Height repetition factor 

43 k_w: Width repetition factor 

44 """ 

45 super().__init__(threshold, device) 

46 self.watermark = watermark.to(device) 

47 self.chacha_key = chacha_key 

48 self.chacha_nonce = chacha_nonce 

49 self.num_frames = num_frames 

50 self.height = height 

51 self.width = width 

52 # Repetition factors 

53 self.k_f = k_f 

54 self.k_c = k_c 

55 self.k_h = k_h 

56 self.k_w = k_w 

57 

58 # Calculate voting threshold 

59 if k_f == 1 and k_c == 1 and k_h == 1 and k_w == 1: 

60 self.vote_threshold = 1 

61 else: 

62 self.vote_threshold = (k_f * k_c * k_h * k_w) // 2 

63 

64 def _stream_key_decrypt(self, reversed_m: np.ndarray) -> np.ndarray: 

65 """Decrypt the watermark using ChaCha20 cipher.""" 

66 if self.chacha_key is None or self.chacha_nonce is None: 

67 # If no encryption keys provided, return as-is 

68 return reversed_m 

69 

70 cipher = ChaCha20.new(key=self.chacha_key, nonce=self.chacha_nonce) 

71 sd_byte = cipher.decrypt(np.packbits(reversed_m).tobytes()) 

72 sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8)) 

73 

74 return sd_bit 

75 

76 def _diffusion_inverse(self, watermark_r: torch.Tensor, is_video: bool = False) -> torch.Tensor: 

77 """Inverse the diffusion process to extract the watermark through voting. 

78  

79 Args: 

80 watermark_r: The reversed watermark tensor 

81 is_video: Whether this is video (5D) or image (4D) data 

82  

83 Returns: 

84 Extracted watermark through voting 

85 """ 

86 if is_video and watermark_r.dim() == 5: 

87 return self._video_diffusion_inverse(watermark_r) 

88 # else: 

89 # return self._image_diffusion_inverse(watermark_r) 

90 

91 def _video_diffusion_inverse(self, watermark_r: torch.Tensor) -> torch.Tensor: 

92 """Video-specific diffusion inverse with frame dimension handling.""" 

93 batch, channels, frames, height, width = watermark_r.shape 

94 

95 expected_frames = getattr(self, "num_frames", 0) 

96 frames_to_use = frames 

97 

98 if expected_frames: 

99 if frames != expected_frames: 

100 logger.warning( 

101 "Frame count mismatch detected: received %s frames, expected %s frames.", 

102 frames, 

103 expected_frames, 

104 ) 

105 frames_to_use = min(frames, expected_frames) 

106 logger.info("Truncated to the first %d frames for detection.", frames_to_use) 

107 

108 if frames_to_use != frames: 

109 watermark_r = watermark_r[:, :, :frames_to_use, :, :] 

110 frames = frames_to_use 

111 

112 if frames < self.k_f: 

113 logger.error( 

114 "VideoShield detector cannot process %s frames with repetition factor %s.", 

115 frames, 

116 self.k_f, 

117 ) 

118 return torch.zeros_like(self.watermark) 

119 

120 remainder = frames % self.k_f 

121 if remainder: 

122 aligned_frames = frames - remainder 

123 if aligned_frames <= 0: 

124 logger.error( 

125 "Unable to align frame count (%s) with repetition factor %s.", 

126 frames, 

127 self.k_f, 

128 ) 

129 return torch.zeros_like(self.watermark) 

130 logger.info( 

131 "Aligning detection frames to %s frames to satisfy repetition factor %s.", 

132 self.k_f, 

133 aligned_frames, 

134 ) 

135 watermark_r = watermark_r[:, :, :aligned_frames, :, :] 

136 frames = aligned_frames 

137 

138 ch_stride = channels // self.k_c 

139 frame_stride = frames // self.k_f 

140 h_stride = height // self.k_h 

141 w_stride = width // self.k_w 

142 

143 # Ensure strides are at least 1 

144 ch_stride = max(1, ch_stride) 

145 frame_stride = max(1, frame_stride) 

146 h_stride = max(1, h_stride) 

147 w_stride = max(1, w_stride) 

148 

149 # Adjust repetition factors if dimensions are too small 

150 k_c = min(self.k_c, channels) 

151 k_f = min(self.k_f, frames) 

152 k_h = min(self.k_h, height) 

153 k_w = min(self.k_w, width) 

154 

155 ch_list = [ch_stride] * k_c 

156 frame_list = [frame_stride] * k_f 

157 h_list = [h_stride] * k_h 

158 w_list = [w_stride] * k_w 

159 

160 # Handle remainder pixels 

161 if sum(ch_list) < channels: 

162 ch_list[-1] += channels - sum(ch_list) 

163 if sum(frame_list) < frames: 

164 frame_list[-1] += frames - sum(frame_list) 

165 if sum(h_list) < height: 

166 h_list[-1] += height - sum(h_list) 

167 if sum(w_list) < width: 

168 w_list[-1] += width - sum(w_list) 

169 

170 try: 

171 split_dim1 = torch.cat(torch.split(watermark_r, tuple(ch_list), dim=1), dim=0) 

172 split_dim2 = torch.cat(torch.split(split_dim1, tuple(frame_list), dim=2), dim=0) 

173 split_dim3 = torch.cat(torch.split(split_dim2, tuple(h_list), dim=3), dim=0) 

174 split_dim4 = torch.cat(torch.split(split_dim3, tuple(w_list), dim=4), dim=0) 

175 

176 vote = torch.sum(split_dim4, dim=0).clone() 

177 vote[vote <= self.vote_threshold] = 0 

178 vote[vote > self.vote_threshold] = 1 

179 

180 return vote 

181 except Exception as e: 

182 logger.error(f"Video diffusion inverse failed: {e}") 

183 return torch.zeros_like(self.watermark) 

184 

185 # def _image_diffusion_inverse(self, watermark_r: torch.Tensor) -> torch.Tensor: 

186 # """Image-specific diffusion inverse.""" 

187 # # Handle both 4D and 5D tensors by squeezing if needed 

188 # if watermark_r.dim() == 5: 

189 # watermark_r = watermark_r.squeeze(2) # Remove frame dimension 

190 

191 # batch, channels, height, width = watermark_r.shape 

192 

193 # ch_stride = channels // self.k_c 

194 # h_stride = height // self.k_h 

195 # w_stride = width // self.k_w 

196 

197 # # Ensure strides are at least 1 

198 # ch_stride = max(1, ch_stride) 

199 # h_stride = max(1, h_stride) 

200 # w_stride = max(1, w_stride) 

201 

202 # # Adjust repetition factors if dimensions are too small 

203 # k_c = min(self.k_c, channels) 

204 # k_h = min(self.k_h, height) 

205 # k_w = min(self.k_w, width) 

206 

207 # ch_list = [ch_stride] * k_c 

208 # h_list = [h_stride] * k_h 

209 # w_list = [w_stride] * k_w 

210 

211 # # Handle remainder pixels 

212 # if sum(ch_list) < channels: 

213 # ch_list[-1] += channels - sum(ch_list) 

214 # if sum(h_list) < height: 

215 # h_list[-1] += height - sum(h_list) 

216 # if sum(w_list) < width: 

217 # w_list[-1] += width - sum(w_list) 

218 

219 # try: 

220 # split_dim1 = torch.cat(torch.split(watermark_r, tuple(ch_list), dim=1), dim=0) 

221 # split_dim2 = torch.cat(torch.split(split_dim1, tuple(h_list), dim=2), dim=0) 

222 # split_dim3 = torch.cat(torch.split(split_dim2, tuple(w_list), dim=3), dim=0) 

223 

224 # vote = torch.sum(split_dim3, dim=0).clone() 

225 # vote[vote <= self.vote_threshold] = 0 

226 # vote[vote > self.vote_threshold] = 1 

227 

228 # return vote 

229 # except Exception as e: 

230 # logger.error(f"Image diffusion inverse failed: {e}") 

231 # # Return a fallback result 

232 # return torch.zeros_like(self.watermark) 

233 

234 def eval_watermark(self, 

235 reversed_latents: torch.Tensor, 

236 detector_type: str = "bit_acc") -> Dict[str, Union[bool, float]]: 

237 """Evaluate the watermark in the reversed latents. 

238  

239 Args: 

240 reversed_latents: The reversed latents from forward diffusion 

241 detector_type: The type of detector to use ('bit_acc', 'standard', etc.) 

242  

243 Returns: 

244 Dict containing detection results and confidence scores 

245 """ 

246 if detector_type not in ['bit_acc']: 

247 raise ValueError(f'Detector type {detector_type} is not supported for VideoShield.') 

248 

249 # Basic validation 

250 if reversed_latents.numel() == 0: 

251 return {'is_watermarked': False, 'bit_acc': 0.0, 'confidence': 0.0} 

252 

253 # Convert latents to binary bits 

254 reversed_m = (reversed_latents > 0).int() 

255 

256 # Decrypt if encryption keys are available 

257 if self.chacha_key is not None and self.chacha_nonce is not None: 

258 reversed_sd = self._stream_key_decrypt(reversed_m.flatten().cpu().numpy()) 

259 else: 

260 # No decryption, use reversed bits directly 

261 reversed_sd = reversed_m 

262 

263 # Reshape back to tensor format 

264 if reversed_latents.dim() == 5: 

265 # Video case 

266 # [B, C, F, H, W] for T2V model 

267 # [B, F, C, H, W] for I2V model 

268 batch, channels_or_frames, frames_or_channels, height, width = reversed_latents.shape 

269 reversed_sd_tensor = torch.from_numpy(reversed_sd).reshape( 

270 batch, channels_or_frames, frames_or_channels, height, width 

271 ).to(torch.uint8).to(self.device) 

272 else: 

273 # Image case 

274 batch, channels, height, width = reversed_latents.shape 

275 reversed_sd_tensor = torch.from_numpy(reversed_sd).reshape( 

276 batch, channels, height, width 

277 ).to(torch.uint8).to(self.device) 

278 

279 # Extract watermark through voting 

280 is_video = reversed_latents.dim() == 5 

281 reversed_watermark = self._diffusion_inverse(reversed_sd_tensor, is_video) 

282 

283 correct = (reversed_watermark == self.watermark).float().mean().item() 

284 

285 return { 

286 'is_watermarked': bool(correct > self.threshold), 

287 'bit_acc': correct, 

288 }