Coverage for detection / prc / prc_detection.py: 97.85%

93 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16import torch 

17import numpy as np 

18from typing import Tuple, Type 

19from scipy.special import erf 

20from detection.base import BaseDetector 

21from ldpc import bp_decoder 

22from galois import FieldArray 

23 

24class PRCDetector(BaseDetector): 

25 

26 def __init__(self, 

27 var: int, 

28 decoding_key: tuple, 

29 GF: Type[FieldArray], 

30 threshold: float, 

31 device: torch.device): 

32 super().__init__(threshold, device) 

33 

34 self.var = var 

35 self.decoding_key = decoding_key 

36 self.GF = GF 

37 

38 def _recover_posteriors(self, z, basis=None, variances=None): 

39 if variances is None: 

40 default_variance = 1.5 

41 denominators = np.sqrt(2 * default_variance * (1+default_variance)) * torch.ones_like(z) 

42 elif type(variances) is float: 

43 denominators = np.sqrt(2 * variances * (1 + variances)) 

44 else: 

45 denominators = torch.sqrt(2 * variances * (1 + variances)) 

46 

47 if basis is None: 

48 return erf(z / denominators) 

49 else: 

50 return erf((z @ basis) / denominators) 

51 

52 def _detect_watermark(self, posteriors: torch.Tensor) -> Tuple[bool, float]: 

53 """Detect watermark in posteriors.""" 

54 generator_matrix, parity_check_matrix, one_time_pad, false_positive_rate, noise_rate, test_bits, g, max_bp_iter, t = self.decoding_key 

55 posteriors = (1 - 2 * noise_rate) * (1 - 2 * np.array(one_time_pad, dtype=float)) * posteriors.numpy(force=True) 

56 

57 r = parity_check_matrix.shape[0] 

58 Pi = np.prod(posteriors[parity_check_matrix.indices.reshape(r, t)], axis=1) 

59 log_plus = np.log((1 + Pi) / 2) 

60 log_minus = np.log((1 - Pi) / 2) 

61 log_prod = log_plus + log_minus 

62 

63 const = 0.5 * np.sum(np.power(log_plus, 2) + np.power(log_minus, 2) - 0.5 * np.power(log_prod, 2)) 

64 threshold = np.sqrt(2 * const * np.log(1 / false_positive_rate)) + 0.5 * log_prod.sum() 

65 #print(f"threshold: {threshold}") 

66 return log_plus.sum() >= threshold, log_plus.sum() 

67 

68 def _boolean_row_reduce(self, A, print_progress=False): 

69 """Given a GF(2) matrix, do row elimination and return the first k rows of A that form an invertible matrix 

70 

71 Args: 

72 A (np.ndarray): A GF(2) matrix 

73 print_progress (bool, optional): Whether to print the progress. Defaults to False. 

74 

75 Returns: 

76 np.ndarray: The first k rows of A that form an invertible matrix 

77 """ 

78 n, k = A.shape 

79 A_rr = A.copy() 

80 perm = np.arange(n) 

81 for j in range(k): 

82 idxs = j + np.nonzero(A_rr[j:, j])[0] 

83 if idxs.size == 0: 

84 print("The given matrix is not invertible") 

85 return None 

86 A_rr[[j, idxs[0]]] = A_rr[[idxs[0], j]] # For matrices you have to swap them this way 

87 (perm[j], perm[idxs[0]]) = (perm[idxs[0]], perm[j]) # Weirdly, this is MUCH faster if you swap this way instead of using perm[[i,j]]=perm[[j,i]] 

88 A_rr[idxs[1:]] += A_rr[j] 

89 if print_progress and (j%5==0 or j+1==k): 

90 sys.stdout.write(f'\rDecoding progress: {j + 1} / {k}') 

91 sys.stdout.flush() 

92 if print_progress: print() 

93 return perm[:k] 

94 

95 def _decode_message(self, posteriors, print_progress=False, max_bp_iter=None): 

96 generator_matrix, parity_check_matrix, one_time_pad, false_positive_rate_key, noise_rate, test_bits, g, max_bp_iter_key, t = self.decoding_key 

