Coverage for evaluation / tools / video_quality_analyzer.py: 93.54%

325 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 13:30 +0000

1from typing import List 

2from PIL import Image 

3import torch 

4import torch.nn as nn 

5import torch.nn.functional as F 

6from torchvision import transforms 

7from torchvision.transforms import Compose, Resize, ToTensor, Normalize, CenterCrop 

8import numpy as np 

9from tqdm import tqdm 

10import cv2 

11import os 

12import subprocess 

13from utils.media_utils import pil_to_torch 

14 

15try: 

16 from torchvision.transforms import InterpolationMode 

17 BICUBIC = InterpolationMode.BICUBIC 

18except ImportError: 

19 BICUBIC = Image.BICUBIC 

20 

21from pathlib import Path 

22 

23if not hasattr(np, 'sctypes'): 

24 np.sctypes = { 

25 'int': [np.int8, np.int16, np.int32, np.int64], 

26 'uint': [np.uint8, np.uint16, np.uint32, np.uint64], 

27 'float': [np.float16, np.float32, np.float64], 

28 'complex': [np.complex64, np.complex128], 

29 'others': [bool, object, bytes, str, np.void] 

30 } 

31 

32def dino_transform_Image(n_px): 

33 """DINO transform for PIL Images.""" 

34 return Compose([ 

35 Resize(size=n_px, antialias=False), 

36 ToTensor(), 

37 Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) 

38 ]) 

39 

40 

41class VideoQualityAnalyzer: 

42 """Video quality analyzer base class.""" 

43 

44 def __init__(self): 

45 pass 

46 

47 def analyze(self, frames: List[Image.Image]): 

48 """Analyze video quality. 

49  

50 Args: 

51 frames: List of PIL Image frames representing the video 

52  

53 Returns: 

54 Quality score(s) 

55 """ 

56 raise NotImplementedError("Subclasses must implement analyze method") 

57 

58 

59class SubjectConsistencyAnalyzer(VideoQualityAnalyzer): 

60 """Analyzer for evaluating subject consistency across video frames using DINO features. 

61  

62 This analyzer measures how consistently the main subject appears across frames by: 

63 1. Extracting DINO features from each frame 

64 2. Computing cosine similarity between consecutive frames and with the first frame 

65 3. Averaging these similarities to get a consistency score 

66 """ 

67 def __init__( 

68 self, 

69 model_url: str = "https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_full_checkpoint.pth", 

70 model_path: str = "dino_vitb16_full.pth", 

71 device: str = "cuda" 

72 ): 

73 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

74 self.model_path = model_path 

75 self.model_url = model_url 

76 

77 # ensure weights exist / download automatically 

78 self._download_weights() 

79 

80 # load model via timm 

81 self.model = self._load_dino_model() 

82 self.model.eval() 

83 self.model.to(self.device) 

84 

85 def _download_weights(self): 

86 if not os.path.exists(self.model_path): 

87 import urllib 

88 print("Downloading DINO ViT-B/16 weights...") 

89 urllib.request.urlretrieve(self.model_url, self.model_path) 

90 print("Download complete:", self.model_path) 

91 else: 

92 print("Weights already exist:", self.model_path) 

93 

94 def _load_dino_model(self): 

95 import timm 

96 # timm vit-base-p16 structure 

97 model = timm.create_model( 

98 "vit_base_patch16_224", 

99 pretrained=False, 

100 num_classes=0 

101 ) 

102 

103 # load full checkpoint 

104 ckpt = torch.load(self.model_path, map_location="cpu") 

105 

106 # for full checkpoint the state dict is nested 

107 if "teacher" in ckpt: 

108 state_dict = ckpt["teacher"] 

109 elif "student" in ckpt: 

110 state_dict = ckpt["student"] 

111 else: 

112 state_dict = ckpt 

113 

114 # remove classifier head keys 

115 state_dict = {k: v for k, v in state_dict.items() if "head" not in k} 

116 

117 model.load_state_dict(state_dict, strict=False) 

118 return model 

119 

120 def transform(self, img: Image.Image) -> torch.Tensor: 

121 """Transform PIL Image to tensor for DINO model.""" 

