Coverage for evaluation / pipelines / image_quality_analysis.py: 95.02%

281 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16from watermark.base import BaseWatermark 

17from evaluation.dataset import BaseDataset 

18from tqdm import tqdm 

19from enum import Enum, auto 

20from PIL import Image 

21from evaluation.tools.image_editor import ImageEditor 

22from typing import List, Dict, Union, Tuple, Any, Optional 

23import numpy as np 

24from dataclasses import dataclass, field 

25import os 

26import random 

27from evaluation.tools.image_quality_analyzer import ( 

28 ImageQualityAnalyzer 

29) 

30import lpips 

31 

32 

33class SilentProgressBar: 

34 """A silent progress bar wrapper that supports set_description but shows no output.""" 

35 

36 def __init__(self, iterable): 

37 self.iterable = iterable 

38 

39 def __iter__(self): 

40 return iter(self.iterable) 

41 

42 def set_description(self, desc): 

43 """No-op for silent mode.""" 

44 pass 

45 

46class QualityPipelineReturnType(Enum): 

47 """Return type of the image quality analysis pipeline.""" 

48 FULL = auto() 

49 SCORES = auto() 

50 MEAN_SCORES = auto() 

51 

52@dataclass 

53class DatasetForEvaluation: 

54 """Dataset for evaluation.""" 

55 watermarked_images: List[Union[Image.Image, List[Image.Image]]] = field(default_factory=list) 

56 unwatermarked_images: List[Union[Image.Image, List[Image.Image]]] = field(default_factory=list) 

57 reference_images: List[Image.Image] = field(default_factory=list) 

58 indexes: List[int] = field(default_factory=list) 

59 prompts: List[str] = field(default_factory=list) 

60 

61class QualityComparisonResult: 

62 """Result of image quality comparison.""" 

63 

64 def __init__(self, 

65 store_path: str, 

66 watermarked_quality_scores: Dict[str, List[float]], 

67 unwatermarked_quality_scores: Dict[str, List[float]], 

68 prompts: List[str], 

69 ) -> None: 

70 """ 

71 Initialize the image quality comparison result. 

72 

73 Parameters: 

74 store_path: The path to store the results. 

75 watermarked_quality_scores: The quality scores of the watermarked image. 

76 unwatermarked_quality_scores: The quality scores of the unwatermarked image. 

77 prompts: The prompts used to generate the images. 

78 """ 

79 self.store_path = store_path 

80 self.watermarked_quality_scores = watermarked_quality_scores 

81 self.unwatermarked_quality_scores = unwatermarked_quality_scores 

82 self.prompts = prompts 

83 

84 

85class ImageQualityAnalysisPipeline: 

86 """Pipeline for image quality analysis.""" 

87 

88 def __init__(self, 

89 dataset: BaseDataset, 

90 watermarked_image_editor_list: List[ImageEditor] = [], 

91 unwatermarked_image_editor_list: List[ImageEditor] = [], 

92 analyzers: List[ImageQualityAnalyzer] = None, 

93 unwatermarked_image_source: str = 'generated', 

94 reference_image_source: str = 'natural', 

95 show_progress: bool = True, 

96 store_path: str = None, 

97 return_type: QualityPipelineReturnType = QualityPipelineReturnType.MEAN_SCORES) -> None: 

98 """ 

99 Initialize the image quality analysis pipeline. 

100 

101 Parameters: 

102 dataset: The dataset for evaluation. 

103 watermarked_image_editor_list: The list of image editors for watermarked images. 

104 unwatermarked_image_editor_list: The list of image editors for unwatermarked images. 

105 analyzers: List of quality analyzers for images. 

106 unwatermarked_image_source: The source of unwatermarked images ('natural' or 'generated'). 

107 reference_image_source: The source of reference images ('natural' or 'generated'). 

108 show_progress: Whether to show progress. 

109 store_path: The path to store the results. If None, the generated images will not be stored. 

110 return_type: The return type of the pipeline. 

111 """ 

112 if unwatermarked_image_source not in ['natural', 'generated']: 

113 raise ValueError(f"Invalid unwatermarked_image_source: {unwatermarked_image_source}") 

114 

115 self.dataset = dataset 

116 self.watermarked_image_editor_list = watermarked_image_editor_list 