97 if max_bp_iter is None: 

98 max_bp_iter = max_bp_iter_key 

99 

100 posteriors = (1 - 2 * noise_rate) * (1 - 2 * np.array(one_time_pad, dtype=float)) * posteriors.numpy(force=True) 

101 channel_probs = (1 - np.abs(posteriors)) / 2 

102 x_recovered = (1 - np.sign(posteriors)) // 2 

103 

104 

105 # Apply the belief-propagation decoder. 

106 if print_progress: 

107 print("Running belief propagation...") 

108 bpd = bp_decoder(parity_check_matrix, channel_probs=channel_probs, max_iter=max_bp_iter, bp_method="product_sum") 

109 x_decoded = bpd.decode(x_recovered) 

110 

111 # Compute a confidence score. 

112 bpd_probs = 1 / (1 + np.exp(bpd.log_prob_ratios)) 

113 confidences = 2 * np.abs(0.5 - bpd_probs) 

114 

115 # Order codeword bits by confidence. 

116 confidence_order = np.argsort(-confidences) 

117 ordered_generator_matrix = generator_matrix[confidence_order] 

118 ordered_x_decoded = x_decoded[confidence_order].astype(int) 

119 

120 # Find the first (according to the confidence order) linearly independent set of rows of the generator matrix. 

121 top_invertible_rows = self._boolean_row_reduce(ordered_generator_matrix, print_progress=print_progress) 

122 if top_invertible_rows is None: 

123 return None 

124 

125 # Solve the system. 

126 if print_progress: 

127 print("Solving linear system...") 

128 recovered_string = np.linalg.solve(ordered_generator_matrix[top_invertible_rows], self.GF(ordered_x_decoded[top_invertible_rows])) 

129 

130 if not (recovered_string[:len(test_bits)] == test_bits).all(): 

131 return None 

132 return np.array(recovered_string[len(test_bits) + g:]) 

133 

134 def _binary_array_to_str(self, binary_array: np.ndarray) -> str: 

135 """Convert binary array back to string.""" 

136 # Ensure the binary array length is divisible by 8 (1 byte = 8 bits) 

137 assert len(binary_array) % 8 == 0, "Binary array length must be a multiple of 8" 

138 

139 # Group the binary array into chunks of 8 bits 

140 byte_chunks = binary_array.reshape(-1, 8) 

141 

142 # Convert each byte (8 bits) to a character 

143 chars = [chr(int(''.join(map(str, byte)), 2)) for byte in byte_chunks] 

144 

145 # Join the characters to form the original string 

146 return ''.join(chars) 

147 

148 def eval_watermark(self, reversed_latents: torch.Tensor, reference_latents: torch.Tensor = None, detector_type: str = "is_watermarked") -> float: 

149 """Evaluate watermark in reversed latents.""" 

150 if detector_type != 'is_watermarked': 

151 raise ValueError(f'Detector type {detector_type} is not supported for PRC. Use "is_watermarked" instead.') 

152 reversed_prc = self._recover_posteriors(reversed_latents.to(torch.float64).flatten().cpu(), variances=self.var).flatten().cpu() 

153 self.recovered_prc = reversed_prc 

154 detect_result, score = self._detect_watermark(reversed_prc) 

155 decoding_result = self._decode_message(reversed_prc) 

156 if decoding_result is None: 

157 return { 

158 'is_watermarked': False, 

159 "score": score, # Keep the score for potential future use 

160 'decoding_result': decoding_result, 

161 "decoded_message": None 

162 } 

163 decoded_message = self._binary_array_to_str(decoding_result) 

164 combined_result = detect_result or (decoding_result is not None) 

165 #print(f"detection_result: {detect_result}, decoding_result: {decoding_result}, combined_result: {combined_result}") 

166 return { 

167 'is_watermarked': bool(combined_result), 

168 "score": score, # Keep the score for potential future use 

169 'decoding_result': decoding_result, 

170 "decoded_message": decoded_message 

171 } 

172 

173