Coverage for evaluation / pipelines / image_quality_analysis.py: 95.02%
281 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16from watermark.base import BaseWatermark
17from evaluation.dataset import BaseDataset
18from tqdm import tqdm
19from enum import Enum, auto
20from PIL import Image
21from evaluation.tools.image_editor import ImageEditor
22from typing import List, Dict, Union, Tuple, Any, Optional
23import numpy as np
24from dataclasses import dataclass, field
25import os
26import random
27from evaluation.tools.image_quality_analyzer import (
28 ImageQualityAnalyzer
29)
30import lpips
33class SilentProgressBar:
34 """A silent progress bar wrapper that supports set_description but shows no output."""
36 def __init__(self, iterable):
37 self.iterable = iterable
39 def __iter__(self):
40 return iter(self.iterable)
42 def set_description(self, desc):
43 """No-op for silent mode."""
44 pass
46class QualityPipelineReturnType(Enum):
47 """Return type of the image quality analysis pipeline."""
48 FULL = auto()
49 SCORES = auto()
50 MEAN_SCORES = auto()
52@dataclass
53class DatasetForEvaluation:
54 """Dataset for evaluation."""
55 watermarked_images: List[Union[Image.Image, List[Image.Image]]] = field(default_factory=list)
56 unwatermarked_images: List[Union[Image.Image, List[Image.Image]]] = field(default_factory=list)
57 reference_images: List[Image.Image] = field(default_factory=list)
58 indexes: List[int] = field(default_factory=list)
59 prompts: List[str] = field(default_factory=list)
61class QualityComparisonResult:
62 """Result of image quality comparison."""
64 def __init__(self,
65 store_path: str,
66 watermarked_quality_scores: Dict[str, List[float]],
67 unwatermarked_quality_scores: Dict[str, List[float]],
68 prompts: List[str],
69 ) -> None:
70 """
71 Initialize the image quality comparison result.
73 Parameters:
74 store_path: The path to store the results.
75 watermarked_quality_scores: The quality scores of the watermarked image.
76 unwatermarked_quality_scores: The quality scores of the unwatermarked image.
77 prompts: The prompts used to generate the images.
78 """
79 self.store_path = store_path
80 self.watermarked_quality_scores = watermarked_quality_scores
81 self.unwatermarked_quality_scores = unwatermarked_quality_scores
82 self.prompts = prompts
85class ImageQualityAnalysisPipeline:
86 """Pipeline for image quality analysis."""
88 def __init__(self,
89 dataset: BaseDataset,
90 watermarked_image_editor_list: List[ImageEditor] = [],
91 unwatermarked_image_editor_list: List[ImageEditor] = [],
92 analyzers: List[ImageQualityAnalyzer] = None,
93 unwatermarked_image_source: str = 'generated',
94 reference_image_source: str = 'natural',
95 show_progress: bool = True,
96 store_path: str = None,
97 return_type: QualityPipelineReturnType = QualityPipelineReturnType.MEAN_SCORES) -> None:
98 """
99 Initialize the image quality analysis pipeline.
101 Parameters:
102 dataset: The dataset for evaluation.
103 watermarked_image_editor_list: The list of image editors for watermarked images.
104 unwatermarked_image_editor_list: The list of image editors for unwatermarked images.
105 analyzers: List of quality analyzers for images.
106 unwatermarked_image_source: The source of unwatermarked images ('natural' or 'generated').
107 reference_image_source: The source of reference images ('natural' or 'generated').
108 show_progress: Whether to show progress.
109 store_path: The path to store the results. If None, the generated images will not be stored.
110 return_type: The return type of the pipeline.
111 """
112 if unwatermarked_image_source not in ['natural', 'generated']:
113 raise ValueError(f"Invalid unwatermarked_image_source: {unwatermarked_image_source}")
115 self.dataset = dataset
116 self.watermarked_image_editor_list = watermarked_image_editor_list
117 self.unwatermarked_image_editor_list = unwatermarked_image_editor_list
118 self.analyzers = analyzers or []
119 self.unwatermarked_image_source = unwatermarked_image_source
120 self.reference_image_source = reference_image_source
121 self.show_progress = show_progress
122 self.store_path = store_path
123 self.return_type = return_type
125 def _check_compatibility(self):
126 """Check if the pipeline is compatible with the dataset."""
127 pass
129 def _get_iterable(self):
130 """Return an iterable for the dataset."""
131 pass
133 def _get_progress_bar(self, iterable):
134 """Return an iterable possibly wrapped with a progress bar."""
135 if self.show_progress:
136 return tqdm(iterable, desc="Processing", leave=True)
137 return SilentProgressBar(iterable)
139 def _get_prompt(self, index: int) -> str:
140 """Get prompt from dataset."""
141 return self.dataset.get_prompt(index)
143 def _get_watermarked_image(self, watermark: BaseWatermark, index: int, **generation_kwargs) -> Union[Image.Image, List[Image.Image]]:
144 """Generate watermarked image from dataset."""
145 prompt = self._get_prompt(index)
146 image = watermark.generate_watermarked_media(input_data=prompt, **generation_kwargs)
147 return image
149 def _get_unwatermarked_image(self, watermark: BaseWatermark, index: int, **generation_kwargs) -> Union[Image.Image, List[Image.Image]]:
150 """Generate or retrieve unwatermarked image from dataset."""
151 if self.unwatermarked_image_source == 'natural':
152 return self.dataset.get_reference(index)
153 elif self.unwatermarked_image_source == 'generated':
154 prompt = self._get_prompt(index)
155 image = watermark.generate_unwatermarked_media(input_data=prompt, **generation_kwargs)
156 return image
158 def _edit_watermarked_image(self, image: Union[Image.Image, List[Image.Image]]) -> Union[Image.Image, List[Image.Image]]:
159 """Edit watermarked image using image editors."""
160 if isinstance(image, list):
161 edited_images = []
162 for img in image:
163 for image_editor in self.watermarked_image_editor_list:
164 img = image_editor.edit(img)
165 edited_images.append(img)
166 return edited_images
167 else:
168 for image_editor in self.watermarked_image_editor_list:
169 image = image_editor.edit(image)
170 return image
172 def _edit_unwatermarked_image(self, image: Union[Image.Image, List[Image.Image]]) -> Union[Image.Image, List[Image.Image]]:
173 """Edit unwatermarked image using image editors."""
174 if isinstance(image, list):
175 edited_images = []
176 for img in image:
177 for image_editor in self.unwatermarked_image_editor_list:
178 img = image_editor.edit(img)
179 edited_images.append(img)
180 return edited_images
181 else:
182 for image_editor in self.unwatermarked_image_editor_list:
183 image = image_editor.edit(image)
184 return image
186 def _prepare_dataset(self, watermark: BaseWatermark, **generation_kwargs) -> DatasetForEvaluation:
187 """
188 Prepare and generate all necessary data for quality analysis.
190 This method should be overridden by subclasses to implement specific
191 data preparation logic based on the analysis requirements.
193 Parameters:
194 watermark: The watermark algorithm instance.
195 generation_kwargs: Additional generation parameters.
197 Returns:
198 DatasetForEvaluation object containing all prepared data.
199 """
200 dataset_eval = DatasetForEvaluation()
202 # Generate all images
203 bar = self._get_progress_bar(self._get_iterable())
204 bar.set_description("Generating images for quality analysis")
205 for index in bar:
206 # Generate and edit watermarked image
207 watermarked_image = self._get_watermarked_image(watermark, index, **generation_kwargs)
208 watermarked_image = self._edit_watermarked_image(watermarked_image)
210 # Generate and edit unwatermarked image
211 unwatermarked_image = self._get_unwatermarked_image(watermark, index, **generation_kwargs)
212 unwatermarked_image = self._edit_unwatermarked_image(unwatermarked_image)
214 dataset_eval.watermarked_images.append(watermarked_image)
215 dataset_eval.unwatermarked_images.append(unwatermarked_image)
216 if hasattr(self, "prompt_per_image"):
217 index = index // self.prompt_per_image
218 dataset_eval.indexes.append(index)
219 dataset_eval.prompts.append(self._get_prompt(index))
221 if self.reference_image_source == 'natural':
222 if self.dataset.num_references > 0:
223 reference_image = self.dataset.get_reference(index)
224 dataset_eval.reference_images.append(reference_image)
225 else:
226 # For text-based analyzers, add None placeholder
227 dataset_eval.reference_images.append(None)
228 else:
229 dataset_eval.reference_images.append(unwatermarked_image)
231 return dataset_eval
233 def _prepare_input_for_quality_analyzer(self,
234 prepared_dataset: DatasetForEvaluation):
235 """
236 Prepare input for quality analyzer.
238 Parameters:
239 prepared_dataset: The prepared dataset.
240 """
241 pass
243 def _store_results(self, prepared_dataset: DatasetForEvaluation):
244 """Store results."""
245 os.makedirs(self.store_path, exist_ok=True)
246 dataset_name = self.dataset.name
248 for (index, watermarked_image, unwatermarked_image, prompt) in zip(prepared_dataset.indexes, prepared_dataset.watermarked_images, prepared_dataset.unwatermarked_images, prepared_dataset.prompts):
249 watermarked_image.save(os.path.join(self.store_path, f"{self.__class__.__name__}_{dataset_name}_watermarked_prompt_{index}.png"))
250 unwatermarked_image.save(os.path.join(self.store_path, f"{self.__class__.__name__}_{dataset_name}_unwatermarked_prompt_{index}.png"))
252 def analyze_quality(self, prepared_data, analyzer):
253 """Analyze quality of watermarked and unwatermarked images."""
254 pass
256 def evaluate(self, watermark: BaseWatermark, generation_kwargs={}):
257 """Conduct evaluation utilizing the pipeline."""
258 # Check compatibility
259 self._check_compatibility()
260 print(self.store_path)
262 # Prepare dataset
263 prepared_dataset = self._prepare_dataset(watermark, **generation_kwargs)
265 # Store results
266 if self.store_path:
267 self._store_results(prepared_dataset)
269 # Prepare input for quality analyzer
270 prepared_data = self._prepare_input_for_quality_analyzer(
271 prepared_dataset
272 )
274 # Analyze quality
275 all_scores = {}
276 for analyzer in self.analyzers:
277 w_scores, u_scores = self.analyze_quality(prepared_data, analyzer)
278 analyzer_name = analyzer.__class__.__name__
279 all_scores[analyzer_name] = (w_scores, u_scores)
281 # Get prompts and indexes
282 prompts = []
284 # For other pipelines
285 for idx in prepared_dataset.indexes:
286 prompts.append(self._get_prompt(idx))
288 # Create result
289 watermarked_scores = {}
290 unwatermarked_scores = {}
292 for analyzer_name, (w_scores, u_scores) in all_scores.items():
293 watermarked_scores[analyzer_name] = w_scores
294 unwatermarked_scores[analyzer_name] = u_scores
296 result = QualityComparisonResult(
297 store_path=self.store_path,
298 watermarked_quality_scores=watermarked_scores,
299 unwatermarked_quality_scores=unwatermarked_scores,
300 prompts=prompts,
301 )
303 # Format results based on return_type
304 return self._format_results(result)
306 def _format_results(self, result: QualityComparisonResult):
307 """Format results based on return_type."""
308 if self.return_type == QualityPipelineReturnType.FULL:
309 return result
310 elif self.return_type == QualityPipelineReturnType.SCORES:
311 return {
312 'watermarked': result.watermarked_quality_scores,
313 'unwatermarked': result.unwatermarked_quality_scores,
314 'prompts': result.prompts
315 }
316 elif self.return_type == QualityPipelineReturnType.MEAN_SCORES:
317 # Calculate mean scores for each analyzer
318 mean_watermarked = {}
319 mean_unwatermarked = {}
321 for analyzer_name, scores in result.watermarked_quality_scores.items():
322 if isinstance(scores, list) and len(scores) > 0:
323 mean_watermarked[analyzer_name] = np.mean(scores)
324 else:
325 mean_watermarked[analyzer_name] = scores
327 for analyzer_name, scores in result.unwatermarked_quality_scores.items():
328 if isinstance(scores, list) and len(scores) > 0:
329 mean_unwatermarked[analyzer_name] = np.mean(scores)
330 else:
331 mean_unwatermarked[analyzer_name] = scores
333 return {
334 'watermarked': mean_watermarked,
335 'unwatermarked': mean_unwatermarked
336 }
338class DirectImageQualityAnalysisPipeline(ImageQualityAnalysisPipeline):
339 """
340 Pipeline for direct image quality analysis.
342 This class analyzes the quality of images by directly comparing the characteristics
343 of watermarked images with unwatermarked images. It evaluates metrics such as PSNR,
344 SSIM, LPIPS, FID, BRISQUE without the need for any external reference image.
346 Use this pipeline to assess the impact of watermarking on image quality directly.
347 """
349 def __init__(self,
350 dataset: BaseDataset,
351 watermarked_image_editor_list: List[ImageEditor] = [],
352 unwatermarked_image_editor_list: List[ImageEditor] = [],
353 analyzers: List[ImageQualityAnalyzer] = None,
354 unwatermarked_image_source: str = 'generated',
355 reference_image_source: str = 'natural',
356 show_progress: bool = True,
357 store_path: str = None,
358 return_type: QualityPipelineReturnType = QualityPipelineReturnType.MEAN_SCORES) -> None:
360 super().__init__(dataset, watermarked_image_editor_list, unwatermarked_image_editor_list,
361 analyzers, unwatermarked_image_source, reference_image_source, show_progress, store_path, return_type)
363 def _get_iterable(self):
364 """Return an iterable for the dataset."""
365 return range(self.dataset.num_samples)
367 def _prepare_input_for_quality_analyzer(self,
368 prepared_dataset: DatasetForEvaluation):
369 """Prepare input for quality analyzer."""
370 return [(watermarked_image, unwatermarked_image) for watermarked_image, unwatermarked_image in zip(prepared_dataset.watermarked_images, prepared_dataset.unwatermarked_images)]
372 def analyze_quality(self,
373 prepared_data: List[Tuple[Image.Image, Image.Image]],
374 analyzer: ImageQualityAnalyzer):
375 """Analyze quality of watermarked and unwatermarked images."""
376 bar = self._get_progress_bar(prepared_data)
377 bar.set_description(f"Analyzing quality for {analyzer.__class__.__name__}")
378 w_scores, u_scores = [], []
379 # For direct analyzers, we analyze each image independently
380 for watermarked_image, unwatermarked_image in bar:
381 # watermarked score
382 try:
383 w_score = analyzer.analyze(watermarked_image)
384 except TypeError:
385 # analyzer expects a reference -> use unwatermarked_image as reference
386 w_score = analyzer.analyze(watermarked_image, unwatermarked_image)
387 # unwatermarked score
388 try:
389 u_score = analyzer.analyze(unwatermarked_image)
390 except TypeError:
391 u_score = analyzer.analyze(unwatermarked_image, watermarked_image)
392 w_scores.append(w_score)
393 u_scores.append(u_score)
395 return w_scores, u_scores
398class ReferencedImageQualityAnalysisPipeline(ImageQualityAnalysisPipeline):
399 """
400 Pipeline for referenced image quality analysis.
402 This pipeline assesses image quality by comparing both watermarked and unwatermarked
403 images against a common reference image. It measures the degree of similarity or
404 deviation from the reference.
406 Ideal for scenarios where the impact of watermarking on image quality needs to be
407 assessed, particularly in relation to specific reference images or ground truth.
408 """
410 def __init__(self,
411 dataset: BaseDataset,
412 watermarked_image_editor_list: List[ImageEditor] = [],
413 unwatermarked_image_editor_list: List[ImageEditor] = [],
414 analyzers: List[ImageQualityAnalyzer] = None,
415 unwatermarked_image_source: str = 'generated',
416 reference_image_source: str = 'natural',
417 show_progress: bool = True,
418 store_path: str = None,
419 return_type: QualityPipelineReturnType = QualityPipelineReturnType.MEAN_SCORES) -> None:
421 super().__init__(dataset, watermarked_image_editor_list, unwatermarked_image_editor_list,
422 analyzers, unwatermarked_image_source, reference_image_source, show_progress, store_path, return_type)
424 def _check_compatibility(self):
425 """Check if the pipeline is compatible with the dataset."""
426 # Check if we have analyzers that use text as reference
427 has_text_analyzer = any(hasattr(analyzer, 'reference_source') and analyzer.reference_source == 'text'
428 for analyzer in self.analyzers)
430 # If all analyzers use text reference, we don't need reference images
431 if not has_text_analyzer and self.dataset.num_references == 0:
432 raise ValueError(f"Reference images are required for referenced image quality analysis. Dataset {self.dataset.name} has no reference images.")
434 def _get_iterable(self):
435 """Return an iterable for the dataset."""
436 return range(self.dataset.num_samples)
438 def _prepare_input_for_quality_analyzer(self,
439 prepared_dataset: DatasetForEvaluation):
440 """Prepare input for quality analyzer."""
441 return [(watermarked_image, unwatermarked_image, reference_image, prompt)
442 for watermarked_image, unwatermarked_image, reference_image, prompt in
443 zip(prepared_dataset.watermarked_images, prepared_dataset.unwatermarked_images, prepared_dataset.reference_images, prepared_dataset.prompts)
444 ]
446 def analyze_quality(self,
447 prepared_data: List[Tuple[Image.Image, Image.Image, Image.Image, str]],
448 analyzer: ImageQualityAnalyzer):
449 """Analyze quality of watermarked and unwatermarked images."""
450 bar = self._get_progress_bar(prepared_data)
451 bar.set_description(f"Analyzing quality for {analyzer.__class__.__name__}")
452 w_scores, u_scores = [], []
453 # For referenced analyzers, we compare against the reference
454 for watermarked_image, unwatermarked_image, reference_image, prompt in bar:
455 if analyzer.reference_source == "image":
456 w_score = analyzer.analyze(watermarked_image, reference_image)
457 u_score = analyzer.analyze(unwatermarked_image, reference_image)
458 elif analyzer.reference_source == "text":
459 w_score = analyzer.analyze(watermarked_image, prompt)
460 u_score = analyzer.analyze(unwatermarked_image, prompt)
461 else:
462 raise ValueError(f"Invalid reference source: {analyzer.reference_source}")
463 w_scores.append(w_score)
464 u_scores.append(u_score)
465 return w_scores, u_scores
468class GroupImageQualityAnalysisPipeline(ImageQualityAnalysisPipeline):
469 """
470 Pipeline for group-based image quality analysis.
472 This pipeline analyzes quality metrics that require comparing distributions
473 of multiple images (e.g., FID). It generates all images upfront and then
474 performs a single analysis on the entire collection.
475 """
477 def __init__(self,
478 dataset: BaseDataset,
479 watermarked_image_editor_list: List[ImageEditor] = [],
480 unwatermarked_image_editor_list: List[ImageEditor] = [],
481 analyzers: List[ImageQualityAnalyzer] = None,
482 unwatermarked_image_source: str = 'generated',
483 reference_image_source: str = 'natural',
484 show_progress: bool = True,
485 store_path: str = None,
486 return_type: QualityPipelineReturnType = QualityPipelineReturnType.MEAN_SCORES) -> None:
488 super().__init__(dataset, watermarked_image_editor_list, unwatermarked_image_editor_list,
489 analyzers, unwatermarked_image_source, reference_image_source, show_progress, store_path, return_type)
491 def _get_iterable(self):
492 """Return an iterable for analyzers instead of dataset indices."""
493 return range(self.dataset.num_samples)
495 def _prepare_input_for_quality_analyzer(self,
496 prepared_dataset: DatasetForEvaluation):
497 """Prepare input for group analyzer."""
498 return [(prepared_dataset.watermarked_images, prepared_dataset.unwatermarked_images, prepared_dataset.reference_images)]
500 def analyze_quality(self,
501 prepared_data: List[Tuple[List[Image.Image], List[Image.Image], List[Image.Image]]],
502 analyzer: ImageQualityAnalyzer):
503 """Analyze quality of image groups."""
504 bar = self._get_progress_bar(prepared_data)
505 bar.set_description(f"Analyzing quality for {analyzer.__class__.__name__}")
506 w_scores, u_scores = [], []
507 # For group analyzers, we pass the entire collection
508 for watermarked_images, unwatermarked_images, reference_images in bar:
509 w_score = analyzer.analyze(watermarked_images, reference_images)
510 u_score = analyzer.analyze(unwatermarked_images, reference_images)
511 w_scores.append(w_score)
512 u_scores.append(u_score)
513 return w_scores, u_scores
516class RepeatImageQualityAnalysisPipeline(ImageQualityAnalysisPipeline):
517 """
518 Pipeline for repeat-based image quality analysis.
520 This pipeline analyzes diversity metrics by generating multiple images
521 for each prompt (e.g., LPIPS diversity). It generates multiple versions
522 per prompt and analyzes the diversity within each group.
523 """
525 def __init__(self,
526 dataset: BaseDataset,
527 prompt_per_image: int = 20,
528 watermarked_image_editor_list: List[ImageEditor] = [],
529 unwatermarked_image_editor_list: List[ImageEditor] = [],
530 analyzers: List[ImageQualityAnalyzer] = None,
531 unwatermarked_image_source: str = 'generated',
532 reference_image_source: str = 'natural',
533 show_progress: bool = True,
534 store_path: str = None,
535 return_type: QualityPipelineReturnType = QualityPipelineReturnType.MEAN_SCORES) -> None:
537 super().__init__(dataset, watermarked_image_editor_list, unwatermarked_image_editor_list,
538 analyzers, unwatermarked_image_source, reference_image_source, show_progress, store_path, return_type)
540 self.prompt_per_image = prompt_per_image
542 def _get_iterable(self):
543 """Return an iterable for the dataset."""
544 return range(self.dataset.num_samples * self.prompt_per_image)
546 def _get_prompt(self, index: int) -> str:
547 """Get prompt from dataset."""
548 prompt_index = index // self.prompt_per_image
549 return self.dataset.get_prompt(prompt_index)
551 def _get_watermarked_image(self, watermark: BaseWatermark, index: int, **generation_kwargs) -> Union[Image.Image, List[Image.Image]]:
552 """Get watermarked image."""
553 prompt = self._get_prompt(index)
554 # Randomly select a generation seed
555 generation_kwargs['gen_seed'] = random.randint(0, 1000000)
556 return watermark.generate_watermarked_media(input_data=prompt, **generation_kwargs)
558 def _get_unwatermarked_image(self, watermark: BaseWatermark, index: int, **generation_kwargs) -> Union[Image.Image, List[Image.Image]]:
559 """Get unwatermarked image."""
560 prompt = self._get_prompt(index)
561 # Randomly select a generation seed
562 generation_kwargs['gen_seed'] = random.randint(0, 1000000)
563 if self.unwatermarked_image_source == 'natural':
564 return self.dataset.get_reference(index)
565 else:
566 return watermark.generate_unwatermarked_media(input_data=prompt, **generation_kwargs)
568 def _prepare_input_for_quality_analyzer(self,
569 prepared_dataset: DatasetForEvaluation):
570 """Prepare input for diversity analyzer."""
571 # Group images by prompt
572 watermarked_images = prepared_dataset.watermarked_images
573 unwatermarked_images = prepared_dataset.unwatermarked_images
575 grouped = []
576 for i in range(0, len(watermarked_images), self.prompt_per_image):
577 grouped.append(
578 (watermarked_images[i:i+self.prompt_per_image],
579 unwatermarked_images[i:i+self.prompt_per_image])
580 )
581 return grouped
583 def analyze_quality(self,
584 prepared_data: List[Tuple[List[Image.Image], List[Image.Image]]],
585 analyzer: ImageQualityAnalyzer):
586 """Analyze diversity of image batches."""
587 bar = self._get_progress_bar(prepared_data)
588 bar.set_description(f"Analyzing diversity for {analyzer.__class__.__name__}")
589 w_scores, u_scores = [], []
590 # For diversity analyzers, we analyze each batch
591 for watermarked_images, unwatermarked_images in bar:
592 w_score = analyzer.analyze(watermarked_images)
593 u_score = analyzer.analyze(unwatermarked_images)
594 w_scores.append(w_score)
595 u_scores.append(u_score)
597 return w_scores, u_scores
600class ComparedImageQualityAnalysisPipeline(ImageQualityAnalysisPipeline):
601 """
602 Pipeline for compared image quality analysis.
604 This pipeline directly compares watermarked and unwatermarked images
605 to compute metrics like PSNR, SSIM, VIF, FSIM and MS-SSIM. The analyzer receives
606 both images and outputs a single comparison score.
607 """
609 def __init__(self,
610 dataset: BaseDataset,
611 watermarked_image_editor_list: List[ImageEditor] = [],
612 unwatermarked_image_editor_list: List[ImageEditor] = [],
613 analyzers: List[ImageQualityAnalyzer] = None,
614 unwatermarked_image_source: str = 'generated',
615 reference_image_source: str = 'natural',
616 show_progress: bool = True,
617 store_path: str = None,
618 return_type: QualityPipelineReturnType = QualityPipelineReturnType.MEAN_SCORES) -> None:
620 super().__init__(dataset, watermarked_image_editor_list, unwatermarked_image_editor_list,
621 analyzers, unwatermarked_image_source, reference_image_source, show_progress, store_path, return_type)
623 def _get_iterable(self):
624 """Return an iterable for the dataset."""
625 return range(self.dataset.num_samples)
627 def _prepare_input_for_quality_analyzer(self,
628 prepared_dataset: DatasetForEvaluation):
629 """Prepare input for comparison analyzer."""
630 return [(watermarked_image, unwatermarked_image) for watermarked_image, unwatermarked_image in zip(prepared_dataset.watermarked_images, prepared_dataset.unwatermarked_images)]
632 def analyze_quality(self,
633 prepared_data: List[Tuple[Image.Image, Image.Image]],
634 analyzer: ImageQualityAnalyzer):
635 """Analyze quality by comparing watermarked and unwatermarked images."""
636 bar = self._get_progress_bar(prepared_data)
637 bar.set_description(f"Analyzing quality for {analyzer.__class__.__name__}")
638 w_scores, u_scores = [], []
639 # For comparison analyzers, we compute similarity/difference
640 for watermarked_image, unwatermarked_image in bar:
641 # Compare watermarked images with unwatermarked images
642 w_score = analyzer.analyze(watermarked_image, unwatermarked_image)
643 w_scores.append(w_score)
644 return w_scores, u_scores # u_scores is not used for comparison analyzers