Coverage for evaluation / pipelines / video_quality_analysis.py: 98.34%

181 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 10:24 +0000

1from enum import Enum, auto 

2from dataclasses import dataclass, field 

3from typing import List, Union, Dict, Tuple 

4from PIL import Image 

5from evaluation.tools.image_editor import ImageEditor 

6from evaluation.tools.video_editor import VideoEditor 

7from evaluation.tools.video_quality_analyzer import VideoQualityAnalyzer 

8from evaluation.dataset import BaseDataset 

9from watermark.base import BaseWatermark 

10import os 

11import numpy as np 

12from tqdm import tqdm 

13 

14 

15class SilentProgressBar: 

16 """A silent progress bar wrapper that supports set_description but shows no output.""" 

17 

18 def __init__(self, iterable): 

19 self.iterable = iterable 

20 

21 def __iter__(self): 

22 return iter(self.iterable) 

23 

24 def set_description(self, desc): 

25 """No-op for silent mode.""" 

26 pass 

27 

28class QualityPipelineReturnType(Enum): 

29 """Return type of the image quality analysis pipeline.""" 

30 FULL = auto() 

31 SCORES = auto() 

32 MEAN_SCORES = auto() 

33 

34@dataclass 

35class DatasetForEvaluation: 

36 """Dataset for evaluation.""" 

37 watermarked_videos: List[List[Image.Image]] = field(default_factory=list) 

38 unwatermarked_videos: List[List[Image.Image]] = field(default_factory=list) 

39 reference_videos: List[List[Image.Image]] = field(default_factory=list) 

40 indexes: List[int] = field(default_factory=list) 

41 

42class QualityComparisonResult: 

43 """Result of quality comparison.""" 

44 

45 def __init__(self, 

46 store_path: str, 

47 watermarked_quality_scores: Dict[str, List[float]], 

48 unwatermarked_quality_scores: Dict[str, List[float]], 

49 prompts: List[str], 

50 ) -> None: 

51 """ 

52 Initialize the image quality comparison result. 

53 

54 Parameters: 

55 store_path: The path to store the results. 

56 watermarked_quality_scores: The quality scores of the watermarked image. 

57 unwatermarked_quality_scores: The quality scores of the unwatermarked image. 

58 prompts: The prompts used to generate the images. 

59 """ 

60 self.store_path = store_path 

61 self.watermarked_quality_scores = watermarked_quality_scores 

62 self.unwatermarked_quality_scores = unwatermarked_quality_scores 

63 self.prompts = prompts 

64 

65 

66class VideoQualityAnalysisPipeline: 

67 """Pipeline for video quality analysis.""" 

68 

69 def __init__(self, 

70 dataset: BaseDataset, 

71 watermarked_video_editor_list: List[VideoEditor] = [], 

72 unwatermarked_video_editor_list: List[VideoEditor] = [], 

73 watermarked_frame_editor_list: List[ImageEditor] = [], 

74 unwatermarked_frame_editor_list: List[ImageEditor] = [], 

75 analyzers: List[VideoQualityAnalyzer] = None, 

76 show_progress: bool = True, 

77 store_path: str = None, 

78 return_type: QualityPipelineReturnType = QualityPipelineReturnType.MEAN_SCORES) -> None: 

79 """Initialize the image quality analysis pipeline. 

80 

81 Args: 

82 dataset (BaseDataset): The dataset for evaluation. 

83 watermarked_video_editor_list (List[VideoEditor], optional): The list of video editors for watermarked videos. Defaults to []. 

84 unwatermarked_video_editor_list (List[VideoEditor], optional): List of quality analyzers for videos. Defaults to []. 

85 watermarked_frame_editor_list (List[ImageEditor], optional): List of image editors for editing individual watermarked frames. Defaults to []. 

86 unwatermarked_frame_editor_list (List[ImageEditor], optional): List of image editors for editing individual unwatermarked frames. Defaults to []. 

87 analyzers (List[VideoQualityAnalyzer], optional): Whether to show progress. Defaults to None. 

88 show_progress (bool, optional): The path to store the results. Defaults to True. 

89 store_path (str, optional): The path to store the results. Defaults to None. 

90 return_type (QualityPipelineReturnType, optional): The return type of the pipeline. Defaults to QualityPipelineReturnType.MEAN_SCORES. 

91 """ 

92 self.dataset = dataset 

93 self.watermarked_video_editor_list = watermarked_video_editor_list 

94 self.unwatermarked_video_editor_list = unwatermarked_video_editor_list 

95 self.watermarked_frame_editor_list = watermarked_frame_editor_list 

96 self.unwatermarked_frame_editor_list = unwatermarked_frame_editor_list 

97 self.analyzers = analyzers or [] 

98 self.show_progress = show_progress 

