Coverage for evaluation / pipelines / video_quality_analysis.py: 98.34%
181 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 10:24 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 10:24 +0000
1from enum import Enum, auto
2from dataclasses import dataclass, field
3from typing import List, Union, Dict, Tuple
4from PIL import Image
5from evaluation.tools.image_editor import ImageEditor
6from evaluation.tools.video_editor import VideoEditor
7from evaluation.tools.video_quality_analyzer import VideoQualityAnalyzer
8from evaluation.dataset import BaseDataset
9from watermark.base import BaseWatermark
10import os
11import numpy as np
12from tqdm import tqdm
15class SilentProgressBar:
16 """A silent progress bar wrapper that supports set_description but shows no output."""
18 def __init__(self, iterable):
19 self.iterable = iterable
21 def __iter__(self):
22 return iter(self.iterable)
24 def set_description(self, desc):
25 """No-op for silent mode."""
26 pass
28class QualityPipelineReturnType(Enum):
29 """Return type of the image quality analysis pipeline."""
30 FULL = auto()
31 SCORES = auto()
32 MEAN_SCORES = auto()
34@dataclass
35class DatasetForEvaluation:
36 """Dataset for evaluation."""
37 watermarked_videos: List[List[Image.Image]] = field(default_factory=list)
38 unwatermarked_videos: List[List[Image.Image]] = field(default_factory=list)
39 reference_videos: List[List[Image.Image]] = field(default_factory=list)
40 indexes: List[int] = field(default_factory=list)
42class QualityComparisonResult:
43 """Result of quality comparison."""
45 def __init__(self,
46 store_path: str,
47 watermarked_quality_scores: Dict[str, List[float]],
48 unwatermarked_quality_scores: Dict[str, List[float]],
49 prompts: List[str],
50 ) -> None:
51 """
52 Initialize the image quality comparison result.
54 Parameters:
55 store_path: The path to store the results.
56 watermarked_quality_scores: The quality scores of the watermarked image.
57 unwatermarked_quality_scores: The quality scores of the unwatermarked image.
58 prompts: The prompts used to generate the images.
59 """
60 self.store_path = store_path
61 self.watermarked_quality_scores = watermarked_quality_scores
62 self.unwatermarked_quality_scores = unwatermarked_quality_scores
63 self.prompts = prompts
66class VideoQualityAnalysisPipeline:
67 """Pipeline for video quality analysis."""
69 def __init__(self,
70 dataset: BaseDataset,
71 watermarked_video_editor_list: List[VideoEditor] = [],
72 unwatermarked_video_editor_list: List[VideoEditor] = [],
73 watermarked_frame_editor_list: List[ImageEditor] = [],
74 unwatermarked_frame_editor_list: List[ImageEditor] = [],
75 analyzers: List[VideoQualityAnalyzer] = None,
76 show_progress: bool = True,
77 store_path: str = None,
78 return_type: QualityPipelineReturnType = QualityPipelineReturnType.MEAN_SCORES) -> None:
79 """Initialize the image quality analysis pipeline.
81 Args:
82 dataset (BaseDataset): The dataset for evaluation.
83 watermarked_video_editor_list (List[VideoEditor], optional): The list of video editors for watermarked videos. Defaults to [].
84 unwatermarked_video_editor_list (List[VideoEditor], optional): List of quality analyzers for videos. Defaults to [].
85 watermarked_frame_editor_list (List[ImageEditor], optional): List of image editors for editing individual watermarked frames. Defaults to [].
86 unwatermarked_frame_editor_list (List[ImageEditor], optional): List of image editors for editing individual unwatermarked frames. Defaults to [].
87 analyzers (List[VideoQualityAnalyzer], optional): Whether to show progress. Defaults to None.
88 show_progress (bool, optional): The path to store the results. Defaults to True.
89 store_path (str, optional): The path to store the results. Defaults to None.
90 return_type (QualityPipelineReturnType, optional): The return type of the pipeline. Defaults to QualityPipelineReturnType.MEAN_SCORES.
91 """
92 self.dataset = dataset
93 self.watermarked_video_editor_list = watermarked_video_editor_list
94 self.unwatermarked_video_editor_list = unwatermarked_video_editor_list
95 self.watermarked_frame_editor_list = watermarked_frame_editor_list
96 self.unwatermarked_frame_editor_list = unwatermarked_frame_editor_list
97 self.analyzers = analyzers or []
98 self.show_progress = show_progress
99 self.store_path = store_path
100 self.return_type = return_type
102 def _check_compatibility(self):
103 """Check if the pipeline is compatible with the dataset."""
104 pass
106 def _get_iterable(self):
107 """Return an iterable for the dataset."""
108 pass
110 def _get_progress_bar(self, iterable):
111 """Return an iterable possibly wrapped with a progress bar."""
112 if self.show_progress:
113 return tqdm(iterable, desc="Processing", leave=True)
114 return SilentProgressBar(iterable)
116 def _get_prompt(self, index: int) -> str:
117 """Get prompt from dataset."""
118 return self.dataset.get_prompt(index)
120 def _get_watermarked_video(self, watermark: BaseWatermark, index: int, **generation_kwargs) -> List[Image.Image]:
121 """Generate watermarked image from dataset."""
122 prompt = self._get_prompt(index)
123 frames = watermark.generate_watermarked_media(input_data=prompt, **generation_kwargs)
124 return frames
126 def _get_unwatermarked_video(self, watermark: BaseWatermark, index: int, **generation_kwargs) -> List[Image.Image]:
127 """Generate or retrieve unwatermarked image from dataset."""
128 prompt = self._get_prompt(index)
129 frames = watermark.generate_unwatermarked_media(input_data=prompt, **generation_kwargs)
130 return frames
132 def _edit_watermarked_video(self, frames: List[Image.Image]) -> List[Image.Image]:
133 """Edit watermarked image using image editors."""
134 # Step 1: Edit all frames using video editors
135 for video_editor in self.watermarked_video_editor_list:
136 frames = video_editor.edit(frames)
137 # Step 2: Edit individual frames using image editors
138 for frame_editor in self.watermarked_frame_editor_list:
139 frames = [frame_editor.edit(frame) for frame in frames]
140 return frames
142 def _edit_unwatermarked_video(self, frames: List[Image.Image]) -> List[Image.Image]:
143 """Edit unwatermarked image using image editors."""
144 # Step 1: Edit all frames using video editors
145 for video_editor in self.unwatermarked_video_editor_list:
146 frames = video_editor.edit(frames)
147 # Step 2: Edit individual frames using image editors
148 for frame_editor in self.unwatermarked_frame_editor_list:
149 frames = [frame_editor.edit(frame) for frame in frames]
150 return frames
152 def _prepare_dataset(self, watermark: BaseWatermark, **generation_kwargs) -> DatasetForEvaluation:
153 """
154 Prepare and generate all necessary data for quality analysis.
156 This method should be overridden by subclasses to implement specific
157 data preparation logic based on the analysis requirements.
159 Parameters:
160 watermark: The watermark algorithm instance.
161 generation_kwargs: Additional generation parameters.
163 Returns:
164 DatasetForEvaluation object containing all prepared data.
165 """
166 dataset_eval = DatasetForEvaluation()
168 # Generate all videos
169 bar = self._get_progress_bar(self._get_iterable())
170 bar.set_description("Generating videos for quality analysis")
171 for index in bar:
172 # Generate and edit watermarked image
173 watermarked_frames = self._get_watermarked_video(watermark, index, **generation_kwargs)
174 watermarked_frames = self._edit_watermarked_video(watermarked_frames)
176 # Generate and edit unwatermarked image
177 unwatermarked_frames = self._get_unwatermarked_video(watermark, index, **generation_kwargs)
178 unwatermarked_frames = self._edit_unwatermarked_video(unwatermarked_frames)
180 dataset_eval.watermarked_videos.append(watermarked_frames)
181 dataset_eval.unwatermarked_videos.append(unwatermarked_frames)
182 dataset_eval.indexes.append(index)
184 if self.dataset.num_references > 0:
185 reference_frames = self.dataset.get_reference(index)
186 dataset_eval.reference_videos.append(reference_frames)
188 return dataset_eval
190 def _prepare_input_for_quality_analyzer(self,
191 watermarked_videos: List[List[Image.Image]],
192 unwatermarked_videos: List[List[Image.Image]],
193 reference_videos: List[List[Image.Image]]):
194 """ Prepare input for quality analyzer.
196 Args:
197 watermarked_videos (List[List[Image.Image]]): Watermarked video(s)
198 unwatermarked_videos (List[List[Image.Image]]): Unwatermarked video(s)
199 reference_videos (List[List[Image.Image]]): Reference video if available
200 """
201 pass
203 def _store_results(self, prepared_dataset: DatasetForEvaluation):
204 """Store results."""
205 os.makedirs(self.store_path, exist_ok=True)
206 dataset_name = self.dataset.name
208 for (index, watermarked_video, unwatermarked_video) in zip(prepared_dataset.indexes, prepared_dataset.watermarked_videos, prepared_dataset.unwatermarked_videos):
209 # unwatermarked/watermarked_video is List[Image.Image], so first make a video from the frames
210 save_dir = os.path.join(self.store_path, f"{self.__class__.__name__}_{dataset_name}_watermarked_prompt{index}")
211 os.makedirs(save_dir, exist_ok=True)
212 for i, frame in enumerate(watermarked_video):
213 frame.save(os.path.join(save_dir, f"frame_{i}.png"))
215 save_dir = os.path.join(self.store_path, f"{self.__class__.__name__}_{dataset_name}_unwatermarked_prompt{index}")
216 os.makedirs(save_dir, exist_ok=True)
217 for i, frame in enumerate(unwatermarked_video):
218 frame.save(os.path.join(save_dir, f"frame_{i}.png"))
220 if self.dataset.num_references > 0:
221 reference_frames = self.dataset.get_reference(index)
222 save_dir = os.path.join(self.store_path, f"{self.__class__.__name__}_{dataset_name}_reference_prompt{index}")
223 os.makedirs(save_dir, exist_ok=True)
224 for i, frame in enumerate(reference_frames):
225 frame.save(os.path.join(save_dir, f"frame_{i}.png"))
227 def analyze_quality(self, prepared_data, analyzer):
228 """Analyze quality of watermarked and unwatermarked images."""
229 pass
231 def evaluate(self, watermark: BaseWatermark, generation_kwargs={}):
232 """Conduct evaluation utilizing the pipeline."""
233 # Check compatibility
234 self._check_compatibility()
236 # Prepare dataset
237 prepared_dataset = self._prepare_dataset(watermark, **generation_kwargs)
239 # Store results
240 if self.store_path:
241 self._store_results(prepared_dataset)
243 # Prepare input for quality analyzer
244 prepared_data = self._prepare_input_for_quality_analyzer(
245 prepared_dataset.watermarked_videos,
246 prepared_dataset.unwatermarked_videos,
247 prepared_dataset.reference_videos
248 )
250 # Analyze quality
251 all_scores = {}
252 for analyzer in self.analyzers:
253 w_scores, u_scores = self.analyze_quality(prepared_data, analyzer)
254 analyzer_name = analyzer.__class__.__name__
255 all_scores[analyzer_name] = (w_scores, u_scores)
257 # Get prompts and indexes
258 prompts = []
260 # For other pipelines
261 for idx in prepared_dataset.indexes:
262 prompts.append(self._get_prompt(idx))
264 # Create result
265 watermarked_scores = {}
266 unwatermarked_scores = {}
268 for analyzer_name, (w_scores, u_scores) in all_scores.items():
269 watermarked_scores[analyzer_name] = w_scores
270 unwatermarked_scores[analyzer_name] = u_scores
272 result = QualityComparisonResult(
273 store_path=self.store_path,
274 watermarked_quality_scores=watermarked_scores,
275 unwatermarked_quality_scores=unwatermarked_scores,
276 prompts=prompts,
277 )
279 # Format results based on return_type
280 return self._format_results(result)
282 def _format_results(self, result: QualityComparisonResult):
283 """Format results based on return_type."""
284 if self.return_type == QualityPipelineReturnType.FULL:
285 return result
286 elif self.return_type == QualityPipelineReturnType.SCORES:
287 return {
288 'watermarked': result.watermarked_quality_scores,
289 'unwatermarked': result.unwatermarked_quality_scores,
290 'prompts': result.prompts
291 }
292 elif self.return_type == QualityPipelineReturnType.MEAN_SCORES:
293 # Calculate mean scores for each analyzer
294 mean_watermarked = {}
295 mean_unwatermarked = {}
297 for analyzer_name, scores in result.watermarked_quality_scores.items():
298 if isinstance(scores, list) and len(scores) > 0:
299 mean_watermarked[analyzer_name] = np.mean(scores)
300 else:
301 mean_watermarked[analyzer_name] = scores
303 for analyzer_name, scores in result.unwatermarked_quality_scores.items():
304 if isinstance(scores, list) and len(scores) > 0:
305 mean_unwatermarked[analyzer_name] = np.mean(scores)
306 else:
307 mean_unwatermarked[analyzer_name] = scores
309 return {
310 'watermarked': mean_watermarked,
311 'unwatermarked': mean_unwatermarked
312 }
314class DirectVideoQualityAnalysisPipeline(VideoQualityAnalysisPipeline):
315 """Pipeline for direct video quality analysis."""
317 def __init__(self,
318 dataset: BaseDataset,
319 watermarked_video_editor_list: List[VideoEditor] = [],
320 unwatermarked_video_editor_list: List[VideoEditor] = [],
321 watermarked_frame_editor_list: List[ImageEditor] = [],
322 unwatermarked_frame_editor_list: List[ImageEditor] = [],
323 analyzers: List[VideoQualityAnalyzer] = None,
324 show_progress: bool = True,
325 store_path: str = None,
326 return_type: QualityPipelineReturnType = QualityPipelineReturnType.MEAN_SCORES) -> None:
327 """Initialize the video quality analysis pipeline.
329 Args:
330 dataset (BaseDataset): The dataset for evaluation.
331 watermarked_video_editor_list (List[VideoEditor], optional): The list of video editors for watermarked videos. Defaults to [].
332 unwatermarked_video_editor_list (List[VideoEditor], optional): List of quality analyzers for videos. Defaults to [].
333 watermarked_frame_editor_list (List[ImageEditor], optional): List of image editors for editing individual watermarked frames. Defaults to [].
334 unwatermarked_frame_editor_list (List[ImageEditor], optional): List of image editors for editing individual unwatermarked frames. Defaults to [].
335 analyzers (List[VideoQualityAnalyzer], optional): Whether to show progress. Defaults to None.
336 show_progress (bool, optional): Whether to show progress. Defaults to True.
337 store_path (str, optional): The path to store the results. Defaults to None.
338 return_type (QualityPipelineReturnType, optional): The return type of the pipeline. Defaults to QualityPipelineReturnType.MEAN_SCORES.
339 """
340 super().__init__(dataset, watermarked_video_editor_list, unwatermarked_video_editor_list, watermarked_frame_editor_list, unwatermarked_frame_editor_list, analyzers, show_progress, store_path, return_type)
342 def _get_iterable(self):
343 """Return an iterable for the dataset."""
344 return range(self.dataset.num_samples)
346 def _get_prompt(self, index: int) -> str:
347 """Get prompt from dataset."""
348 return self.dataset.get_prompt(index)
350 def _get_watermarked_video(self, watermark: BaseWatermark, index: int, **generation_kwargs) -> List[Image.Image]:
351 """Generate watermarked video from dataset."""
352 prompt = self._get_prompt(index)
353 frames = watermark.generate_watermarked_media(input_data=prompt, **generation_kwargs)
354 return frames
356 def _get_unwatermarked_video(self, watermark: BaseWatermark, index: int, **generation_kwargs) -> List[Image.Image]:
357 """Generate or retrieve unwatermarked video from dataset."""
358 prompt = self._get_prompt(index)
359 frames = watermark.generate_unwatermarked_media(input_data=prompt, **generation_kwargs)
360 return frames
362 def _prepare_input_for_quality_analyzer(self,
363 watermarked_videos: List[List[Image.Image]],
364 unwatermarked_videos: List[List[Image.Image]],
365 reference_videos: List[List[Image.Image]]):
366 """Prepare input for quality analyzer."""
367 # Group videos by prompt
368 return watermarked_videos, unwatermarked_videos
370 def analyze_quality(self,
371 prepared_data: Tuple[List[List[Image.Image]], List[List[Image.Image]], List[List[Image.Image]]],
372 analyzer: VideoQualityAnalyzer):
373 """Analyze quality of watermarked and unwatermarked videos."""
374 watermarked_videos, unwatermarked_videos = prepared_data
376 # Create pairs of watermarked and unwatermarked videos
377 video_pairs = list(zip(watermarked_videos, unwatermarked_videos))
379 bar = self._get_progress_bar(video_pairs)
380 bar.set_description(f"Analyzing quality for {analyzer.__class__.__name__}")
381 w_scores, u_scores = [], []
382 for watermarked_video, unwatermarked_video in bar:
383 w_score = analyzer.analyze(watermarked_video)
384 u_score = analyzer.analyze(unwatermarked_video)
385 w_scores.append(w_score)
386 u_scores.append(u_score)
387 return w_scores, u_scores