122 transform = dino_transform_Image(224) 

123 return transform(img) 

124 

125 def analyze(self, frames: List[Image.Image]) -> float: 

126 """Analyze subject consistency across video frames. 

127  

128 Args: 

129 frames: List of PIL Image frames representing the video 

130  

131 Returns: 

132 Subject consistency score (higher is better, range [0, 1]) 

133 """ 

134 if len(frames) < 2: 

135 return 1.0 # Single frame is perfectly consistent with itself 

136 

137 video_sim = 0.0 

138 frame_count = 0 

139 

140 # Process frames and extract features 

141 with torch.no_grad(): 

142 for i, frame in enumerate(frames): 

143 # Transform and prepare frame 

144 frame_tensor = self.transform(frame).unsqueeze(0).to(self.device) 

145 

146 # Extract features 

147 features = self.model(frame_tensor) 

148 features = F.normalize(features, dim=-1, p=2) 

149 

150 if i == 0: 

151 # Store first frame features 

152 first_frame_features = features 

153 else: 

154 # Compute similarity with previous frame 

155 sim_prev = max(0.0, F.cosine_similarity(prev_features, features).item()) 

156 

157 # Compute similarity with first frame 

158 sim_first = max(0.0, F.cosine_similarity(first_frame_features, features).item()) 

159 

160 # Average the two similarities 

161 frame_sim = (sim_prev + sim_first) / 2.0 

162 video_sim += frame_sim 

163 frame_count += 1 

164 

165 # Store current features as previous for next iteration 

166 prev_features = features 

167 

168 # Return average similarity across all frame pairs 

169 if frame_count > 0: 

170 return video_sim / frame_count 

171 else: 

172 return 1.0 

173 

174# from contextlib import contextmanager 

175 

176# @contextmanager 

177# def isolated_import_context(code_dir, isolated_prefixes, prefix_tag=None): 

178# """Context manager for isolated module imports to avoid conflicts with main project. 

179 

180# Args: 

181# code_dir: External code directory to add to sys.path 

182# isolated_prefixes: List of module name prefixes to isolate (e.g., ['utils', 'networks']) 

183# prefix_tag: Tag to prefix external modules with after loading (default: code_dir.name + '_ext_') 

184 

185# Example: 

186# with isolated_import_context(CODE_DIR, ['utils', 'networks']): 

187# # imports here will use CODE_DIR's modules 

188# spec = importlib.util.spec_from_file_location("entry", CODE_DIR / "main.py") 

189# ... 

190# # after exiting, main project's 'utils' is restored 

191# """ 

192# import sys 

193 

194# if prefix_tag is None: 

195# prefix_tag = code_dir.name + '_ext_' 

196 

197# original_path = sys.path.copy() 

198# saved_modules = {} 

199 

200# # Remove potentially conflicting modules 

201# for prefix in isolated_prefixes: 

202# for mod_name in list(sys.modules.keys()): 

203# if mod_name == prefix or mod_name.startswith(prefix + '.'): 

204# saved_modules[mod_name] = sys.modules.pop(mod_name) 

205 

206# sys.path.insert(0, str(code_dir)) 

207 

208# try: 

209# yield 

210# finally: 

211# sys.path[:] = original_path 

212 

213# # Rename external modules with prefix tag to avoid future conflicts 

214# for prefix in isolated_prefixes: 

215# for mod_name in list(sys.modules.keys()): 

216# if mod_name == prefix or mod_name.startswith(prefix + '.'): 

217# if mod_name not in saved_modules: 

218# sys.modules[prefix_tag + mod_name] = sys.modules.pop(mod_name) 

219 

220# # Restore main project modules 

221# sys.modules.update(saved_modules) 

222 

223class MotionSmoothnessAnalyzer(VideoQualityAnalyzer): 

224 """Analyzer for evaluating motion smoothness in videos using AMT-S model. 

225  

226 This analyzer measures motion smoothness by: 

227 1. Extracting frames at even indices from the video 

228 2. Using AMT-S model to interpolate between consecutive frames 

229 3. Comparing interpolated frames with actual frames to compute smoothness score 

230  

231 The score represents how well the motion can be predicted/interpolated, 

232 with smoother motion resulting in higher scores. 

233 """ 