99 self.store_path = store_path 

100 self.return_type = return_type 

101 

102 def _check_compatibility(self): 

103 """Check if the pipeline is compatible with the dataset.""" 

104 pass 

105 

106 def _get_iterable(self): 

107 """Return an iterable for the dataset.""" 

108 pass 

109 

110 def _get_progress_bar(self, iterable): 

111 """Return an iterable possibly wrapped with a progress bar.""" 

112 if self.show_progress: 

113 return tqdm(iterable, desc="Processing", leave=True) 

114 return SilentProgressBar(iterable) 

115 

116 def _get_prompt(self, index: int) -> str: 

117 """Get prompt from dataset.""" 

118 return self.dataset.get_prompt(index) 

119 

120 def _get_watermarked_video(self, watermark: BaseWatermark, index: int, **generation_kwargs) -> List[Image.Image]: 

121 """Generate watermarked image from dataset.""" 

122 prompt = self._get_prompt(index) 

123 frames = watermark.generate_watermarked_media(input_data=prompt, **generation_kwargs) 

124 return frames 

125 

126 def _get_unwatermarked_video(self, watermark: BaseWatermark, index: int, **generation_kwargs) -> List[Image.Image]: 

127 """Generate or retrieve unwatermarked image from dataset.""" 

128 prompt = self._get_prompt(index) 

129 frames = watermark.generate_unwatermarked_media(input_data=prompt, **generation_kwargs) 

130 return frames 

131 

132 def _edit_watermarked_video(self, frames: List[Image.Image]) -> List[Image.Image]: 

133 """Edit watermarked image using image editors.""" 

134 # Step 1: Edit all frames using video editors 

135 for video_editor in self.watermarked_video_editor_list: 

136 frames = video_editor.edit(frames) 

137 # Step 2: Edit individual frames using image editors 

138 for frame_editor in self.watermarked_frame_editor_list: 

139 frames = [frame_editor.edit(frame) for frame in frames] 

140 return frames 

141 

142 def _edit_unwatermarked_video(self, frames: List[Image.Image]) -> List[Image.Image]: 

143 """Edit unwatermarked image using image editors.""" 

144 # Step 1: Edit all frames using video editors 

145 for video_editor in self.unwatermarked_video_editor_list: 

146 frames = video_editor.edit(frames) 

147 # Step 2: Edit individual frames using image editors 

148 for frame_editor in self.unwatermarked_frame_editor_list: 

149 frames = [frame_editor.edit(frame) for frame in frames] 

150 return frames 

151 

152 def _prepare_dataset(self, watermark: BaseWatermark, **generation_kwargs) -> DatasetForEvaluation: 

153 """ 

154 Prepare and generate all necessary data for quality analysis. 

155  

156 This method should be overridden by subclasses to implement specific 

157 data preparation logic based on the analysis requirements. 

158  

159 Parameters: 

160 watermark: The watermark algorithm instance. 

161 generation_kwargs: Additional generation parameters. 

162  

163 Returns: 

164 DatasetForEvaluation object containing all prepared data. 

165 """ 

166 dataset_eval = DatasetForEvaluation() 

167 

168 # Generate all videos 

169 bar = self._get_progress_bar(self._get_iterable()) 

170 bar.set_description("Generating videos for quality analysis") 

171 for index in bar: 

172 # Generate and edit watermarked image 

173 watermarked_frames = self._get_watermarked_video(watermark, index, **generation_kwargs) 

174 watermarked_frames = self._edit_watermarked_video(watermarked_frames) 

175 

176 # Generate and edit unwatermarked image 

177 unwatermarked_frames = self._get_unwatermarked_video(watermark, index, **generation_kwargs) 

178 unwatermarked_frames = self._edit_unwatermarked_video(unwatermarked_frames) 

179 

180 dataset_eval.watermarked_videos.append(watermarked_frames) 

181 dataset_eval.unwatermarked_videos.append(unwatermarked_frames) 

182 dataset_eval.indexes.append(index) 

183 

184 if self.dataset.num_references > 0: 

185 reference_frames = self.dataset.get_reference(index) 

186 dataset_eval.reference_videos.append(reference_frames) 

187 

188 return dataset_eval 

189 

190 def _prepare_input_for_quality_analyzer(self, 

191 watermarked_videos: List[List[Image.Image]], 

192 unwatermarked_videos: List[List[Image.Image]], 

193 reference_videos: List[List[Image.Image]]): 

194 """ Prepare input for quality analyzer. 

195  

196 Args: 

197 watermarked_videos (List[List[Image.Image]]): Watermarked video(s) 

198 unwatermarked_videos (List[List[Image.Image]]): Unwatermarked video(s) 

199 reference_videos (List[List[Image.Image]]): Reference video if available 

200 """ 

201 pass 

202 

