Coverage for evaluation / tools / success_rate_calculator.py: 98.20%
111 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 10:24 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 10:24 +0000
1from typing import List, Dict, Union
2from exceptions.exceptions import TypeMismatchException, ConfigurationError
3from sklearn.metrics import roc_auc_score, roc_curve
5class DetectionResult:
7 def __init__(self,
8 gold_label: bool,
9 detection_result: float,
10 ):
12 self.gold_label = gold_label
13 self.detection_result = detection_result
16class BaseSuccessRateCalculator:
18 def __init__(self,
19 labels: List[str] = ['TPR', 'TNR', 'FPR', 'FNR', 'F1', 'P', 'R', 'F1', 'ACC', 'AUC'],
20 ):
21 self.labels = labels
23 def _check_instance(self,
24 data: List[Union[bool, float]],
25 expected_type: type):
26 for item in data:
27 if not isinstance(item, expected_type):
28 raise TypeMismatchException(expected_type, type(item))
30 def _filter_metrics(self,
31 metrics: Dict[str, float]) -> Dict[str, float]:
32 return {label: metrics[label] for label in self.labels if label in metrics}
34 def calculate(self,
35 watermarked_results: List[DetectionResult],
36 non_watermarked_results: List[DetectionResult]) -> Dict[str, float]:
37 pass
39class FundamentalSuccessRateCalculator(BaseSuccessRateCalculator):
40 """
41 Calculator for fundamental success rates of watermark detection.
43 This class specifically handles the calculation of success rates for scenarios involving
44 watermark detection after fixed thresholding. It provides metrics based on comparisons
45 between expected watermarked results and actual detection outputs.
47 Use this class when you need to evaluate the effectiveness of watermark detection algorithms
48 under fixed thresholding conditions.
49 """
51 def __init__(self, labels: List[str] = ['TPR', 'TNR', 'FPR', 'FNR', 'P', 'R', 'F1', 'ACC']) -> None:
52 """
53 Initialize the fundamental success rate calculator.
55 Parameters:
56 labels (List[str]): The list of metric labels to include in the output.
57 """
58 super().__init__(labels)
60 def _compute_metrics(self, inputs: List[DetectionResult]) -> Dict[str, float]:
61 """Compute metrics based on the provided inputs."""
62 TP = sum(1 for d in inputs if d.detection_result and d.gold_label)
63 TN = sum(1 for d in inputs if not d.detection_result and not d.gold_label)
64 FP = sum(1 for d in inputs if d.detection_result and not d.gold_label)
65 FN = sum(1 for d in inputs if not d.detection_result and d.gold_label)
67 TPR = TP / (TP + FN) if TP + FN else 0.0
68 FPR = FP / (FP + TN) if FP + TN else 0.0
69 TNR = TN / (TN + FP) if TN + FP else 0.0
70 FNR = FN / (FN + TP) if FN + TP else 0.0
71 P = TP / (TP + FP) if TP + FP else 0.0
72 R = TP / (TP + FN) if TP + FN else 0.0
73 F1 = 2 * (P * R) / (P + R) if P + R else 0.0
74 ACC = (TP + TN) / (len(inputs)) if inputs else 0.0
76 # Calculate AUC
77 y_true = [x.gold_label for x in inputs]
78 y_score = [x.detection_result for x in inputs]
79 auc = roc_auc_score(y_true=y_true, y_score=y_score)
81 # Calculate FPR and TPR for ROC curve
82 fpr, tpr, _ = roc_curve(y_true=y_true, y_score=y_score)
84 return {
85 'TPR': TPR, 'TNR': TNR, 'FPR': FPR, 'FNR': FNR,
86 'P': P, 'R': R, 'F1': F1, 'ACC': ACC,
87 'AUC': auc,
88 'FPR_ROC': fpr, 'TPR_ROC': tpr
89 }
91 def calculate(self, watermarked_result: List[Union[bool, DetectionResult]], non_watermarked_result: List[Union[bool, DetectionResult]]) -> Dict[str, float]:
92 """calculate success rates of watermark detection based on provided results."""
94 # Convert input to DetectionResult objects if needed
95 if watermarked_result and isinstance(watermarked_result[0], bool):
96 self._check_instance(watermarked_result, bool)
97 inputs = [DetectionResult(True, x) for x in watermarked_result]
98 else:
99 # Assume they are DetectionResult objects
100 inputs = list(watermarked_result)
102 if non_watermarked_result and isinstance(non_watermarked_result[0], bool):
103 self._check_instance(non_watermarked_result, bool)
104 inputs.extend([DetectionResult(False, x) for x in non_watermarked_result])
105 else:
106 # Assume they are DetectionResult objects
107 inputs.extend(list(non_watermarked_result))
109 metrics = self._compute_metrics(inputs)
110 return self._filter_metrics(metrics)
113class DynamicThresholdSuccessRateCalculator(BaseSuccessRateCalculator):
115 def __init__(self,
116 labels: List[str] = ['TPR', 'TNR', 'FPR', 'FNR', 'F1', 'P', 'R', 'F1', 'ACC', 'AUC'],
117 rule: str = 'best',
118 target_fpr: float = None,
119 reverse: bool = False,
120 ):
121 super().__init__(labels)
122 self.rule = rule
123 self.target_fpr = target_fpr
124 self.reverse = reverse
126 if self.rule not in ['best', 'target_fpr']:
127 raise ConfigurationError(f"Invalid rule: {self.rule}")
129 if self.target_fpr is not None and not (0 <= self.target_fpr <= 1):
130 raise ConfigurationError(f"Invalid target_fpr: {self.target_fpr}")
132 def _compute_metrics(self,
133 inputs: List[DetectionResult],
134 threshold: float) -> Dict[str, float]:
135 if not self.reverse:
136 TP = sum(1 for x in inputs if x.gold_label and x.detection_result >= threshold)
137 FP = sum(1 for x in inputs if x.detection_result >= threshold and not x.gold_label)
138 TN = sum(1 for x in inputs if x.detection_result < threshold and not x.gold_label)
139 FN = sum(1 for x in inputs if x.detection_result < threshold and x.gold_label)
140 else:
141 TP = sum(1 for x in inputs if x.gold_label and x.detection_result <= threshold)
142 FP = sum(1 for x in inputs if x.detection_result <= threshold and not x.gold_label)
143 TN = sum(1 for x in inputs if x.detection_result > threshold and not x.gold_label)
144 FN = sum(1 for x in inputs if x.detection_result > threshold and x.gold_label)
146 # Calculate AUC
147 y_true = [1 if x.gold_label else 0 for x in inputs]
148 # print(inputs)
149 if not self.reverse:
150 y_score = [x.detection_result for x in inputs]
151 else:
152 y_score = [-x.detection_result for x in inputs]
153 auc = roc_auc_score(y_true=y_true, y_score=y_score)
155 # Get ROC curve
156 fpr, tpr, _ = roc_curve(y_true=y_true, y_score=y_score)
158 metrics = {
159 'TPR': TP / (TP + FN),
160 'TNR': TN / (TN + FP),
161 'FPR': FP / (TN + FP),
162 'FNR': FN / (TP + FN),
163 'P': TP / (TP + FP),
164 'R': TP / (TP + FN),
165 'F1': 2 * TP / (2 * TP + FP + FN),
166 'ACC': (TP + TN) / (TP + TN + FP + FN),
167 'AUC': auc,
168 'FPR_ROC': fpr,
169 'TPR_ROC': tpr,
170 }
171 return metrics
174 def _find_best_threshold(self, inputs: List[DetectionResult]) -> float:
175 best_threshold = 0
176 best_metrics = None
177 for i in range(len(inputs) - 1):
178 threshold = (inputs[i].detection_result + inputs[i + 1].detection_result) / 2
179 metrics = self._compute_metrics(inputs, threshold)
180 if best_metrics is None or metrics['F1'] > best_metrics['F1']:
181 best_threshold = threshold
182 best_metrics = metrics
183 return best_threshold
185 def _find_threshold_by_fpr(self, inputs: List[DetectionResult]) -> float:
187 threshold = 0
188 for i in range(len(inputs) - 1):
189 threshold = (inputs[i].detection_result + inputs[i + 1].detection_result) / 2
190 metrics = self._compute_metrics(inputs, threshold)
191 if metrics['FPR'] <= self.target_fpr:
192 break
193 return threshold
195 def _find_threshold(self, inputs: List[DetectionResult]) -> float:
197 sorted_inputs = sorted(inputs, key=lambda x: x.detection_result, reverse=self.reverse)
199 if self.rule == 'best':
200 return self._find_best_threshold(sorted_inputs)
201 elif self.rule == 'target_fpr':
202 return self._find_threshold_by_fpr(sorted_inputs)
204 def calculate(self,
205 watermarked_results: List[float],
206 non_watermarked_results: List[float]) -> Dict[str, float]:
207 # Check if inputs are boolean values (which suggests PRC or similar fixed-threshold algorithms)
208 if (watermarked_results and isinstance(watermarked_results[0], bool)) or \
209 (non_watermarked_results and isinstance(non_watermarked_results[0], bool)):
210 raise ValueError(
211 "DynamicThresholdSuccessRateCalculator received boolean values. "
212 "For algorithms like PRC that use fixed thresholds, please use "
213 "FundamentalSuccessRateCalculator instead."
214 )
216 self._check_instance(watermarked_results, float)
217 self._check_instance(non_watermarked_results, float)
219 inputs = [DetectionResult(True, d) for d in watermarked_results] + [DetectionResult(False, d) for d in non_watermarked_results]
221 threshold = self._find_threshold(inputs)
222 metrics = self._compute_metrics(inputs, threshold)
223 return self._filter_metrics(metrics)