234 

235 def __init__(self, model_path: str = "model/amt/amt-s.pth", 

236 device: str = "cuda", niters: int = 1): 

237 """Initialize the MotionSmoothnessAnalyzer. 

238  

239 Args: 

240 model_path: Path to the AMT-S model checkpoint 

241 device: Device to run the model on ('cuda' or 'cpu') 

242 niters: Number of interpolation iterations (default: 1) 

243 """ 

244 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

245 self.niters = niters 

246 

247 # Initialize model parameters 

248 self._initialize_params() 

249 

250 # Load AMT-S model 

251 self.model = self._load_amt_model(model_path) 

252 self.model.eval() 

253 self.model.to(self.device) 

254 

255 def _initialize_params(self): 

256 """Initialize parameters for video processing.""" 

257 if self.device.type == 'cuda': 

258 self.anchor_resolution = 1024 * 512 

259 self.anchor_memory = 1500 * 1024**2 

260 self.anchor_memory_bias = 2500 * 1024**2 

261 self.vram_avail = torch.cuda.get_device_properties(self.device).total_memory 

262 else: 

263 # Do not resize in cpu mode 

264 self.anchor_resolution = 8192 * 8192 

265 self.anchor_memory = 1 

266 self.anchor_memory_bias = 0 

267 self.vram_avail = 1 

268 

269 # Time embedding for interpolation (t=0.5) 

270 self.embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(self.device) 

271 

272 def _load_amt_model(self, model_path: str): 

273 """Load AMT-S model. 

274  

275 Args: 

276 model_path: Path to the model checkpoint 

277  

278 Returns: 

279 Loaded AMT-S model 

280 """ 

281 # Import AMT-S model (note the hyphen in filename) 

282 import sys 

283 import importlib.util 

284 

285 # Load the module with hyphen in filename 

286 spec = importlib.util.spec_from_file_location("amt_s", "model/amt/networks/AMT-S.py") 

287 amt_s_module = importlib.util.module_from_spec(spec) 

288 spec.loader.exec_module(amt_s_module) 

289 Model = amt_s_module.Model 

290 

291 # Create model with default parameters 

292 model = Model( 

293 corr_radius=3, 

294 corr_lvls=4, 

295 num_flows=3 

296 ) 

297 

298 # Load checkpoint 

299 if os.path.exists(model_path): 

300 ckpt = torch.load(model_path, map_location="cpu", weights_only=False) 

301 model.load_state_dict(ckpt['state_dict']) 

302 

303 return model 

304 

305 def _extract_frames(self, frames: List[Image.Image], start_from: int = 0) -> List[np.ndarray]: 

306 """Extract frames at even indices starting from start_from. 

307  

308 Args: 

309 frames: List of PIL Image frames 

310 start_from: Starting index (default: 0) 

311  

312 Returns: 

313 List of extracted frames as numpy arrays 

314 """ 

315 extracted = [] 

316 for i in range(start_from, len(frames), 2): 

317 # Convert PIL Image to numpy array 

318 frame_np = np.array(frames[i]) 

319 extracted.append(frame_np) 

320 return extracted 

321 

322 def _img2tensor(self, img: np.ndarray) -> torch.Tensor: 

323 """Convert numpy image to tensor. 

324  

325 Args: 

326 img: Image as numpy array (H, W, C) 

327  

328 Returns: 

329 Image tensor (1, C, H, W) 

330 """ 

331 from model.amt.utils.utils import img2tensor 

332 return img2tensor(img) 

333 

334 def _tensor2img(self, tensor: torch.Tensor) -> np.ndarray: 

335 """Convert tensor to numpy image. 

336  

337 Args: 

338 tensor: Image tensor (1, C, H, W) 

339  

340 Returns: 

341 Image as numpy array (H, W, C) 

342 """ 

343 from model.amt.utils.utils import tensor2img 

344 return tensor2img(tensor) 

345 

346 def _check_dim_and_resize(self, tensor_list: List[torch.Tensor]) -> List[torch.Tensor]: 

347 """Check dimensions and resize tensors if needed. 

348  

349 Args: 

350 tensor_list: List of image tensors 

351  

352 Returns: 

353 List of resized tensors 

354 """ 