203 def _store_results(self, prepared_dataset: DatasetForEvaluation): 

204 """Store results.""" 

205 os.makedirs(self.store_path, exist_ok=True) 

206 dataset_name = self.dataset.name 

207 

208 for (index, watermarked_video, unwatermarked_video) in zip(prepared_dataset.indexes, prepared_dataset.watermarked_videos, prepared_dataset.unwatermarked_videos): 

209 # unwatermarked/watermarked_video is List[Image.Image], so first make a video from the frames 

210 save_dir = os.path.join(self.store_path, f"{self.__class__.__name__}_{dataset_name}_watermarked_prompt{index}") 

211 os.makedirs(save_dir, exist_ok=True) 

212 for i, frame in enumerate(watermarked_video): 

213 frame.save(os.path.join(save_dir, f"frame_{i}.png")) 

214 

215 save_dir = os.path.join(self.store_path, f"{self.__class__.__name__}_{dataset_name}_unwatermarked_prompt{index}") 

216 os.makedirs(save_dir, exist_ok=True) 

217 for i, frame in enumerate(unwatermarked_video): 

218 frame.save(os.path.join(save_dir, f"frame_{i}.png")) 

219 

220 if self.dataset.num_references > 0: 

221 reference_frames = self.dataset.get_reference(index) 

222 save_dir = os.path.join(self.store_path, f"{self.__class__.__name__}_{dataset_name}_reference_prompt{index}") 

223 os.makedirs(save_dir, exist_ok=True) 

224 for i, frame in enumerate(reference_frames): 

225 frame.save(os.path.join(save_dir, f"frame_{i}.png")) 

226 

227 def analyze_quality(self, prepared_data, analyzer): 

228 """Analyze quality of watermarked and unwatermarked images.""" 

229 pass 

230 

231 def evaluate(self, watermark: BaseWatermark, generation_kwargs={}): 

232 """Conduct evaluation utilizing the pipeline.""" 

233 # Check compatibility 

234 self._check_compatibility() 

235 

236 # Prepare dataset 

237 prepared_dataset = self._prepare_dataset(watermark, **generation_kwargs) 

238 

239 # Store results 

240 if self.store_path: 

241 self._store_results(prepared_dataset) 

242 

243 # Prepare input for quality analyzer 

244 prepared_data = self._prepare_input_for_quality_analyzer( 

245 prepared_dataset.watermarked_videos, 

246 prepared_dataset.unwatermarked_videos, 

247 prepared_dataset.reference_videos 

248 ) 

249 

250 # Analyze quality 

251 all_scores = {} 

252 for analyzer in self.analyzers: 

253 w_scores, u_scores = self.analyze_quality(prepared_data, analyzer) 

254 analyzer_name = analyzer.__class__.__name__ 

255 all_scores[analyzer_name] = (w_scores, u_scores) 

256 

257 # Get prompts and indexes 

258 prompts = [] 

259 

260 # For other pipelines 

261 for idx in prepared_dataset.indexes: 

262 prompts.append(self._get_prompt(idx)) 

263 

264 # Create result 

265 watermarked_scores = {} 

266 unwatermarked_scores = {} 

267 

268 for analyzer_name, (w_scores, u_scores) in all_scores.items(): 

269 watermarked_scores[analyzer_name] = w_scores 

270 unwatermarked_scores[analyzer_name] = u_scores 

271 

272 result = QualityComparisonResult( 

273 store_path=self.store_path, 

274 watermarked_quality_scores=watermarked_scores, 

275 unwatermarked_quality_scores=unwatermarked_scores, 

276 prompts=prompts, 

277 ) 

278 

279 # Format results based on return_type 

280 return self._format_results(result) 

281 

282 def _format_results(self, result: QualityComparisonResult): 

283 """Format results based on return_type.""" 

284 if self.return_type == QualityPipelineReturnType.FULL: 

285 return result 

286 elif self.return_type == QualityPipelineReturnType.SCORES: 

287 return { 

288 'watermarked': result.watermarked_quality_scores, 

289 'unwatermarked': result.unwatermarked_quality_scores, 

290 'prompts': result.prompts 

291 } 

292 elif self.return_type == QualityPipelineReturnType.MEAN_SCORES: 

293 # Calculate mean scores for each analyzer 

294 mean_watermarked = {} 

295 mean_unwatermarked = {} 

296 

297 for analyzer_name, scores in result.watermarked_quality_scores.items(): 

298 if isinstance(scores, list) and len(scores) > 0: 

299 mean_watermarked[analyzer_name] = np.mean(scores) 

300 else: 

301 mean_watermarked[analyzer_name] = scores 

302 

303 for analyzer_name, scores in result.unwatermarked_quality_scores.items(): 

304 if isinstance(scores, list) and len(scores) > 0: 

