Coverage for evaluation / pipelines / detection.py: 98.99%
99 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from watermark.base import BaseWatermark
2from evaluation.dataset import BaseDataset
3from tqdm import tqdm
4from enum import Enum, auto
5from PIL import Image
6from evaluation.tools.image_editor import ImageEditor
7from evaluation.tools.video_editor import VideoEditor
8from typing import List, Union
10class DetectionPipelineReturnType(Enum):
11 FULL = auto()
12 SCORES = auto()
13 IS_WATERMARKED = auto()
15class WatermarkDetectionResult:
17 def __init__(self,
18 generated_or_retrieved_media,
19 edited_media,
20 detect_result) -> None:
21 self.generated_or_retrieved_media = generated_or_retrieved_media
22 self.edited_media = edited_media
23 self.detect_result = detect_result
24 pass
26 def __str__(self):
27 return f"WatermarkDetectionResult(generated_or_retrieved_media={self.generated_or_retrieved_media}, edited_media={self.edited_media}, detect_result={self.detect_result})"
29class WatermarkDetectionPipeline:
30 def __init__(self,
31 dataset: BaseDataset,
32 media_editor_list: List[Union[ImageEditor, VideoEditor]] = [],
33 show_progress: bool = True,
34 detector_type: str = "l1_distance",
35 return_type: DetectionPipelineReturnType = DetectionPipelineReturnType.SCORES):
36 self.dataset = dataset
37 self.media_editor_list = media_editor_list
38 self.show_progress = show_progress
39 self.return_type = return_type
40 self.detector_type = detector_type
42 def _edit_media(self, media: List[Image.Image]) -> List[Image.Image]:
43 """
44 Edit the media using the media editor list.
46 Args:
47 media (List[Image.Image]): The media to edit.
49 Raises:
50 ValueError: If the editor type is not supported.
52 Returns:
53 List[Image.Image]: The edited media.
54 """
55 results = media
57 for editor in self.media_editor_list:
58 if isinstance(editor, ImageEditor):
59 for i in range(len(results)):
60 results[i] = editor.edit(results[i]) # return single edited image
61 elif isinstance(editor, VideoEditor):
62 results = editor.edit(results) # return a list of edited videos
63 else:
64 raise ValueError(f"Invalid media type: {type(media)}")
66 return results
68 def _detect_watermark(self, media:List[Image.Image], watermark: BaseWatermark, **kwargs):
69 detect_result = watermark.detect_watermark_in_media(media, detector_type=self.detector_type, **kwargs)
70 print(detect_result)
71 return detect_result
73 def _get_iterable(self):
74 pass
76 def _get_progress_bar(self, iterable):
77 if self.show_progress:
78 return tqdm(iterable, desc="Processing")
79 return iterable
81 def _generate_or_retrieve_media(self, index: int, watermark: BaseWatermark, **kwargs) -> List[Image.Image]:
82 pass
84 def evaluate(self, watermark: BaseWatermark, detection_kwargs={}, generation_kwargs={}):
85 evaluation_results = []
86 bar = self._get_progress_bar(self._get_iterable())
88 for index in bar:
89 generated_or_retrieved_media = self._generate_or_retrieve_media(index, watermark,**generation_kwargs)
90 edited_media = self._edit_media(generated_or_retrieved_media)
92 detect_result = self._detect_watermark(edited_media, watermark, **detection_kwargs)
93 evaluation_results.append(WatermarkDetectionResult(generated_or_retrieved_media, edited_media, detect_result))
95 if self.return_type == DetectionPipelineReturnType.FULL:
96 return evaluation_results
97 elif self.return_type == DetectionPipelineReturnType.SCORES:
98 return [result.detect_result[self.detector_type] for result in evaluation_results]
99 elif self.return_type == DetectionPipelineReturnType.IS_WATERMARKED:
100 return [result.detect_result['is_watermarked'] for result in evaluation_results]
102class WatermarkedMediaDetectionPipeline(WatermarkDetectionPipeline):
103 def __init__(self, dataset: BaseDataset, media_editor_list: List[Union[ImageEditor, VideoEditor]],
104 show_progress: bool = True,
105 detector_type: str = "l1_distance",
106 return_type: DetectionPipelineReturnType = DetectionPipelineReturnType.SCORES,
107 *args, **kwargs):
108 super().__init__(dataset, media_editor_list, show_progress, detector_type, return_type, *args, **kwargs)
110 def _get_iterable(self):
111 return range(self.dataset.num_samples)
113 def _generate_or_retrieve_media(self, index: int, watermark: BaseWatermark, **generation_kwargs):
114 prompt = self.dataset.get_prompt(index)
115 generated_media = watermark.generate_watermarked_media(input_data=prompt, **generation_kwargs)
116 if isinstance(generated_media, Image.Image):
117 return [generated_media]
118 elif isinstance(generated_media, list):
119 return generated_media
120 else:
121 raise ValueError(f"Invalid media type: {type(generated_media)}")
123class UnWatermarkedMediaDetectionPipeline(WatermarkDetectionPipeline):
124 def __init__(self, dataset: BaseDataset, media_editor_list: List[Union[ImageEditor, VideoEditor]], media_source_mode : str ="ground_truth",
125 show_progress: bool = True,
126 detector_type: str = "l1_distance",
127 return_type: DetectionPipelineReturnType = DetectionPipelineReturnType.SCORES,
128 *args, **kwargs):
129 super().__init__(dataset, media_editor_list, show_progress, detector_type, return_type, *args, **kwargs)
130 self.media_source_mode = media_source_mode
132 def _get_iterable(self):
133 if self.media_source_mode == "ground_truth":
134 assert self.dataset.num_references != 0, "This dataset does not have ground truth images or videos"
135 return range(self.dataset.num_references)
136 elif self.media_source_mode == "generated":
137 return range(self.dataset.num_samples)
138 else:
139 raise ValueError(f"Invalid media source mode: {self.media_source_mode}")
141 def _generate_or_retrieve_media(self, index: int, watermark: BaseWatermark, **generation_kwargs):
142 if self.media_source_mode == "ground_truth":
143 return [self.dataset.get_reference(index)]
144 elif self.media_source_mode == "generated":
145 prompt = self.dataset.get_prompt(index)
146 generated_media = watermark.generate_unwatermarked_media(input_data=prompt, **generation_kwargs)
147 if isinstance(generated_media, Image.Image):
148 return [generated_media]
149 elif isinstance(generated_media, list):
150 return generated_media
151 else:
152 raise ValueError(f"Invalid media type: {type(generated_media)}")
153 else:
154 raise ValueError(f"Invalid media source mode: {self.media_source_mode}")