355 from model.amt.utils.utils import check_dim_and_resize 

356 return check_dim_and_resize(tensor_list) 

357 

358 def _calculate_scale(self, h: int, w: int) -> float: 

359 """Calculate scaling factor based on available VRAM. 

360  

361 Args: 

362 h: Height of the image 

363 w: Width of the image 

364  

365 Returns: 

366 Scaling factor 

367 """ 

368 scale = self.anchor_resolution / (h * w) * np.sqrt((self.vram_avail - self.anchor_memory_bias) / self.anchor_memory) 

369 scale = 1 if scale > 1 else scale 

370 scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16 

371 return scale 

372 

373 def _interpolate_frames(self, inputs: List[torch.Tensor], scale: float) -> List[torch.Tensor]: 

374 """Interpolate frames using AMT-S model. 

375 

376 Args: 

377 inputs: List of input frame tensors 

378 scale: Scaling factor for processing 

379 

380 Returns: 

381 List of interpolated frame tensors 

382 """ 

383 

384 from model.amt.utils.utils import InputPadder 

385 # Pad inputs 

386 padding = int(16 / scale) 

387 padder = InputPadder(inputs[0].shape, padding) 

388 inputs = padder.pad(*inputs) 

389 

390 # Perform interpolation for specified iterations 

391 for _ in range(self.niters): 

392 outputs = [inputs[0]] 

393 for in_0, in_1 in zip(inputs[:-1], inputs[1:]): 

394 in_0 = in_0.to(self.device) 

395 in_1 = in_1.to(self.device) 

396 with torch.no_grad(): 

397 imgt_pred = self.model(in_0, in_1, self.embt, scale_factor=scale, eval=True)['imgt_pred'] 

398 outputs += [imgt_pred.cpu(), in_1.cpu()] 

399 inputs = outputs 

400 

401 # Unpad outputs 

402 outputs = padder.unpad(*outputs) 

403 return outputs 

404 

405 def _compute_frame_difference(self, img1: np.ndarray, img2: np.ndarray) -> float: 

406 """Compute average absolute difference between two images. 

407  

408 Args: 

409 img1: First image 

410 img2: Second image 

411  

412 Returns: 

413 Average pixel difference 

414 """ 

415 diff = cv2.absdiff(img1, img2) 

416 return np.mean(diff) 

417 

418 def _compute_vfi_score(self, original_frames: List[np.ndarray], interpolated_frames: List[np.ndarray]) -> float: 

419 """Compute video frame interpolation score. 

420  

421 Args: 

422 original_frames: Original video frames 

423 interpolated_frames: Interpolated frames 

424  

425 Returns: 

426 VFI score (lower difference means better interpolation) 

427 """ 

428 # Extract frames at odd indices for comparison 

429 ori_compare = self._extract_frames([Image.fromarray(f) for f in original_frames], start_from=1) 

430 interp_compare = self._extract_frames([Image.fromarray(f) for f in interpolated_frames], start_from=1) 

431 

432 scores = [] 

433 for ori, interp in zip(ori_compare, interp_compare): 

434 score = self._compute_frame_difference(ori, interp) 

435 scores.append(score) 

436 

437 return np.mean(scores) if scores else 0.0 

438 

439 def analyze(self, frames: List[Image.Image]) -> float: 

440 """Analyze motion smoothness in video frames. 

441  

442 Args: 

443 frames: List of PIL Image frames representing the video 

444  

445 Returns: 

446 Motion smoothness score (higher is better, range [0, 1]) 

447 """ 

448 if len(frames) < 2: 

449 return 1.0 # Single frame has perfect smoothness 

450 

451 # Convert PIL Images to numpy arrays 

452 np_frames = [np.array(frame) for frame in frames] 

453 

454 # Extract frames at even indices 

455 frame_list = self._extract_frames(frames, start_from=0) 

456 

457 # Convert to tensors 

458 inputs = [self._img2tensor(frame).to(self.device) for frame in frame_list] 

459 

460 if len(inputs) <= 1: 

461 return 1.0 # Not enough frames for interpolation 

