Coverage for detection / prc / prc_detection.py: 97.85%
93 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16import torch
17import numpy as np
18from typing import Tuple, Type
19from scipy.special import erf
20from detection.base import BaseDetector
21from ldpc import bp_decoder
22from galois import FieldArray
24class PRCDetector(BaseDetector):
26 def __init__(self,
27 var: int,
28 decoding_key: tuple,
29 GF: Type[FieldArray],
30 threshold: float,
31 device: torch.device):
32 super().__init__(threshold, device)
34 self.var = var
35 self.decoding_key = decoding_key
36 self.GF = GF
38 def _recover_posteriors(self, z, basis=None, variances=None):
39 if variances is None:
40 default_variance = 1.5
41 denominators = np.sqrt(2 * default_variance * (1+default_variance)) * torch.ones_like(z)
42 elif type(variances) is float:
43 denominators = np.sqrt(2 * variances * (1 + variances))
44 else:
45 denominators = torch.sqrt(2 * variances * (1 + variances))
47 if basis is None:
48 return erf(z / denominators)
49 else:
50 return erf((z @ basis) / denominators)
52 def _detect_watermark(self, posteriors: torch.Tensor) -> Tuple[bool, float]:
53 """Detect watermark in posteriors."""
54 generator_matrix, parity_check_matrix, one_time_pad, false_positive_rate, noise_rate, test_bits, g, max_bp_iter, t = self.decoding_key
55 posteriors = (1 - 2 * noise_rate) * (1 - 2 * np.array(one_time_pad, dtype=float)) * posteriors.numpy(force=True)
57 r = parity_check_matrix.shape[0]
58 Pi = np.prod(posteriors[parity_check_matrix.indices.reshape(r, t)], axis=1)
59 log_plus = np.log((1 + Pi) / 2)
60 log_minus = np.log((1 - Pi) / 2)
61 log_prod = log_plus + log_minus
63 const = 0.5 * np.sum(np.power(log_plus, 2) + np.power(log_minus, 2) - 0.5 * np.power(log_prod, 2))
64 threshold = np.sqrt(2 * const * np.log(1 / false_positive_rate)) + 0.5 * log_prod.sum()
65 #print(f"threshold: {threshold}")
66 return log_plus.sum() >= threshold, log_plus.sum()
68 def _boolean_row_reduce(self, A, print_progress=False):
69 """Given a GF(2) matrix, do row elimination and return the first k rows of A that form an invertible matrix
71 Args:
72 A (np.ndarray): A GF(2) matrix
73 print_progress (bool, optional): Whether to print the progress. Defaults to False.
75 Returns:
76 np.ndarray: The first k rows of A that form an invertible matrix
77 """
78 n, k = A.shape
79 A_rr = A.copy()
80 perm = np.arange(n)
81 for j in range(k):
82 idxs = j + np.nonzero(A_rr[j:, j])[0]
83 if idxs.size == 0:
84 print("The given matrix is not invertible")
85 return None
86 A_rr[[j, idxs[0]]] = A_rr[[idxs[0], j]] # For matrices you have to swap them this way
87 (perm[j], perm[idxs[0]]) = (perm[idxs[0]], perm[j]) # Weirdly, this is MUCH faster if you swap this way instead of using perm[[i,j]]=perm[[j,i]]
88 A_rr[idxs[1:]] += A_rr[j]
89 if print_progress and (j%5==0 or j+1==k):
90 sys.stdout.write(f'\rDecoding progress: {j + 1} / {k}')
91 sys.stdout.flush()
92 if print_progress: print()
93 return perm[:k]
95 def _decode_message(self, posteriors, print_progress=False, max_bp_iter=None):
96 generator_matrix, parity_check_matrix, one_time_pad, false_positive_rate_key, noise_rate, test_bits, g, max_bp_iter_key, t = self.decoding_key
97 if max_bp_iter is None:
98 max_bp_iter = max_bp_iter_key
100 posteriors = (1 - 2 * noise_rate) * (1 - 2 * np.array(one_time_pad, dtype=float)) * posteriors.numpy(force=True)
101 channel_probs = (1 - np.abs(posteriors)) / 2
102 x_recovered = (1 - np.sign(posteriors)) // 2
105 # Apply the belief-propagation decoder.
106 if print_progress:
107 print("Running belief propagation...")
108 bpd = bp_decoder(parity_check_matrix, channel_probs=channel_probs, max_iter=max_bp_iter, bp_method="product_sum")
109 x_decoded = bpd.decode(x_recovered)
111 # Compute a confidence score.
112 bpd_probs = 1 / (1 + np.exp(bpd.log_prob_ratios))
113 confidences = 2 * np.abs(0.5 - bpd_probs)
115 # Order codeword bits by confidence.
116 confidence_order = np.argsort(-confidences)
117 ordered_generator_matrix = generator_matrix[confidence_order]
118 ordered_x_decoded = x_decoded[confidence_order].astype(int)
120 # Find the first (according to the confidence order) linearly independent set of rows of the generator matrix.
121 top_invertible_rows = self._boolean_row_reduce(ordered_generator_matrix, print_progress=print_progress)
122 if top_invertible_rows is None:
123 return None
125 # Solve the system.
126 if print_progress:
127 print("Solving linear system...")
128 recovered_string = np.linalg.solve(ordered_generator_matrix[top_invertible_rows], self.GF(ordered_x_decoded[top_invertible_rows]))
130 if not (recovered_string[:len(test_bits)] == test_bits).all():
131 return None
132 return np.array(recovered_string[len(test_bits) + g:])
134 def _binary_array_to_str(self, binary_array: np.ndarray) -> str:
135 """Convert binary array back to string."""
136 # Ensure the binary array length is divisible by 8 (1 byte = 8 bits)
137 assert len(binary_array) % 8 == 0, "Binary array length must be a multiple of 8"
139 # Group the binary array into chunks of 8 bits
140 byte_chunks = binary_array.reshape(-1, 8)
142 # Convert each byte (8 bits) to a character
143 chars = [chr(int(''.join(map(str, byte)), 2)) for byte in byte_chunks]
145 # Join the characters to form the original string
146 return ''.join(chars)
148 def eval_watermark(self, reversed_latents: torch.Tensor, reference_latents: torch.Tensor = None, detector_type: str = "is_watermarked") -> float:
149 """Evaluate watermark in reversed latents."""
150 if detector_type != 'is_watermarked':
151 raise ValueError(f'Detector type {detector_type} is not supported for PRC. Use "is_watermarked" instead.')
152 reversed_prc = self._recover_posteriors(reversed_latents.to(torch.float64).flatten().cpu(), variances=self.var).flatten().cpu()
153 self.recovered_prc = reversed_prc
154 detect_result, score = self._detect_watermark(reversed_prc)
155 decoding_result = self._decode_message(reversed_prc)
156 if decoding_result is None:
157 return {
158 'is_watermarked': False,
159 "score": score, # Keep the score for potential future use
160 'decoding_result': decoding_result,
161 "decoded_message": None
162 }
163 decoded_message = self._binary_array_to_str(decoding_result)
164 combined_result = detect_result or (decoding_result is not None)
165 #print(f"detection_result: {detect_result}, decoding_result: {decoding_result}, combined_result: {combined_result}")
166 return {
167 'is_watermarked': bool(combined_result),
168 "score": score, # Keep the score for potential future use
169 'decoding_result': decoding_result,
170 "decoded_message": decoded_message
171 }