Coverage for evaluation / pipelines / detection.py: 98.99%

99 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from watermark.base import BaseWatermark 

2from evaluation.dataset import BaseDataset 

3from tqdm import tqdm 

4from enum import Enum, auto 

5from PIL import Image 

6from evaluation.tools.image_editor import ImageEditor 

7from evaluation.tools.video_editor import VideoEditor 

8from typing import List, Union 

9 

10class DetectionPipelineReturnType(Enum): 

11 FULL = auto() 

12 SCORES = auto() 

13 IS_WATERMARKED = auto() 

14 

15class WatermarkDetectionResult: 

16 

17 def __init__(self, 

18 generated_or_retrieved_media, 

19 edited_media, 

20 detect_result) -> None: 

21 self.generated_or_retrieved_media = generated_or_retrieved_media 

22 self.edited_media = edited_media 

23 self.detect_result = detect_result 

24 pass 

25 

26 def __str__(self): 

27 return f"WatermarkDetectionResult(generated_or_retrieved_media={self.generated_or_retrieved_media}, edited_media={self.edited_media}, detect_result={self.detect_result})" 

28 

29class WatermarkDetectionPipeline: 

30 def __init__(self, 

31 dataset: BaseDataset, 

32 media_editor_list: List[Union[ImageEditor, VideoEditor]] = [], 

33 show_progress: bool = True, 

34 detector_type: str = "l1_distance", 

35 return_type: DetectionPipelineReturnType = DetectionPipelineReturnType.SCORES): 

36 self.dataset = dataset 

37 self.media_editor_list = media_editor_list 

38 self.show_progress = show_progress 

39 self.return_type = return_type 

40 self.detector_type = detector_type 

41 

42 def _edit_media(self, media: List[Image.Image]) -> List[Image.Image]: 

43 """ 

44 Edit the media using the media editor list. 

45 

46 Args: 

47 media (List[Image.Image]): The media to edit. 

48 

49 Raises: 

50 ValueError: If the editor type is not supported. 

51 

52 Returns: 

53 List[Image.Image]: The edited media. 

54 """ 

55 results = media 

56 

57 for editor in self.media_editor_list: 

58 if isinstance(editor, ImageEditor): 

59 for i in range(len(results)): 

60 results[i] = editor.edit(results[i]) # return single edited image 

61 elif isinstance(editor, VideoEditor): 

62 results = editor.edit(results) # return a list of edited videos 

63 else: 

64 raise ValueError(f"Invalid media type: {type(media)}") 

65 

66 return results 

67 

68 def _detect_watermark(self, media:List[Image.Image], watermark: BaseWatermark, **kwargs): 

69 detect_result = watermark.detect_watermark_in_media(media, detector_type=self.detector_type, **kwargs) 

70 print(detect_result) 

71 return detect_result 

72 

73 def _get_iterable(self): 

74 pass 

75 

76 def _get_progress_bar(self, iterable): 

77 if self.show_progress: 

78 return tqdm(iterable, desc="Processing") 

79 return iterable 

80 

81 def _generate_or_retrieve_media(self, index: int, watermark: BaseWatermark, **kwargs) -> List[Image.Image]: 

82 pass 

83 

84 def evaluate(self, watermark: BaseWatermark, detection_kwargs={}, generation_kwargs={}): 

85 evaluation_results = [] 

86 bar = self._get_progress_bar(self._get_iterable()) 

87 

88 for index in bar: 

89 generated_or_retrieved_media = self._generate_or_retrieve_media(index, watermark,**generation_kwargs) 

90 edited_media = self._edit_media(generated_or_retrieved_media) 

91 

92 detect_result = self._detect_watermark(edited_media, watermark, **detection_kwargs) 

93 evaluation_results.append(WatermarkDetectionResult(generated_or_retrieved_media, edited_media, detect_result)) 

94 

95 if self.return_type == DetectionPipelineReturnType.FULL: 

96 return evaluation_results 

97 elif self.return_type == DetectionPipelineReturnType.SCORES: 

98 return [result.detect_result[self.detector_type] for result in evaluation_results] 

99 elif self.return_type == DetectionPipelineReturnType.IS_WATERMARKED: 

100 return [result.detect_result['is_watermarked'] for result in evaluation_results] 

101 

102class WatermarkedMediaDetectionPipeline(WatermarkDetectionPipeline): 

103 def __init__(self, dataset: BaseDataset, media_editor_list: List[Union[ImageEditor, VideoEditor]], 

104 show_progress: bool = True, 

105 detector_type: str = "l1_distance", 

106 return_type: DetectionPipelineReturnType = DetectionPipelineReturnType.SCORES, 

107 *args, **kwargs): 

108 super().__init__(dataset, media_editor_list, show_progress, detector_type, return_type, *args, **kwargs) 

109 

110 def _get_iterable(self): 

111 return range(self.dataset.num_samples) 

112 

113 def _generate_or_retrieve_media(self, index: int, watermark: BaseWatermark, **generation_kwargs): 

114 prompt = self.dataset.get_prompt(index) 

115 generated_media = watermark.generate_watermarked_media(input_data=prompt, **generation_kwargs) 

116 if isinstance(generated_media, Image.Image): 

117 return [generated_media] 

118 elif isinstance(generated_media, list): 

119 return generated_media 

120 else: 

121 raise ValueError(f"Invalid media type: {type(generated_media)}") 

122 

123class UnWatermarkedMediaDetectionPipeline(WatermarkDetectionPipeline): 

124 def __init__(self, dataset: BaseDataset, media_editor_list: List[Union[ImageEditor, VideoEditor]], media_source_mode : str ="ground_truth", 

125 show_progress: bool = True, 

126 detector_type: str = "l1_distance", 

127 return_type: DetectionPipelineReturnType = DetectionPipelineReturnType.SCORES, 

128 *args, **kwargs): 

129 super().__init__(dataset, media_editor_list, show_progress, detector_type, return_type, *args, **kwargs) 

130 self.media_source_mode = media_source_mode 

131 

132 def _get_iterable(self): 

133 if self.media_source_mode == "ground_truth": 

134 assert self.dataset.num_references != 0, "This dataset does not have ground truth images or videos" 

135 return range(self.dataset.num_references) 

136 elif self.media_source_mode == "generated": 

137 return range(self.dataset.num_samples) 

138 else: 

139 raise ValueError(f"Invalid media source mode: {self.media_source_mode}") 

140 

141 def _generate_or_retrieve_media(self, index: int, watermark: BaseWatermark, **generation_kwargs): 

142 if self.media_source_mode == "ground_truth": 

143 return [self.dataset.get_reference(index)] 

144 elif self.media_source_mode == "generated": 

145 prompt = self.dataset.get_prompt(index) 

146 generated_media = watermark.generate_unwatermarked_media(input_data=prompt, **generation_kwargs) 

147 if isinstance(generated_media, Image.Image): 

148 return [generated_media] 

149 elif isinstance(generated_media, list): 

150 return generated_media 

151 else: 

152 raise ValueError(f"Invalid media type: {type(generated_media)}") 

153 else: 

154 raise ValueError(f"Invalid media source mode: {self.media_source_mode}") 

155 

156 

157