462 

463 # Check dimensions and resize if needed 

464 inputs = self._check_dim_and_resize(inputs) 

465 h, w = inputs[0].shape[-2:] 

466 

467 # Calculate scale based on available memory 

468 scale = self._calculate_scale(h, w) 

469 

470 # Perform frame interpolation 

471 outputs = self._interpolate_frames(inputs, scale) 

472 

473 # Convert outputs back to images 

474 output_images = [self._tensor2img(out) for out in outputs] 

475 

476 # Compute VFI score 

477 vfi_score = self._compute_vfi_score(np_frames, output_images) 

478 

479 # Normalize score to [0, 1] range (higher is better) 

480 # Original score is average pixel difference [0, 255], we normalize and invert 

481 normalized_score = (255.0 - vfi_score) / 255.0 

482 

483 return normalized_score 

484 

485 

486class DynamicDegreeAnalyzer(VideoQualityAnalyzer): 

487 """Analyzer for evaluating dynamic degree (motion intensity) in videos using RAFT optical flow. 

488  

489 This analyzer measures the amount and intensity of motion in videos by: 

490 1. Computing optical flow between consecutive frames using RAFT 

491 2. Calculating flow magnitude for each pixel 

492 3. Extracting top 5% highest flow magnitudes 

493 4. Determining if video has sufficient dynamic motion based on thresholds 

494  

495 The score represents whether the video contains dynamic motion (1.0) or is mostly static (0.0). 

496 """ 

497 

498 def __init__(self, model_path: str = "model/raft/raft-things.pth", 

499 device: str = "cuda", sample_fps: int = 8): 

500 """Initialize the DynamicDegreeAnalyzer. 

501  

502 Args: 

503 model_path: Path to the RAFT model checkpoint 

504 device: Device to run the model on ('cuda' or 'cpu') 

505 sample_fps: Target FPS for frame sampling (default: 8) 

506 """ 

507 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

508 self.sample_fps = sample_fps 

509 

510 # Load RAFT model 

511 self.model = self._load_raft_model(model_path) 

512 self.model.eval() 

513 self.model.to(self.device) 

514 

515 def _load_raft_model(self, model_path: str): 

516 """Load RAFT optical flow model. 

517  

518 Args: 

519 model_path: Path to the model checkpoint 

520  

521 Returns: 

522 Loaded RAFT model 

523 """ 

524 from model.raft.core.raft import RAFT 

525 from easydict import EasyDict as edict 

526 

527 # Configure RAFT arguments 

528 args = edict({ 

529 "model": model_path, 

530 "small": False, 

531 "mixed_precision": False, 

532 "alternate_corr": False 

533 }) 

534 

535 # Create and load model 

536 model = RAFT(args) 

537 

538 if os.path.exists(model_path): 

539 ckpt = torch.load(model_path, map_location="cpu") 

540 # Remove 'module.' prefix if present (from DataParallel) 

541 new_ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()} 

542 model.load_state_dict(new_ckpt) 

543 

544 return model 

545 

546 def _extract_frames_for_flow(self, frames: List[Image.Image], target_fps: int = 8) -> List[torch.Tensor]: 

547 """Extract and prepare frames for optical flow computation. 

548  

549 Args: 

550 frames: List of PIL Image frames 

551 target_fps: Target sampling rate (default: 8 fps) 

552  

553 Returns: 

554 List of prepared frame tensors 

555 """ 

556 # Estimate original FPS and calculate sampling interval 

557 # Assuming 30fps original video, adjust sampling to get ~8fps 

558 total_frames = len(frames) 

559 assumed_fps = 30 # Common video fps 

560 interval = max(1, round(assumed_fps / target_fps)) 

561 

562 # Sample frames at interval 

563 sampled_frames = [] 

564 for i in range(0, total_frames, interval): 

565 frame = frames[i] 

566 # Convert PIL to numpy array 

567 frame_np = np.array(frame) 

568 # Convert to tensor and normalize 

569 frame_tensor = torch.from_numpy(frame_np.astype(np.uint8)).permute(2, 0, 1).float() 

570 frame_tensor = frame_tensor[None].to(self.device) 

571 sampled_frames.append(frame_tensor) 