117 self.unwatermarked_image_editor_list = unwatermarked_image_editor_list 

118 self.analyzers = analyzers or [] 

119 self.unwatermarked_image_source = unwatermarked_image_source 

120 self.reference_image_source = reference_image_source 

121 self.show_progress = show_progress 

122 self.store_path = store_path 

123 self.return_type = return_type 

124 

125 def _check_compatibility(self): 

126 """Check if the pipeline is compatible with the dataset.""" 

127 pass 

128 

129 def _get_iterable(self): 

130 """Return an iterable for the dataset.""" 

131 pass 

132 

133 def _get_progress_bar(self, iterable): 

134 """Return an iterable possibly wrapped with a progress bar.""" 

135 if self.show_progress: 

136 return tqdm(iterable, desc="Processing", leave=True) 

137 return SilentProgressBar(iterable) 

138 

139 def _get_prompt(self, index: int) -> str: 

140 """Get prompt from dataset.""" 

141 return self.dataset.get_prompt(index) 

142 

143 def _get_watermarked_image(self, watermark: BaseWatermark, index: int, **generation_kwargs) -> Union[Image.Image, List[Image.Image]]: 

144 """Generate watermarked image from dataset.""" 

145 prompt = self._get_prompt(index) 

146 image = watermark.generate_watermarked_media(input_data=prompt, **generation_kwargs) 

147 return image 

148 

149 def _get_unwatermarked_image(self, watermark: BaseWatermark, index: int, **generation_kwargs) -> Union[Image.Image, List[Image.Image]]: 

150 """Generate or retrieve unwatermarked image from dataset.""" 

151 if self.unwatermarked_image_source == 'natural': 

152 return self.dataset.get_reference(index) 

153 elif self.unwatermarked_image_source == 'generated': 

154 prompt = self._get_prompt(index) 

155 image = watermark.generate_unwatermarked_media(input_data=prompt, **generation_kwargs) 

156 return image 

157 

158 def _edit_watermarked_image(self, image: Union[Image.Image, List[Image.Image]]) -> Union[Image.Image, List[Image.Image]]: 

159 """Edit watermarked image using image editors.""" 

160 if isinstance(image, list): 

161 edited_images = [] 

162 for img in image: 

163 for image_editor in self.watermarked_image_editor_list: 

164 img = image_editor.edit(img) 

165 edited_images.append(img) 

166 return edited_images 

167 else: 

168 for image_editor in self.watermarked_image_editor_list: 

169 image = image_editor.edit(image) 

170 return image 

171 

172 def _edit_unwatermarked_image(self, image: Union[Image.Image, List[Image.Image]]) -> Union[Image.Image, List[Image.Image]]: 

173 """Edit unwatermarked image using image editors.""" 

174 if isinstance(image, list): 

175 edited_images = [] 

176 for img in image: 

177 for image_editor in self.unwatermarked_image_editor_list: 

178 img = image_editor.edit(img) 

179 edited_images.append(img) 

180 return edited_images 

181 else: 

182 for image_editor in self.unwatermarked_image_editor_list: 

183 image = image_editor.edit(image) 

184 return image 

185 

186 def _prepare_dataset(self, watermark: BaseWatermark, **generation_kwargs) -> DatasetForEvaluation: 

187 """ 

188 Prepare and generate all necessary data for quality analysis. 

189  

190 This method should be overridden by subclasses to implement specific 

191 data preparation logic based on the analysis requirements. 

192  

193 Parameters: 

194 watermark: The watermark algorithm instance. 

195 generation_kwargs: Additional generation parameters. 

196  

197 Returns: 

198 DatasetForEvaluation object containing all prepared data. 

199 """ 

200 dataset_eval = DatasetForEvaluation() 

201 

202 # Generate all images 

203 bar = self._get_progress_bar(self._get_iterable()) 

204 bar.set_description("Generating images for quality analysis") 

205 for index in bar: 

206 # Generate and edit watermarked image 

207 watermarked_image = self._get_watermarked_image(watermark, index, **generation_kwargs) 

208 watermarked_image = self._edit_watermarked_image(watermarked_image) 

209 

210 # Generate and edit unwatermarked image 

211 unwatermarked_image = self._get_unwatermarked_image(watermark, index, **generation_kwargs) 

