Coverage for detection / videomark / videomark_detection.py: 94.41%
179 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 10:24 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 10:24 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16import torch
17import numpy as np
18from typing import Tuple, Type
19from scipy.special import erf
20from detection.base import BaseDetector
21from ldpc import bp_decoder
22from galois import FieldArray
23import sys
24from Levenshtein import hamming
25import logging
27logger = logging.getLogger(__name__)
29class VideoMarkDetector(BaseDetector):
31 def __init__(self,
32 message_sequence: np.ndarray,
33 watermark: np.ndarray,
34 num_frames: int,
35 var: int,
36 decoding_key: tuple,
37 GF: Type[FieldArray],
38 threshold: float,
39 device: torch.device):
40 super().__init__(threshold, device)
42 self.message_sequence = message_sequence
43 self.watermark = watermark
44 self.num_frames = num_frames
45 self.var = var
46 self.decoding_key = decoding_key
47 self.GF = GF
48 self.threshold = threshold
49 self.code_length = self.decoding_key[0].shape[0]
51 def _recover_posteriors(self, z, basis=None, variances=None):
52 if variances is None:
53 default_variance = 1.5
54 denominators = np.sqrt(2 * default_variance * (1+default_variance)) * torch.ones_like(z)
55 elif type(variances) is float:
56 denominators = np.sqrt(2 * variances * (1 + variances))
57 else:
58 denominators = torch.sqrt(2 * variances * (1 + variances))
60 if basis is None:
61 return erf(z / denominators)
62 else:
63 return erf((z @ basis) / denominators)
65 def _align_posteriors_length(self, posteriors: torch.Tensor) -> torch.Tensor:
66 """Ensure posterior vector length matches the code length expected by LDPC decoder."""
67 current_len = posteriors.numel()
68 if current_len == self.code_length:
69 return posteriors
70 if current_len > self.code_length:
71 logger.warning(
72 "Detected per-frame feature length %s is greater than the encoding length %s. The excess part has been truncated.",
73 current_len,
74 self.code_length,
75 )
76 return posteriors[:self.code_length]
78 pad_len = self.code_length - current_len
79 logger.warning(
80 "Detected per-frame feature length %s is less than the encoding length %s. %s elements have been padded with zeros at the end.",
81 current_len,
82 self.code_length,
83 pad_len,
84 )
85 pad = torch.zeros(pad_len, dtype=posteriors.dtype, device=posteriors.device)
86 return torch.cat([posteriors, pad], dim=0)
88 def _detect_watermark(self, posteriors: torch.Tensor) -> Tuple[bool, float]:
89 """Detect watermark in posteriors."""
90 generator_matrix, parity_check_matrix, one_time_pad, false_positive_rate, noise_rate, test_bits, g, max_bp_iter, t = self.decoding_key
91 posteriors = self._align_posteriors_length(posteriors)
92 posteriors = (1 - 2 * noise_rate) * (1 - 2 * np.array(one_time_pad, dtype=float)) * posteriors.numpy(force=True)
94 r = parity_check_matrix.shape[0]
95 Pi = np.prod(posteriors[parity_check_matrix.indices.reshape(r, t)], axis=1)
96 log_plus = np.log((1 + Pi) / 2)
97 log_minus = np.log((1 - Pi) / 2)
98 log_prod = log_plus + log_minus
100 const = 0.5 * np.sum(np.power(log_plus, 2) + np.power(log_minus, 2) - 0.5 * np.power(log_prod, 2))
101 threshold = np.sqrt(2 * const * np.log(1 / false_positive_rate)) + 0.5 * log_prod.sum()
102 #print(f"threshold: {threshold}")
103 return log_plus.sum() >= threshold#, log_plus.sum()
105 def _boolean_row_reduce(self, A, print_progress=False):
106 """Given a GF(2) matrix, do row elimination and return the first k rows of A that form an invertible matrix
108 Args:
109 A (np.ndarray): A GF(2) matrix
110 print_progress (bool, optional): Whether to print the progress. Defaults to False.
112 Returns:
113 np.ndarray: The first k rows of A that form an invertible matrix
114 """
115 n, k = A.shape
116 A_rr = A.copy()
117 perm = np.arange(n)
118 for j in range(k):
119 idxs = j + np.nonzero(A_rr[j:, j])[0]
120 if idxs.size == 0:
121 print("The given matrix is not invertible")
122 return None
123 A_rr[[j, idxs[0]]] = A_rr[[idxs[0], j]] # For matrices you have to swap them this way
124 (perm[j], perm[idxs[0]]) = (perm[idxs[0]], perm[j]) # Weirdly, this is MUCH faster if you swap this way instead of using perm[[i,j]]=perm[[j,i]]
125 A_rr[idxs[1:]] += A_rr[j]
126 if print_progress and (j%5==0 or j+1==k):
127 sys.stdout.write(f'\rDecoding progress: {j + 1} / {k}')
128 sys.stdout.flush()
129 if print_progress: print()
130 return perm[:k]
132 def _decode_message(self, posteriors, print_progress=False, max_bp_iter=None):
133 generator_matrix, parity_check_matrix, one_time_pad, false_positive_rate_key, noise_rate, test_bits, g, max_bp_iter_key, t = self.decoding_key
134 if max_bp_iter is None:
135 max_bp_iter = max_bp_iter_key
137 posteriors = self._align_posteriors_length(posteriors)
138 posteriors = (1 - 2 * noise_rate) * (1 - 2 * np.array(one_time_pad, dtype=float)) * posteriors.numpy(force=True)
139 channel_probs = (1 - np.abs(posteriors)) / 2
140 x_recovered = (1 - np.sign(posteriors)) // 2
143 # Apply the belief-propagation decoder.
144 if print_progress:
145 print("Running belief propagation...")
146 bpd = bp_decoder(parity_check_matrix, channel_probs=channel_probs, max_iter=max_bp_iter, bp_method="product_sum")
147 x_decoded = bpd.decode(x_recovered)
149 # Compute a confidence score.
150 bpd_probs = 1 / (1 + np.exp(bpd.log_prob_ratios))
151 confidences = 2 * np.abs(0.5 - bpd_probs)
153 # Order codeword bits by confidence.
154 confidence_order = np.argsort(-confidences)
155 ordered_generator_matrix = generator_matrix[confidence_order]
156 ordered_x_decoded = x_decoded[confidence_order].astype(int)
158 # Find the first (according to the confidence order) linearly independent set of rows of the generator matrix.
159 top_invertible_rows = self._boolean_row_reduce(ordered_generator_matrix, print_progress=print_progress)
160 if top_invertible_rows is None:
161 return None
163 # Solve the system.
164 if print_progress:
165 print("Solving linear system...")
166 recovered_string = np.linalg.solve(ordered_generator_matrix[top_invertible_rows], self.GF(ordered_x_decoded[top_invertible_rows]))
168 return np.array(recovered_string[len(test_bits) + g:])
170 def bits_to_string(self, bits):
171 return ''.join(map(str, bits))
173 def recover(self, idx_list, message_list, distance_list, message_length):
174 """
175 Recover the original message sequence from indices, messages, and distances.
177 - If idx != -1: use sorted valid entries (normal recovery)
178 - If idx == -1: use the original message_list and distance_list directly
179 """
181 valid_entries = [
182 (idx, msg, dist)
183 for idx, msg, dist in zip(idx_list, message_list, distance_list)
184 if idx != -1
185 ]
186 valid_entries.sort(key=lambda x: x[0])
188 sorted_order, recovered_message, recovered_distance = [], [], []
189 valid_idx = 0
191 for i, idx in enumerate(idx_list):
192 if idx == -1:
194 sorted_order.append(-1)
195 recovered_message.append(message_list[i])
196 recovered_distance.append(distance_list[i])
197 else:
198 sorted_order.append(valid_entries[valid_idx][0])
199 recovered_message.append(valid_entries[valid_idx][1])
200 recovered_distance.append(valid_entries[valid_idx][2])
201 valid_idx += 1
203 return (
204 np.array(sorted_order),
205 np.array(recovered_message),
206 np.array(recovered_distance)
207 )
212 def eval_watermark(self, reversed_latents: torch.Tensor, reference_latents: torch.Tensor = None, detector_type: str = "bit_acc") -> float:
213 """Evaluate watermark in reversed latents."""
214 if detector_type != 'bit_acc':
215 raise ValueError(f'Detector type {detector_type} is not supported for VideoMark. Use "bit_acc" instead.')
217 if reversed_latents.dim() < 3:
218 raise ValueError("VideoMark detector expects at least 3D reversed latents for video inputs.")
220 latents = reversed_latents
222 if latents.dim() == 4:
223 latents = latents.unsqueeze(0)
225 if latents.dim() != 5:
226 raise ValueError("Unsupported reversed_latents dimensionality for VideoMark detector.")
228 expected_frames = self.num_frames or 0
230 channel_axis = None
231 for axis in range(1, 5):
232 if latents.shape[axis] == 4:
233 channel_axis = axis
234 break
235 if channel_axis is not None and channel_axis != 1:
236 latents = latents.movedim(channel_axis, 1)
238 candidate_axes = [axis for axis in range(2, 5)]
239 if expected_frames:
240 frame_axis = min(candidate_axes, key=lambda axis: abs(latents.shape[axis] - expected_frames))
241 else:
242 frame_axis = 2
243 if frame_axis != 2:
244 latents = latents.movedim(frame_axis, 2)
246 available_frames = latents.shape[2]
247 frames_to_use = available_frames
249 if expected_frames:
250 if available_frames != expected_frames:
251 logger.warning(
252 "Frame count mismatch detected: received %s frames, expected %s frames.",
253 available_frames,
254 expected_frames,
255 )
256 frames_to_use = min(available_frames, expected_frames)
257 logger.info("Truncated to the first %d frames for detection.", frames_to_use)
259 if frames_to_use <= 1:
260 logger.error("There are not enough frames for VideoMark detection.")
261 return {
262 'is_watermarked': False,
263 "bit_acc": 0.0,
264 "recovered_index": np.array([]),
265 "recovered_message": np.array([]),
266 "recovered_distance": np.array([]),
267 }
269 latents = latents[:, :, :frames_to_use, ...]
271 idx_list, message_list, distance_list = [], [], []
272 message_length = self.message_sequence.shape[1]
273 message_sequence_str = [self.bits_to_string(msg) for msg in self.message_sequence]
275 for frame_index in range(frames_to_use):
276 frame_latents = latents[:, :, frame_index, ...].to(torch.float64)
277 reversed_prc = self._recover_posteriors(
278 frame_latents.flatten().cpu(),
279 variances=self.var
280 ).flatten().cpu()
281 aligned_prc = self._align_posteriors_length(reversed_prc)
282 self.recovered_prc = aligned_prc
284 detect_result = self._detect_watermark(aligned_prc)
285 decode_message = self._decode_message(aligned_prc)
286 if decode_message is None:
287 decode_message = np.zeros((message_length,), dtype=int)
288 decode_message_str = "0" * message_length
289 else:
290 decode_message = np.asarray(decode_message).astype(int)
291 decode_message_str = self.bits_to_string(decode_message)
293 distances = np.array([
294 2 * (hamming(decode_message_str, msg) / len(msg) - 0.5)
295 for msg in message_sequence_str
296 ])
297 min_distance = distances.min()
298 idx = -1 if not detect_result else distances.argmin()
300 message_list.append(decode_message)
301 distance_list.append(min_distance)
302 idx_list.append(idx)
304 recovered_index, recovered_message, recovered_distance = self.recover(
305 idx_list, message_list, distance_list, message_length
306 )
308 watermark_reference = np.asarray(self.watermark[:frames_to_use])
309 recovered_message = np.asarray(recovered_message)
311 if recovered_message.size == 0 or watermark_reference.size == 0:
312 bit_acc = 0.0
313 else:
314 frame_limit = min(recovered_message.shape[0], watermark_reference.shape[0])
315 if frame_limit <= 0:
316 bit_acc = 0.0
317 else:
318 bit_acc = np.mean(
319 recovered_message[:frame_limit] == watermark_reference[:frame_limit]
320 )
322 return {
323 'is_watermarked': float(bit_acc) >= self.threshold,
324 "bit_acc": float(bit_acc),
325 "recovered_index": recovered_index,
326 "recovered_message": recovered_message,
327 "recovered_distance": recovered_distance,
328 }