572 

573 return sampled_frames 

574 

575 def _compute_flow_magnitude(self, flow: torch.Tensor) -> float: 

576 """Compute flow magnitude score from optical flow. 

577  

578 Args: 

579 flow: Optical flow tensor (B, 2, H, W) 

580  

581 Returns: 

582 Flow magnitude score 

583 """ 

584 # Extract flow components 

585 flow_np = flow[0].permute(1, 2, 0).cpu().numpy() 

586 u = flow_np[:, :, 0] 

587 v = flow_np[:, :, 1] 

588 

589 # Compute flow magnitude 

590 magnitude = np.sqrt(np.square(u) + np.square(v)) 

591 

592 # Get top 5% highest magnitudes 

593 h, w = magnitude.shape 

594 magnitude_flat = magnitude.flatten() 

595 cut_index = int(h * w * 0.05) 

596 

597 # Sort in descending order and take mean of top 5% 

598 top_magnitudes = np.sort(-magnitude_flat)[:cut_index] 

599 mean_magnitude = np.mean(np.abs(top_magnitudes)) 

600 

601 return mean_magnitude.item() 

602 

603 def _determine_dynamic_threshold(self, frame_shape: tuple, num_frames: int) -> dict: 

604 """Determine thresholds for dynamic motion detection. 

605  

606 Args: 

607 frame_shape: Shape of the frame tensor 

608 num_frames: Number of frames in the video 

609  

610 Returns: 

611 Dictionary with threshold parameters 

612 """ 

613 # Scale threshold based on image resolution 

614 scale = min(frame_shape[-2:]) # min of height and width 

615 magnitude_threshold = 6.0 * (scale / 256.0) 

616 

617 # Scale count threshold based on number of frames 

618 count_threshold = round(4 * (num_frames / 16.0)) 

619 

620 return { 

621 "magnitude_threshold": magnitude_threshold, 

622 "count_threshold": count_threshold 

623 } 

624 

625 def _check_dynamic_motion(self, flow_scores: List[float], thresholds: dict) -> bool: 

626 """Check if video has dynamic motion based on flow scores. 

627  

628 Args: 

629 flow_scores: List of optical flow magnitude scores 

630 thresholds: Threshold parameters 

631  

632 Returns: 

633 True if video has dynamic motion, False otherwise 

634 """ 

635 magnitude_threshold = thresholds["magnitude_threshold"] 

636 count_threshold = thresholds["count_threshold"] 

637 

638 # Count frames with significant motion 

639 motion_count = 0 

640 for score in flow_scores: 

641 if score > magnitude_threshold: 

642 motion_count += 1 

643 if motion_count >= count_threshold: 

644 return True 

645 

646 return False 

647 

648 def analyze(self, frames: List[Image.Image]) -> float: 

649 """Analyze dynamic degree (motion intensity) in video frames. 

650  

651 Args: 

652 frames: List of PIL Image frames representing the video 

653  

654 Returns: 

655 Dynamic degree score: 1.0 if video has dynamic motion, 0.0 if mostly static 

656 """ 

657 if len(frames) < 2: 

658 return 0.0 # Cannot compute optical flow with less than 2 frames 

659 

660 # Extract and prepare frames for optical flow 

661 prepared_frames = self._extract_frames_for_flow(frames, self.sample_fps) 

662 

663 if len(prepared_frames) < 2: 

664 return 0.0 

665 

666 # Determine thresholds based on video characteristics 

667 thresholds = self._determine_dynamic_threshold( 

668 prepared_frames[0].shape, 

669 len(prepared_frames) 

670 ) 

671 

672 # Compute optical flow between consecutive frames 

673 flow_scores = [] 

674 

675 with torch.no_grad(): 

676 for frame1, frame2 in zip(prepared_frames[:-1], prepared_frames[1:]): 

677 # Pad frames if necessary 

678 from model.raft.core.utils_core.utils import InputPadder 

679 padder = InputPadder(frame1.shape) 

680 frame1_padded, frame2_padded = padder.pad(frame1, frame2) 

681 

682 # Compute optical flow 

683 _, flow_up = self.model(frame1_padded, frame2_padded, iters=20, test_mode=True) 