212 unwatermarked_image = self._edit_unwatermarked_image(unwatermarked_image) 

213 

214 dataset_eval.watermarked_images.append(watermarked_image) 

215 dataset_eval.unwatermarked_images.append(unwatermarked_image) 

216 if hasattr(self, "prompt_per_image"): 

217 index = index // self.prompt_per_image 

218 dataset_eval.indexes.append(index) 

219 dataset_eval.prompts.append(self._get_prompt(index)) 

220 

221 if self.reference_image_source == 'natural': 

222 if self.dataset.num_references > 0: 

223 reference_image = self.dataset.get_reference(index) 

224 dataset_eval.reference_images.append(reference_image) 

225 else: 

226 # For text-based analyzers, add None placeholder 

227 dataset_eval.reference_images.append(None) 

228 else: 

229 dataset_eval.reference_images.append(unwatermarked_image) 

230 

231 return dataset_eval 

232 

233 def _prepare_input_for_quality_analyzer(self, 

234 prepared_dataset: DatasetForEvaluation): 

235 """ 

236 Prepare input for quality analyzer. 

237  

238 Parameters: 

239 prepared_dataset: The prepared dataset. 

240 """ 

241 pass 

242 

243 def _store_results(self, prepared_dataset: DatasetForEvaluation): 

244 """Store results.""" 

245 os.makedirs(self.store_path, exist_ok=True) 

246 dataset_name = self.dataset.name 

247 

248 for (index, watermarked_image, unwatermarked_image, prompt) in zip(prepared_dataset.indexes, prepared_dataset.watermarked_images, prepared_dataset.unwatermarked_images, prepared_dataset.prompts): 

249 watermarked_image.save(os.path.join(self.store_path, f"{self.__class__.__name__}_{dataset_name}_watermarked_prompt_{index}.png")) 

250 unwatermarked_image.save(os.path.join(self.store_path, f"{self.__class__.__name__}_{dataset_name}_unwatermarked_prompt_{index}.png")) 

251 

252 def analyze_quality(self, prepared_data, analyzer): 

253 """Analyze quality of watermarked and unwatermarked images.""" 

254 pass 

255 

256 def evaluate(self, watermark: BaseWatermark, generation_kwargs={}): 

257 """Conduct evaluation utilizing the pipeline.""" 

258 # Check compatibility 

259 self._check_compatibility() 

260 print(self.store_path) 

261 

262 # Prepare dataset 

263 prepared_dataset = self._prepare_dataset(watermark, **generation_kwargs) 

264 

265 # Store results 

266 if self.store_path: 

267 self._store_results(prepared_dataset) 

268 

269 # Prepare input for quality analyzer 

270 prepared_data = self._prepare_input_for_quality_analyzer( 

271 prepared_dataset 

272 ) 

273 

274 # Analyze quality 

275 all_scores = {} 

276 for analyzer in self.analyzers: 

277 w_scores, u_scores = self.analyze_quality(prepared_data, analyzer) 

278 analyzer_name = analyzer.__class__.__name__ 

279 all_scores[analyzer_name] = (w_scores, u_scores) 

280 

281 # Get prompts and indexes 

282 prompts = [] 

283 

284 # For other pipelines 

285 for idx in prepared_dataset.indexes: 

286 prompts.append(self._get_prompt(idx)) 

287 

288 # Create result 

289 watermarked_scores = {} 

290 unwatermarked_scores = {} 

291 

292 for analyzer_name, (w_scores, u_scores) in all_scores.items(): 

293 watermarked_scores[analyzer_name] = w_scores 

294 unwatermarked_scores[analyzer_name] = u_scores 

295 

296 result = QualityComparisonResult( 

297 store_path=self.store_path, 

298 watermarked_quality_scores=watermarked_scores, 

299 unwatermarked_quality_scores=unwatermarked_scores, 

300 prompts=prompts, 

301 ) 

302 

303 # Format results based on return_type 

304 return self._format_results(result) 

305 

306 def _format_results(self, result: QualityComparisonResult): 

307 """Format results based on return_type.""" 

308 if self.return_type == QualityPipelineReturnType.FULL: 

309 return result 

310 elif self.return_type == QualityPipelineReturnType.SCORES: 