305 mean_unwatermarked[analyzer_name] = np.mean(scores) 

306 else: 

307 mean_unwatermarked[analyzer_name] = scores 

308 

309 return { 

310 'watermarked': mean_watermarked, 

311 'unwatermarked': mean_unwatermarked 

312 } 

313 

314class DirectVideoQualityAnalysisPipeline(VideoQualityAnalysisPipeline): 

315 """Pipeline for direct video quality analysis.""" 

316 

317 def __init__(self, 

318 dataset: BaseDataset, 

319 watermarked_video_editor_list: List[VideoEditor] = [], 

320 unwatermarked_video_editor_list: List[VideoEditor] = [], 

321 watermarked_frame_editor_list: List[ImageEditor] = [], 

322 unwatermarked_frame_editor_list: List[ImageEditor] = [], 

323 analyzers: List[VideoQualityAnalyzer] = None, 

324 show_progress: bool = True, 

325 store_path: str = None, 

326 return_type: QualityPipelineReturnType = QualityPipelineReturnType.MEAN_SCORES) -> None: 

327 """Initialize the video quality analysis pipeline. 

328 

329 Args: 

330 dataset (BaseDataset): The dataset for evaluation. 

331 watermarked_video_editor_list (List[VideoEditor], optional): The list of video editors for watermarked videos. Defaults to []. 

332 unwatermarked_video_editor_list (List[VideoEditor], optional): List of quality analyzers for videos. Defaults to []. 

333 watermarked_frame_editor_list (List[ImageEditor], optional): List of image editors for editing individual watermarked frames. Defaults to []. 

334 unwatermarked_frame_editor_list (List[ImageEditor], optional): List of image editors for editing individual unwatermarked frames. Defaults to []. 

335 analyzers (List[VideoQualityAnalyzer], optional): Whether to show progress. Defaults to None. 

336 show_progress (bool, optional): Whether to show progress. Defaults to True. 

337 store_path (str, optional): The path to store the results. Defaults to None. 

338 return_type (QualityPipelineReturnType, optional): The return type of the pipeline. Defaults to QualityPipelineReturnType.MEAN_SCORES. 

339 """ 

340 super().__init__(dataset, watermarked_video_editor_list, unwatermarked_video_editor_list, watermarked_frame_editor_list, unwatermarked_frame_editor_list, analyzers, show_progress, store_path, return_type) 

341 

342 def _get_iterable(self): 

343 """Return an iterable for the dataset.""" 

344 return range(self.dataset.num_samples) 

345 

346 def _get_prompt(self, index: int) -> str: 

347 """Get prompt from dataset.""" 

348 return self.dataset.get_prompt(index) 

349 

350 def _get_watermarked_video(self, watermark: BaseWatermark, index: int, **generation_kwargs) -> List[Image.Image]: 

351 """Generate watermarked video from dataset.""" 

352 prompt = self._get_prompt(index) 

353 frames = watermark.generate_watermarked_media(input_data=prompt, **generation_kwargs) 

354 return frames 

355 

356 def _get_unwatermarked_video(self, watermark: BaseWatermark, index: int, **generation_kwargs) -> List[Image.Image]: 

357 """Generate or retrieve unwatermarked video from dataset.""" 

358 prompt = self._get_prompt(index) 

359 frames = watermark.generate_unwatermarked_media(input_data=prompt, **generation_kwargs) 

360 return frames 

361 

362 def _prepare_input_for_quality_analyzer(self, 

363 watermarked_videos: List[List[Image.Image]], 

364 unwatermarked_videos: List[List[Image.Image]], 

365 reference_videos: List[List[Image.Image]]): 

366 """Prepare input for quality analyzer.""" 

367 # Group videos by prompt 

368 return watermarked_videos, unwatermarked_videos 

369 

370 def analyze_quality(self, 

371 prepared_data: Tuple[List[List[Image.Image]], List[List[Image.Image]], List[List[Image.Image]]], 

372 analyzer: VideoQualityAnalyzer): 

373 """Analyze quality of watermarked and unwatermarked videos.""" 

374 watermarked_videos, unwatermarked_videos = prepared_data 

375 

376 # Create pairs of watermarked and unwatermarked videos 

377 video_pairs = list(zip(watermarked_videos, unwatermarked_videos)) 

378 

379 bar = self._get_progress_bar(video_pairs) 

380 bar.set_description(f"Analyzing quality for {analyzer.__class__.__name__}") 

381 w_scores, u_scores = [], [] 

382 for watermarked_video, unwatermarked_video in bar: 

383 w_score = analyzer.analyze(watermarked_video) 

384 u_score = analyzer.analyze(unwatermarked_video) 

385 w_scores.append(w_score) 

386 u_scores.append(u_score) 

387 return w_scores, u_scores