684 

685 # Calculate flow magnitude score 

686 magnitude_score = self._compute_flow_magnitude(flow_up) 

687 flow_scores.append(magnitude_score) 

688 

689 # Check if video has dynamic motion 

690 has_dynamic_motion = self._check_dynamic_motion(flow_scores, thresholds) 

691 

692 # Return binary score: 1.0 for dynamic, 0.0 for static 

693 return 1.0 if has_dynamic_motion else 0.0 

694 

695 

696class BackgroundConsistencyAnalyzer(VideoQualityAnalyzer): 

697 """Analyzer for evaluating background consistency across video frames using CLIP features. 

698  

699 This analyzer measures how consistently the background appears across frames by: 

700 1. Extracting CLIP visual features from each frame 

701 2. Computing cosine similarity between consecutive frames and with the first frame 

702 3. Averaging these similarities to get a consistency score 

703  

704 Similar to SubjectConsistencyAnalyzer but focuses on overall visual consistency 

705 including background elements, making it suitable for detecting background stability. 

706 """ 

707 

708 def __init__(self, model_name: str = "ViT-B/32", device: str = "cuda"): 

709 """Initialize the BackgroundConsistencyAnalyzer. 

710  

711 Args: 

712 model_name: CLIP model name (default: "ViT-B/32") 

713 device: Device to run the model on ('cuda' or 'cpu') 

714 """ 

715 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

716 

717 # Load CLIP model 

718 self.model, self.preprocess = self._load_clip_model(model_name) 

719 self.model.eval() 

720 self.model.to(self.device) 

721 

722 # Image transform for CLIP (when processing tensor inputs) 

723 self.tensor_transform = self._get_clip_tensor_transform(224) 

724 

725 def _load_clip_model(self, model_name: str): 

726 """Load CLIP model. 

727  

728 Args: 

729 model_name: Name of the CLIP model to load 

730  

731 Returns: 

732 Tuple of (model, preprocess_function) 

733 """ 

734 import clip 

735 

736 model, preprocess = clip.load(model_name, device=self.device) 

737 return model, preprocess 

738 

739 def _get_clip_tensor_transform(self, n_px: int): 

740 """Get CLIP transform for tensor inputs. 

741  

742 Args: 

743 n_px: Target image size 

744  

745 Returns: 

746 Transform composition for tensor inputs 

747 """ 

748 try: 

749 from torchvision.transforms import InterpolationMode 

750 BICUBIC = InterpolationMode.BICUBIC 

751 except ImportError: 

752 BICUBIC = Image.BICUBIC 

753 

754 return Compose([ 

755 Resize(n_px, interpolation=BICUBIC, antialias=False), 

756 CenterCrop(n_px), 

757 transforms.Lambda(lambda x: x.float().div(255.0)), 

758 Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), 

759 ]) 

760 

761 def _prepare_images_for_clip(self, frames: List[Image.Image]) -> torch.Tensor: 

762 """Prepare PIL images for CLIP processing. 

763  

764 Args: 

765 frames: List of PIL Image frames 

766  

767 Returns: 

768 Batch tensor of preprocessed images 

769 """ 

770 # Use CLIP's built-in preprocess for PIL images 

771 images = [] 

772 for frame in frames: 

773 processed = self.preprocess(frame) 

774 images.append(processed) 

775 

776 # Stack into batch tensor 

777 return torch.stack(images).to(self.device) 

778 

779 def analyze(self, frames: List[Image.Image]) -> float: 

780 """Analyze background consistency across video frames. 

781  

782 Args: 

783 frames: List of PIL Image frames representing the video 

784  

785 Returns: 

786 Background consistency score (higher is better, range [0, 1]) 

787 """ 

788 if len(frames) < 2: 

789 return 1.0 # Single frame is perfectly consistent with itself 

790 

791 # Prepare images for CLIP 

792 images = self._prepare_images_for_clip(frames) 

793 

794 # Extract CLIP features 

795 with torch.no_grad(): 

796 image_features = self.model.encode_image(images) 

797 image_features = F.normalize(image_features, dim=-1, p=2) 

798 

799 video_sim = 0.0 

