Coverage for detection / videomark / videomark_detection.py: 94.41%

179 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 10:24 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16import torch 

17import numpy as np 

18from typing import Tuple, Type 

19from scipy.special import erf 

20from detection.base import BaseDetector 

21from ldpc import bp_decoder 

22from galois import FieldArray 

23import sys 

24from Levenshtein import hamming 

25import logging 

26 

27logger = logging.getLogger(__name__) 

28 

29class VideoMarkDetector(BaseDetector): 

30 

31 def __init__(self, 

32 message_sequence: np.ndarray, 

33 watermark: np.ndarray, 

34 num_frames: int, 

35 var: int, 

36 decoding_key: tuple, 

37 GF: Type[FieldArray], 

38 threshold: float, 

39 device: torch.device): 

40 super().__init__(threshold, device) 

41 

42 self.message_sequence = message_sequence 

43 self.watermark = watermark 

44 self.num_frames = num_frames 

45 self.var = var 

46 self.decoding_key = decoding_key 

47 self.GF = GF 

48 self.threshold = threshold 

49 self.code_length = self.decoding_key[0].shape[0] 

50 

51 def _recover_posteriors(self, z, basis=None, variances=None): 

52 if variances is None: 

53 default_variance = 1.5 

54 denominators = np.sqrt(2 * default_variance * (1+default_variance)) * torch.ones_like(z) 

55 elif type(variances) is float: 

56 denominators = np.sqrt(2 * variances * (1 + variances)) 

57 else: 

58 denominators = torch.sqrt(2 * variances * (1 + variances)) 

59 

60 if basis is None: 

61 return erf(z / denominators) 

62 else: 

63 return erf((z @ basis) / denominators) 

64 

65 def _align_posteriors_length(self, posteriors: torch.Tensor) -> torch.Tensor: 

66 """Ensure posterior vector length matches the code length expected by LDPC decoder.""" 

67 current_len = posteriors.numel() 

68 if current_len == self.code_length: 

69 return posteriors 

70 if current_len > self.code_length: 

71 logger.warning( 

72 "Detected per-frame feature length %s is greater than the encoding length %s. The excess part has been truncated.", 

73 current_len, 

74 self.code_length, 

75 ) 

76 return posteriors[:self.code_length] 

77 

78 pad_len = self.code_length - current_len 

79 logger.warning( 

80 "Detected per-frame feature length %s is less than the encoding length %s. %s elements have been padded with zeros at the end.", 

81 current_len, 

82 self.code_length, 

83 pad_len, 

84 ) 

85 pad = torch.zeros(pad_len, dtype=posteriors.dtype, device=posteriors.device) 

86 return torch.cat([posteriors, pad], dim=0) 

87 

88 def _detect_watermark(self, posteriors: torch.Tensor) -> Tuple[bool, float]: 

89 """Detect watermark in posteriors.""" 

90 generator_matrix, parity_check_matrix, one_time_pad, false_positive_rate, noise_rate, test_bits, g, max_bp_iter, t = self.decoding_key 

91 posteriors = self._align_posteriors_length(posteriors) 

92 posteriors = (1 - 2 * noise_rate) * (1 - 2 * np.array(one_time_pad, dtype=float)) * posteriors.numpy(force=True) 

93 

94 r = parity_check_matrix.shape[0] 

95 Pi = np.prod(posteriors[parity_check_matrix.indices.reshape(r, t)], axis=1) 

96 log_plus = np.log((1 + Pi) / 2) 

97 log_minus = np.log((1 - Pi) / 2) 

98 log_prod = log_plus + log_minus 

99 

100 const = 0.5 * np.sum(np.power(log_plus, 2) + np.power(log_minus, 2) - 0.5 * np.power(log_prod, 2)) 

101 threshold = np.sqrt(2 * const * np.log(1 / false_positive_rate)) + 0.5 * log_prod.sum() 

102 #print(f"threshold: {threshold}") 

103 return log_plus.sum() >= threshold#, log_plus.sum() 

104 

105 def _boolean_row_reduce(self, A, print_progress=False): 

106 """Given a GF(2) matrix, do row elimination and return the first k rows of A that form an invertible matrix 

107 

108 Args: 

109 A (np.ndarray): A GF(2) matrix 

110 print_progress (bool, optional): Whether to print the progress. Defaults to False. 

111 

112 Returns: 

113 np.ndarray: The first k rows of A that form an invertible matrix 

114 """ 

115 n, k = A.shape 

116 A_rr = A.copy() 

117 perm = np.arange(n) 

118 for j in range(k): 

