Coverage for evaluation / tools / success_rate_calculator.py: 98.20%

111 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 10:24 +0000

1from typing import List, Dict, Union 

2from exceptions.exceptions import TypeMismatchException, ConfigurationError 

3from sklearn.metrics import roc_auc_score, roc_curve 

4 

5class DetectionResult: 

6 

7 def __init__(self, 

8 gold_label: bool, 

9 detection_result: float, 

10 ): 

11 

12 self.gold_label = gold_label 

13 self.detection_result = detection_result 

14 

15 

16class BaseSuccessRateCalculator: 

17 

18 def __init__(self, 

19 labels: List[str] = ['TPR', 'TNR', 'FPR', 'FNR', 'F1', 'P', 'R', 'F1', 'ACC', 'AUC'], 

20 ): 

21 self.labels = labels 

22 

23 def _check_instance(self, 

24 data: List[Union[bool, float]], 

25 expected_type: type): 

26 for item in data: 

27 if not isinstance(item, expected_type): 

28 raise TypeMismatchException(expected_type, type(item)) 

29 

30 def _filter_metrics(self, 

31 metrics: Dict[str, float]) -> Dict[str, float]: 

32 return {label: metrics[label] for label in self.labels if label in metrics} 

33 

34 def calculate(self, 

35 watermarked_results: List[DetectionResult], 

36 non_watermarked_results: List[DetectionResult]) -> Dict[str, float]: 

37 pass 

38 

39class FundamentalSuccessRateCalculator(BaseSuccessRateCalculator): 

40 """ 

41 Calculator for fundamental success rates of watermark detection. 

42 

43 This class specifically handles the calculation of success rates for scenarios involving 

44 watermark detection after fixed thresholding. It provides metrics based on comparisons 

45 between expected watermarked results and actual detection outputs. 

46 

47 Use this class when you need to evaluate the effectiveness of watermark detection algorithms 

48 under fixed thresholding conditions. 

49 """ 

50 

51 def __init__(self, labels: List[str] = ['TPR', 'TNR', 'FPR', 'FNR', 'P', 'R', 'F1', 'ACC']) -> None: 

52 """ 

53 Initialize the fundamental success rate calculator. 

54 

55 Parameters: 

56 labels (List[str]): The list of metric labels to include in the output. 

57 """ 

58 super().__init__(labels) 

59 

60 def _compute_metrics(self, inputs: List[DetectionResult]) -> Dict[str, float]: 

61 """Compute metrics based on the provided inputs.""" 

62 TP = sum(1 for d in inputs if d.detection_result and d.gold_label) 

63 TN = sum(1 for d in inputs if not d.detection_result and not d.gold_label) 

64 FP = sum(1 for d in inputs if d.detection_result and not d.gold_label) 

65 FN = sum(1 for d in inputs if not d.detection_result and d.gold_label) 

66 

67 TPR = TP / (TP + FN) if TP + FN else 0.0 

68 FPR = FP / (FP + TN) if FP + TN else 0.0 

69 TNR = TN / (TN + FP) if TN + FP else 0.0 

70 FNR = FN / (FN + TP) if FN + TP else 0.0 

71 P = TP / (TP + FP) if TP + FP else 0.0 

72 R = TP / (TP + FN) if TP + FN else 0.0 

73 F1 = 2 * (P * R) / (P + R) if P + R else 0.0 

74 ACC = (TP + TN) / (len(inputs)) if inputs else 0.0 

75 

76 # Calculate AUC 

77 y_true = [x.gold_label for x in inputs] 

78 y_score = [x.detection_result for x in inputs] 

79 auc = roc_auc_score(y_true=y_true, y_score=y_score) 

80 

81 # Calculate FPR and TPR for ROC curve 

82 fpr, tpr, _ = roc_curve(y_true=y_true, y_score=y_score) 

83 

84 return { 

85 'TPR': TPR, 'TNR': TNR, 'FPR': FPR, 'FNR': FNR, 

86 'P': P, 'R': R, 'F1': F1, 'ACC': ACC, 

87 'AUC': auc, 

88 'FPR_ROC': fpr, 'TPR_ROC': tpr 

89 } 

90 

91 def calculate(self, watermarked_result: List[Union[bool, DetectionResult]], non_watermarked_result: List[Union[bool, DetectionResult]]) -> Dict[str, float]: 

92 """calculate success rates of watermark detection based on provided results.""" 

93 

94 # Convert input to DetectionResult objects if needed 

95 if watermarked_result and isinstance(watermarked_result[0], bool): 

96 self._check_instance(watermarked_result, bool) 

97 inputs = [DetectionResult(True, x) for x in watermarked_result] 

98 else: 

99 # Assume they are DetectionResult objects 

100 inputs = list(watermarked_result) 

101 

102 if non_watermarked_result and isinstance(non_watermarked_result[0], bool): 

103 self._check_instance(non_watermarked_result, bool) 

104 inputs.extend([DetectionResult(False, x) for x in non_watermarked_result]) 

105 else: 

106 # Assume they are DetectionResult objects 

107 inputs.extend(list(non_watermarked_result)) 

108 

109 metrics = self._compute_metrics(inputs) 

110 return self._filter_metrics(metrics) 

111 

112 

113class DynamicThresholdSuccessRateCalculator(BaseSuccessRateCalculator): 