800 frame_count = 0 

801 

802 # Compute similarity between frames 

803 for i in range(len(image_features)): 

804 image_feature = image_features[i].unsqueeze(0) 

805 

806 if i == 0: 

807 # Store first frame features 

808 first_image_feature = image_feature 

809 else: 

810 # Compute similarity with previous frame 

811 sim_prev = max(0.0, F.cosine_similarity(former_image_feature, image_feature).item()) 

812 

813 # Compute similarity with first frame 

814 sim_first = max(0.0, F.cosine_similarity(first_image_feature, image_feature).item()) 

815 

816 # Average the two similarities 

817 frame_sim = (sim_prev + sim_first) / 2.0 

818 video_sim += frame_sim 

819 frame_count += 1 

820 

821 # Store current features as previous for next iteration 

822 former_image_feature = image_feature 

823 

824 # Return average similarity across all frame pairs 

825 if frame_count > 0: 

826 return video_sim / frame_count 

827 else: 

828 return 1.0 

829 

830class ImagingQualityAnalyzer(VideoQualityAnalyzer): 

831 """Analyzer for evaluating imaging quality of videos. 

832  

833 This analyzer measures the quality of videos by: 

834 1. Inputting frames into MUSIQ image quality predictor 

835 2. Determining if the video is blurry or has artifacts 

836  

837 The score represents the quality of the video (higher is better). 

838 """ 

839 def __init__(self, model_path: str = "model/musiq/musiq_spaq_ckpt-358bb6af.pth", device: str = "cuda"): 

840 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

841 self.model = self._load_musiq(model_path) 

842 self.model.to(self.device) 

843 self.model.eval() 

844 

845 def _load_musiq(self, model_path: str): 

846 """Load MUSIQ model. 

847  

848 Args: 

849 model_path: Path to the MUSIQ model checkpoint 

850  

851 Returns: 

852 MUSIQ model 

853 """ 

854 # from pathlib import Path 

855 # CACHE_DIR = Path("/tmp/musiq_cache") 

856 # model_path = CACHE_DIR / model_path 

857 

858 # if the model_path not exists 

859 # then makedir and wget 

860 if not os.path.exists(model_path): 

861 os.makedirs(os.path.dirname(model_path), exist_ok=True) 

862 wget_command = ['wget', 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/musiq_spaq_ckpt-358bb6af.pth', '-P', os.path.dirname(model_path)] 

863 subprocess.run(wget_command, check=True) 

864 try: 

865 from pyiqa.archs.musiq_arch import MUSIQ 

866 except ImportError: 

867 raise ImportError("Please install pyiqa to use ImagingQualityAnalyzer: pip install pyiqa") 

868 model = MUSIQ(pretrained_model_path=str(model_path)) 

869 

870 return model 

871 

872 def _preprocess_frames(self, frames: List[Image.Image]) -> torch.Tensor: 

873 """Preprocess frames for MUSIQ model. 

874  

875 Args: 

876 frames: List of PIL Image frames 

877  

878 Returns: 

879 Preprocessed frames as tensor 

880 """ 

881 frames = [pil_to_torch(frame, normalize=False) for frame in frames] # [(C, H, W)] 

882 frames = torch.stack(frames) # (T, C, H, W) 

883 

884 _, _, h, w = frames.size() 

885 if max(h, w) > 512: 

886 scale = 512./max(h, w) 

887 frames = F.interpolate(frames, size=(int(scale * h), int(scale * w)), mode='bilinear', align_corners=False) 

888 

889 return frames 

890 

891 def analyze(self, frames: List[Image.Image]) -> float: 

892 """Analyze imaging quality of video frames. 

893  

894 Args: 

895 frames: List of PIL Image frames representing the video 

896  

897 Returns: 

898 Imaging quality score (higher is better, range [0, 1]) 

899 """ 

900 frame_tensor = self._preprocess_frames(frames) 

901 acc_score_video = 0.0 

902 for i in range(len(frame_tensor)): 

903 frame = frame_tensor[i].unsqueeze(0).to(self.device) 

904 score = self.model(frame) 

905 acc_score_video += float(score) 

906 return acc_score_video / (100 * len(frame_tensor))