311 return { 

312 'watermarked': result.watermarked_quality_scores, 

313 'unwatermarked': result.unwatermarked_quality_scores, 

314 'prompts': result.prompts 

315 } 

316 elif self.return_type == QualityPipelineReturnType.MEAN_SCORES: 

317 # Calculate mean scores for each analyzer 

318 mean_watermarked = {} 

319 mean_unwatermarked = {} 

320 

321 for analyzer_name, scores in result.watermarked_quality_scores.items(): 

322 if isinstance(scores, list) and len(scores) > 0: 

323 mean_watermarked[analyzer_name] = np.mean(scores) 

324 else: 

325 mean_watermarked[analyzer_name] = scores 

326 

327 for analyzer_name, scores in result.unwatermarked_quality_scores.items(): 

328 if isinstance(scores, list) and len(scores) > 0: 

329 mean_unwatermarked[analyzer_name] = np.mean(scores) 

330 else: 

331 mean_unwatermarked[analyzer_name] = scores 

332 

333 return { 

334 'watermarked': mean_watermarked, 

335 'unwatermarked': mean_unwatermarked 

336 } 

337 

338class DirectImageQualityAnalysisPipeline(ImageQualityAnalysisPipeline): 

339 """ 

340 Pipeline for direct image quality analysis. 

341  

342 This class analyzes the quality of images by directly comparing the characteristics  

343 of watermarked images with unwatermarked images. It evaluates metrics such as PSNR,  

344 SSIM, LPIPS, FID, BRISQUE without the need for any external reference image. 

345  

346 Use this pipeline to assess the impact of watermarking on image quality directly. 

347 """ 

348 

349 def __init__(self, 

350 dataset: BaseDataset, 

351 watermarked_image_editor_list: List[ImageEditor] = [], 

352 unwatermarked_image_editor_list: List[ImageEditor] = [], 

353 analyzers: List[ImageQualityAnalyzer] = None, 

354 unwatermarked_image_source: str = 'generated', 

355 reference_image_source: str = 'natural', 

356 show_progress: bool = True, 

357 store_path: str = None, 

358 return_type: QualityPipelineReturnType = QualityPipelineReturnType.MEAN_SCORES) -> None: 

359 

360 super().__init__(dataset, watermarked_image_editor_list, unwatermarked_image_editor_list, 

361 analyzers, unwatermarked_image_source, reference_image_source, show_progress, store_path, return_type) 

362 

363 def _get_iterable(self): 

364 """Return an iterable for the dataset.""" 

365 return range(self.dataset.num_samples) 

366 

367 def _prepare_input_for_quality_analyzer(self, 

368 prepared_dataset: DatasetForEvaluation): 

369 """Prepare input for quality analyzer.""" 

370 return [(watermarked_image, unwatermarked_image) for watermarked_image, unwatermarked_image in zip(prepared_dataset.watermarked_images, prepared_dataset.unwatermarked_images)] 

371 

372 def analyze_quality(self, 

373 prepared_data: List[Tuple[Image.Image, Image.Image]], 

374 analyzer: ImageQualityAnalyzer): 

375 """Analyze quality of watermarked and unwatermarked images.""" 

376 bar = self._get_progress_bar(prepared_data) 

377 bar.set_description(f"Analyzing quality for {analyzer.__class__.__name__}") 

378 w_scores, u_scores = [], [] 

379 # For direct analyzers, we analyze each image independently 

380 for watermarked_image, unwatermarked_image in bar: 

381 # watermarked score 

382 try: 

383 w_score = analyzer.analyze(watermarked_image) 

384 except TypeError: 

385 # analyzer expects a reference -> use unwatermarked_image as reference 

386 w_score = analyzer.analyze(watermarked_image, unwatermarked_image) 

387 # unwatermarked score 

388 try: 

389 u_score = analyzer.analyze(unwatermarked_image) 

390 except TypeError: 

391 u_score = analyzer.analyze(unwatermarked_image, watermarked_image) 

392 w_scores.append(w_score) 

393 u_scores.append(u_score) 

394 

395 return w_scores, u_scores 

396 

397 

398class ReferencedImageQualityAnalysisPipeline(ImageQualityAnalysisPipeline): 