114 

115 def __init__(self, 

116 labels: List[str] = ['TPR', 'TNR', 'FPR', 'FNR', 'F1', 'P', 'R', 'F1', 'ACC', 'AUC'], 

117 rule: str = 'best', 

118 target_fpr: float = None, 

119 reverse: bool = False, 

120 ): 

121 super().__init__(labels) 

122 self.rule = rule 

123 self.target_fpr = target_fpr 

124 self.reverse = reverse 

125 

126 if self.rule not in ['best', 'target_fpr']: 

127 raise ConfigurationError(f"Invalid rule: {self.rule}") 

128 

129 if self.target_fpr is not None and not (0 <= self.target_fpr <= 1): 

130 raise ConfigurationError(f"Invalid target_fpr: {self.target_fpr}") 

131 

132 def _compute_metrics(self, 

133 inputs: List[DetectionResult], 

134 threshold: float) -> Dict[str, float]: 

135 if not self.reverse: 

136 TP = sum(1 for x in inputs if x.gold_label and x.detection_result >= threshold) 

137 FP = sum(1 for x in inputs if x.detection_result >= threshold and not x.gold_label) 

138 TN = sum(1 for x in inputs if x.detection_result < threshold and not x.gold_label) 

139 FN = sum(1 for x in inputs if x.detection_result < threshold and x.gold_label) 

140 else: 

141 TP = sum(1 for x in inputs if x.gold_label and x.detection_result <= threshold) 

142 FP = sum(1 for x in inputs if x.detection_result <= threshold and not x.gold_label) 

143 TN = sum(1 for x in inputs if x.detection_result > threshold and not x.gold_label) 

144 FN = sum(1 for x in inputs if x.detection_result > threshold and x.gold_label) 

145 

146 # Calculate AUC 

147 y_true = [1 if x.gold_label else 0 for x in inputs] 

148 # print(inputs) 

149 if not self.reverse: 

150 y_score = [x.detection_result for x in inputs] 

151 else: 

152 y_score = [-x.detection_result for x in inputs] 

153 auc = roc_auc_score(y_true=y_true, y_score=y_score) 

154 

155 # Get ROC curve 

156 fpr, tpr, _ = roc_curve(y_true=y_true, y_score=y_score) 

157 

158 metrics = { 

159 'TPR': TP / (TP + FN), 

160 'TNR': TN / (TN + FP), 

161 'FPR': FP / (TN + FP), 

162 'FNR': FN / (TP + FN), 

163 'P': TP / (TP + FP), 

164 'R': TP / (TP + FN), 

165 'F1': 2 * TP / (2 * TP + FP + FN), 

166 'ACC': (TP + TN) / (TP + TN + FP + FN), 

167 'AUC': auc, 

168 'FPR_ROC': fpr, 

169 'TPR_ROC': tpr, 

170 } 

171 return metrics 

172 

173 

174 def _find_best_threshold(self, inputs: List[DetectionResult]) -> float: 

175 best_threshold = 0 

176 best_metrics = None 

177 for i in range(len(inputs) - 1): 

178 threshold = (inputs[i].detection_result + inputs[i + 1].detection_result) / 2 

179 metrics = self._compute_metrics(inputs, threshold) 

180 if best_metrics is None or metrics['F1'] > best_metrics['F1']: 

181 best_threshold = threshold 

182 best_metrics = metrics 

183 return best_threshold 

184 

185 def _find_threshold_by_fpr(self, inputs: List[DetectionResult]) -> float: 

186 

187 threshold = 0 

188 for i in range(len(inputs) - 1): 

189 threshold = (inputs[i].detection_result + inputs[i + 1].detection_result) / 2 

190 metrics = self._compute_metrics(inputs, threshold) 

191 if metrics['FPR'] <= self.target_fpr: 

192 break 

193 return threshold 

194 

195 def _find_threshold(self, inputs: List[DetectionResult]) -> float: 

196 

197 sorted_inputs = sorted(inputs, key=lambda x: x.detection_result, reverse=self.reverse) 

198 

199 if self.rule == 'best': 

200 return self._find_best_threshold(sorted_inputs) 

201 elif self.rule == 'target_fpr': 

202 return self._find_threshold_by_fpr(sorted_inputs) 

203 

204 def calculate(self, 

205 watermarked_results: List[float], 

206 non_watermarked_results: List[float]) -> Dict[str, float]: 

207 # Check if inputs are boolean values (which suggests PRC or similar fixed-threshold algorithms) 

208 if (watermarked_results and isinstance(watermarked_results[0], bool)) or \ 

209 (non_watermarked_results and isinstance(non_watermarked_results[0], bool)): 

210 raise ValueError( 

211 "DynamicThresholdSuccessRateCalculator received boolean values. " 

212 "For algorithms like PRC that use fixed thresholds, please use " 

213 "FundamentalSuccessRateCalculator instead." 

214 ) 

215 

216 self._check_instance(watermarked_results, float) 

217 self._check_instance(non_watermarked_results, float) 

218 

219 inputs = [DetectionResult(True, d) for d in watermarked_results] + [DetectionResult(False, d) for d in non_watermarked_results] 

220 

221 threshold = self._find_threshold(inputs) 

222 metrics = self._compute_metrics(inputs, threshold) 

223 return self._filter_metrics(metrics) 

224 

225