119 idxs = j + np.nonzero(A_rr[j:, j])[0] 

120 if idxs.size == 0: 

121 print("The given matrix is not invertible") 

122 return None 

123 A_rr[[j, idxs[0]]] = A_rr[[idxs[0], j]] # For matrices you have to swap them this way 

124 (perm[j], perm[idxs[0]]) = (perm[idxs[0]], perm[j]) # Weirdly, this is MUCH faster if you swap this way instead of using perm[[i,j]]=perm[[j,i]] 

125 A_rr[idxs[1:]] += A_rr[j] 

126 if print_progress and (j%5==0 or j+1==k): 

127 sys.stdout.write(f'\rDecoding progress: {j + 1} / {k}') 

128 sys.stdout.flush() 

129 if print_progress: print() 

130 return perm[:k] 

131 

132 def _decode_message(self, posteriors, print_progress=False, max_bp_iter=None): 

133 generator_matrix, parity_check_matrix, one_time_pad, false_positive_rate_key, noise_rate, test_bits, g, max_bp_iter_key, t = self.decoding_key 

134 if max_bp_iter is None: 

135 max_bp_iter = max_bp_iter_key 

136 

137 posteriors = self._align_posteriors_length(posteriors) 

138 posteriors = (1 - 2 * noise_rate) * (1 - 2 * np.array(one_time_pad, dtype=float)) * posteriors.numpy(force=True) 

139 channel_probs = (1 - np.abs(posteriors)) / 2 

140 x_recovered = (1 - np.sign(posteriors)) // 2 

141 

142 

143 # Apply the belief-propagation decoder. 

144 if print_progress: 

145 print("Running belief propagation...") 

146 bpd = bp_decoder(parity_check_matrix, channel_probs=channel_probs, max_iter=max_bp_iter, bp_method="product_sum") 

147 x_decoded = bpd.decode(x_recovered) 

148 

149 # Compute a confidence score. 

150 bpd_probs = 1 / (1 + np.exp(bpd.log_prob_ratios)) 

151 confidences = 2 * np.abs(0.5 - bpd_probs) 

152 

153 # Order codeword bits by confidence. 

154 confidence_order = np.argsort(-confidences) 

155 ordered_generator_matrix = generator_matrix[confidence_order] 

156 ordered_x_decoded = x_decoded[confidence_order].astype(int) 

157 

158 # Find the first (according to the confidence order) linearly independent set of rows of the generator matrix. 

159 top_invertible_rows = self._boolean_row_reduce(ordered_generator_matrix, print_progress=print_progress) 

160 if top_invertible_rows is None: 

161 return None 

162 

163 # Solve the system. 

164 if print_progress: 

165 print("Solving linear system...") 

166 recovered_string = np.linalg.solve(ordered_generator_matrix[top_invertible_rows], self.GF(ordered_x_decoded[top_invertible_rows])) 

167 

168 return np.array(recovered_string[len(test_bits) + g:]) 

169 

170 def bits_to_string(self, bits): 

171 return ''.join(map(str, bits)) 

172 

173 def recover(self, idx_list, message_list, distance_list, message_length): 

174 """ 

175 Recover the original message sequence from indices, messages, and distances. 

176 

177 - If idx != -1: use sorted valid entries (normal recovery) 

178 - If idx == -1: use the original message_list and distance_list directly 

179 """ 

180 

181 valid_entries = [ 

182 (idx, msg, dist) 

183 for idx, msg, dist in zip(idx_list, message_list, distance_list) 

184 if idx != -1 

185 ] 

186 valid_entries.sort(key=lambda x: x[0]) 

187 

188 sorted_order, recovered_message, recovered_distance = [], [], [] 

189 valid_idx = 0 

190 

191 for i, idx in enumerate(idx_list): 

192 if idx == -1: 

193 

194 sorted_order.append(-1) 

195 recovered_message.append(message_list[i]) 

196 recovered_distance.append(distance_list[i]) 

197 else: 

198 sorted_order.append(valid_entries[valid_idx][0]) 

199 recovered_message.append(valid_entries[valid_idx][1]) 

200 recovered_distance.append(valid_entries[valid_idx][2]) 

201 valid_idx += 1 

202 

203 return ( 

204 np.array(sorted_order), 

205 np.array(recovered_message), 

206 np.array(recovered_distance) 

207 ) 

208 

209 

210 

211 

212 def eval_watermark(self, reversed_latents: torch.Tensor, reference_latents: torch.Tensor = None, detector_type: str = "bit_acc") -> float: 

213 """Evaluate watermark in reversed latents.""" 

214 if detector_type != 'bit_acc': 