399 """ 

400 Pipeline for referenced image quality analysis. 

401 

402 This pipeline assesses image quality by comparing both watermarked and unwatermarked  

403 images against a common reference image. It measures the degree of similarity or  

404 deviation from the reference. 

405  

406 Ideal for scenarios where the impact of watermarking on image quality needs to be  

407 assessed, particularly in relation to specific reference images or ground truth. 

408 """ 

409 

410 def __init__(self, 

411 dataset: BaseDataset, 

412 watermarked_image_editor_list: List[ImageEditor] = [], 

413 unwatermarked_image_editor_list: List[ImageEditor] = [], 

414 analyzers: List[ImageQualityAnalyzer] = None, 

415 unwatermarked_image_source: str = 'generated', 

416 reference_image_source: str = 'natural', 

417 show_progress: bool = True, 

418 store_path: str = None, 

419 return_type: QualityPipelineReturnType = QualityPipelineReturnType.MEAN_SCORES) -> None: 

420 

421 super().__init__(dataset, watermarked_image_editor_list, unwatermarked_image_editor_list, 

422 analyzers, unwatermarked_image_source, reference_image_source, show_progress, store_path, return_type) 

423 

424 def _check_compatibility(self): 

425 """Check if the pipeline is compatible with the dataset.""" 

426 # Check if we have analyzers that use text as reference 

427 has_text_analyzer = any(hasattr(analyzer, 'reference_source') and analyzer.reference_source == 'text' 

428 for analyzer in self.analyzers) 

429 

430 # If all analyzers use text reference, we don't need reference images 

431 if not has_text_analyzer and self.dataset.num_references == 0: 

432 raise ValueError(f"Reference images are required for referenced image quality analysis. Dataset {self.dataset.name} has no reference images.") 

433 

434 def _get_iterable(self): 

435 """Return an iterable for the dataset.""" 

436 return range(self.dataset.num_samples) 

437 

438 def _prepare_input_for_quality_analyzer(self, 

439 prepared_dataset: DatasetForEvaluation): 

440 """Prepare input for quality analyzer.""" 

441 return [(watermarked_image, unwatermarked_image, reference_image, prompt) 

442 for watermarked_image, unwatermarked_image, reference_image, prompt in 

443 zip(prepared_dataset.watermarked_images, prepared_dataset.unwatermarked_images, prepared_dataset.reference_images, prepared_dataset.prompts) 

444 ] 

445 

446 def analyze_quality(self, 

447 prepared_data: List[Tuple[Image.Image, Image.Image, Image.Image, str]], 

448 analyzer: ImageQualityAnalyzer): 

449 """Analyze quality of watermarked and unwatermarked images.""" 

450 bar = self._get_progress_bar(prepared_data) 

451 bar.set_description(f"Analyzing quality for {analyzer.__class__.__name__}") 

452 w_scores, u_scores = [], [] 

453 # For referenced analyzers, we compare against the reference 

454 for watermarked_image, unwatermarked_image, reference_image, prompt in bar: 

455 if analyzer.reference_source == "image": 

456 w_score = analyzer.analyze(watermarked_image, reference_image) 

457 u_score = analyzer.analyze(unwatermarked_image, reference_image) 

458 elif analyzer.reference_source == "text": 

459 w_score = analyzer.analyze(watermarked_image, prompt) 

460 u_score = analyzer.analyze(unwatermarked_image, prompt) 

461 else: 

462 raise ValueError(f"Invalid reference source: {analyzer.reference_source}") 

463 w_scores.append(w_score) 

464 u_scores.append(u_score) 

465 return w_scores, u_scores 

466 

467 

468class GroupImageQualityAnalysisPipeline(ImageQualityAnalysisPipeline): 

469 """ 

470 Pipeline for group-based image quality analysis. 

471  

472 This pipeline analyzes quality metrics that require comparing distributions 

473 of multiple images (e.g., FID). It generates all images upfront and then 

474 performs a single analysis on the entire collection. 

475 """ 

476 

477 def __init__(self, 

478 dataset: BaseDataset, 

479 watermarked_image_editor_list: List[ImageEditor] = [], 

480 unwatermarked_image_editor_list: List[ImageEditor] = [], 

481 analyzers: List[ImageQualityAnalyzer] = None, 

482 unwatermarked_image_source: str = 'generated', 

483 reference_image_source: str = 'natural', 

484 show_progress: bool = True, 

485 store_path: str = None, 

486 return_type: QualityPipelineReturnType = QualityPipelineReturnType.MEAN_SCORES) -> None: 

487 

488 super().__init__(dataset, watermarked_image_editor_list, unwatermarked_image_editor_list, 

489 analyzers, unwatermarked_image_source, reference_image_source, show_progress, store_path, return_type) 

490 

491 def _get_iterable(self): 

492 """Return an iterable for analyzers instead of dataset indices.""" 

493 return range(self.dataset.num_samples) 

494 

495 def _prepare_input_for_quality_analyzer(self, 

496 prepared_dataset: DatasetForEvaluation): 

497 """Prepare input for group analyzer.""" 

498 return [(prepared_dataset.watermarked_images, prepared_dataset.unwatermarked_images, prepared_dataset.reference_images)] 

499 

500 def analyze_quality(self, 

501 prepared_data: List[Tuple[List[Image.Image], List[Image.Image], List[Image.Image]]], 

502 analyzer: ImageQualityAnalyzer): 

503 """Analyze quality of image groups.""" 

504 bar = self._get_progress_bar(prepared_data) 

505 bar.set_description(f"Analyzing quality for {analyzer.__class__.__name__}") 

506 w_scores, u_scores = [], [] 

507 # For group analyzers, we pass the entire collection 

508 for watermarked_images, unwatermarked_images, reference_images in bar: 

509 w_score = analyzer.analyze(watermarked_images, reference_images) 

510 u_score = analyzer.analyze(unwatermarked_images, reference_images) 

511 w_scores.append(w_score) 

512 u_scores.append(u_score) 

513 return w_scores, u_scores 

514 

515 

516class RepeatImageQualityAnalysisPipeline(ImageQualityAnalysisPipeline): 

517 """ 

518 Pipeline for repeat-based image quality analysis. 

519  

520 This pipeline analyzes diversity metrics by generating multiple images 

521 for each prompt (e.g., LPIPS diversity). It generates multiple versions 

522 per prompt and analyzes the diversity within each group. 

523 """ 

524 

525 def __init__(self, 

526 dataset: BaseDataset, 

527 prompt_per_image: int = 20, 

528 watermarked_image_editor_list: List[ImageEditor] = [], 

529 unwatermarked_image_editor_list: List[ImageEditor] = [], 

530 analyzers: List[ImageQualityAnalyzer] = None, 

531 unwatermarked_image_source: str = 'generated', 

532 reference_image_source: str = 'natural', 

533 show_progress: bool = True, 

534 store_path: str = None, 

535 return_type: QualityPipelineReturnType = QualityPipelineReturnType.MEAN_SCORES) -> None: 

536 

537 super().__init__(dataset, watermarked_image_editor_list, unwatermarked_image_editor_list, 

538 analyzers, unwatermarked_image_source, reference_image_source, show_progress, store_path, return_type) 

539 

540 self.prompt_per_image = prompt_per_image 

541 

542 def _get_iterable(self): 

543 """Return an iterable for the dataset.""" 

544 return range(self.dataset.num_samples * self.prompt_per_image) 

545 

546 def _get_prompt(self, index: int) -> str: 

547 """Get prompt from dataset.""" 

548 prompt_index = index // self.prompt_per_image 

549 return self.dataset.get_prompt(prompt_index) 

550 

551 def _get_watermarked_image(self, watermark: BaseWatermark, index: int, **generation_kwargs) -> Union[Image.Image, List[Image.Image]]: 

552 """Get watermarked image.""" 

553 prompt = self._get_prompt(index) 

554 # Randomly select a generation seed 

555 generation_kwargs['gen_seed'] = random.randint(0, 1000000) 

556 return watermark.generate_watermarked_media(input_data=prompt, **generation_kwargs) 

557 

558 def _get_unwatermarked_image(self, watermark: BaseWatermark, index: int, **generation_kwargs) -> Union[Image.Image, List[Image.Image]]: 

559 """Get unwatermarked image.""" 

560 prompt = self._get_prompt(index) 

561 # Randomly select a generation seed 