215 raise ValueError(f'Detector type {detector_type} is not supported for VideoMark. Use "bit_acc" instead.') 

216 

217 if reversed_latents.dim() < 3: 

218 raise ValueError("VideoMark detector expects at least 3D reversed latents for video inputs.") 

219 

220 latents = reversed_latents 

221 

222 if latents.dim() == 4: 

223 latents = latents.unsqueeze(0) 

224 

225 if latents.dim() != 5: 

226 raise ValueError("Unsupported reversed_latents dimensionality for VideoMark detector.") 

227 

228 expected_frames = self.num_frames or 0 

229 

230 channel_axis = None 

231 for axis in range(1, 5): 

232 if latents.shape[axis] == 4: 

233 channel_axis = axis 

234 break 

235 if channel_axis is not None and channel_axis != 1: 

236 latents = latents.movedim(channel_axis, 1) 

237 

238 candidate_axes = [axis for axis in range(2, 5)] 

239 if expected_frames: 

240 frame_axis = min(candidate_axes, key=lambda axis: abs(latents.shape[axis] - expected_frames)) 

241 else: 

242 frame_axis = 2 

243 if frame_axis != 2: 

244 latents = latents.movedim(frame_axis, 2) 

245 

246 available_frames = latents.shape[2] 

247 frames_to_use = available_frames 

248 

249 if expected_frames: 

250 if available_frames != expected_frames: 

251 logger.warning( 

252 "Frame count mismatch detected: received %s frames, expected %s frames.", 

253 available_frames, 

254 expected_frames, 

255 ) 

256 frames_to_use = min(available_frames, expected_frames) 

257 logger.info("Truncated to the first %d frames for detection.", frames_to_use) 

258 

259 if frames_to_use <= 1: 

260 logger.error("There are not enough frames for VideoMark detection.") 

261 return { 

262 'is_watermarked': False, 

263 "bit_acc": 0.0, 

264 "recovered_index": np.array([]), 

265 "recovered_message": np.array([]), 

266 "recovered_distance": np.array([]), 

267 } 

268 

269 latents = latents[:, :, :frames_to_use, ...] 

270 

271 idx_list, message_list, distance_list = [], [], [] 

272 message_length = self.message_sequence.shape[1] 

273 message_sequence_str = [self.bits_to_string(msg) for msg in self.message_sequence] 

274 

275 for frame_index in range(frames_to_use): 

276 frame_latents = latents[:, :, frame_index, ...].to(torch.float64) 

277 reversed_prc = self._recover_posteriors( 

278 frame_latents.flatten().cpu(), 

279 variances=self.var 

280 ).flatten().cpu() 

281 aligned_prc = self._align_posteriors_length(reversed_prc) 

282 self.recovered_prc = aligned_prc 

283 

284 detect_result = self._detect_watermark(aligned_prc) 

285 decode_message = self._decode_message(aligned_prc) 

286 if decode_message is None: 

287 decode_message = np.zeros((message_length,), dtype=int) 

288 decode_message_str = "0" * message_length 

289 else: 

290 decode_message = np.asarray(decode_message).astype(int) 

291 decode_message_str = self.bits_to_string(decode_message) 

292 

293 distances = np.array([ 

294 2 * (hamming(decode_message_str, msg) / len(msg) - 0.5) 

295 for msg in message_sequence_str 

296 ]) 

297 min_distance = distances.min() 

298 idx = -1 if not detect_result else distances.argmin() 

299 

300 message_list.append(decode_message) 

301 distance_list.append(min_distance) 

302 idx_list.append(idx) 

303 

304 recovered_index, recovered_message, recovered_distance = self.recover( 

305 idx_list, message_list, distance_list, message_length 

306 ) 

307 

308 watermark_reference = np.asarray(self.watermark[:frames_to_use]) 

309 recovered_message = np.asarray(recovered_message) 

310 

311 if recovered_message.size == 0 or watermark_reference.size == 0: 

312 bit_acc = 0.0 

313 else: 

314 frame_limit = min(recovered_message.shape[0], watermark_reference.shape[0]) 

315 if frame_limit <= 0: 

316 bit_acc = 0.0 

317 else: 

318 bit_acc = np.mean( 

319 recovered_message[:frame_limit] == watermark_reference[:frame_limit] 

320 ) 

321 

322 return { 

323 'is_watermarked': float(bit_acc) >= self.threshold, 

324 "bit_acc": float(bit_acc), 

325 "recovered_index": recovered_index, 

326 "recovered_message": recovered_message, 

327 "recovered_distance": recovered_distance, 

328 } 

329