562 generation_kwargs['gen_seed'] = random.randint(0, 1000000) 

563 if self.unwatermarked_image_source == 'natural': 

564 return self.dataset.get_reference(index) 

565 else: 

566 return watermark.generate_unwatermarked_media(input_data=prompt, **generation_kwargs) 

567 

568 def _prepare_input_for_quality_analyzer(self, 

569 prepared_dataset: DatasetForEvaluation): 

570 """Prepare input for diversity analyzer.""" 

571 # Group images by prompt 

572 watermarked_images = prepared_dataset.watermarked_images 

573 unwatermarked_images = prepared_dataset.unwatermarked_images 

574 

575 grouped = [] 

576 for i in range(0, len(watermarked_images), self.prompt_per_image): 

577 grouped.append( 

578 (watermarked_images[i:i+self.prompt_per_image], 

579 unwatermarked_images[i:i+self.prompt_per_image]) 

580 ) 

581 return grouped 

582 

583 def analyze_quality(self, 

584 prepared_data: List[Tuple[List[Image.Image], List[Image.Image]]], 

585 analyzer: ImageQualityAnalyzer): 

586 """Analyze diversity of image batches.""" 

587 bar = self._get_progress_bar(prepared_data) 

588 bar.set_description(f"Analyzing diversity for {analyzer.__class__.__name__}") 

589 w_scores, u_scores = [], [] 

590 # For diversity analyzers, we analyze each batch 

591 for watermarked_images, unwatermarked_images in bar: 

592 w_score = analyzer.analyze(watermarked_images) 

593 u_score = analyzer.analyze(unwatermarked_images) 

594 w_scores.append(w_score) 

595 u_scores.append(u_score) 

596 

597 return w_scores, u_scores 

598 

599 

600class ComparedImageQualityAnalysisPipeline(ImageQualityAnalysisPipeline): 

601 """ 

602 Pipeline for compared image quality analysis. 

603  

604 This pipeline directly compares watermarked and unwatermarked images 

605 to compute metrics like PSNR, SSIM, VIF, FSIM and MS-SSIM. The analyzer receives 

606 both images and outputs a single comparison score. 

607 """ 

608 

609 def __init__(self, 

610 dataset: BaseDataset, 

611 watermarked_image_editor_list: List[ImageEditor] = [], 

612 unwatermarked_image_editor_list: List[ImageEditor] = [], 

613 analyzers: List[ImageQualityAnalyzer] = None, 

614 unwatermarked_image_source: str = 'generated', 

615 reference_image_source: str = 'natural', 

616 show_progress: bool = True, 

617 store_path: str = None, 

618 return_type: QualityPipelineReturnType = QualityPipelineReturnType.MEAN_SCORES) -> None: 

619 

620 super().__init__(dataset, watermarked_image_editor_list, unwatermarked_image_editor_list, 

621 analyzers, unwatermarked_image_source, reference_image_source, show_progress, store_path, return_type) 

622 

623 def _get_iterable(self): 

624 """Return an iterable for the dataset.""" 

625 return range(self.dataset.num_samples) 

626 

627 def _prepare_input_for_quality_analyzer(self, 

628 prepared_dataset: DatasetForEvaluation): 

629 """Prepare input for comparison analyzer.""" 

630 return [(watermarked_image, unwatermarked_image) for watermarked_image, unwatermarked_image in zip(prepared_dataset.watermarked_images, prepared_dataset.unwatermarked_images)] 

631 

632 def analyze_quality(self, 

633 prepared_data: List[Tuple[Image.Image, Image.Image]], 

634 analyzer: ImageQualityAnalyzer): 

635 """Analyze quality by comparing watermarked and unwatermarked images.""" 

636 bar = self._get_progress_bar(prepared_data) 

637 bar.set_description(f"Analyzing quality for {analyzer.__class__.__name__}") 

638 w_scores, u_scores = [], [] 

639 # For comparison analyzers, we compute similarity/difference 

640 for watermarked_image, unwatermarked_image in bar: 

641 # Compare watermarked images with unwatermarked images 

642 w_score = analyzer.analyze(watermarked_image, unwatermarked_image) 

643 w_scores.append(w_score) 

644 return w_scores, u_scores # u_scores is not used for comparison analyzers