Coverage for evaluation / tools / video_quality_analyzer.py: 93.54%
325 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 13:30 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 13:30 +0000
1from typing import List
2from PIL import Image
3import torch
4import torch.nn as nn
5import torch.nn.functional as F
6from torchvision import transforms
7from torchvision.transforms import Compose, Resize, ToTensor, Normalize, CenterCrop
8import numpy as np
9from tqdm import tqdm
10import cv2
11import os
12import subprocess
13from utils.media_utils import pil_to_torch
15try:
16 from torchvision.transforms import InterpolationMode
17 BICUBIC = InterpolationMode.BICUBIC
18except ImportError:
19 BICUBIC = Image.BICUBIC
21from pathlib import Path
23if not hasattr(np, 'sctypes'):
24 np.sctypes = {
25 'int': [np.int8, np.int16, np.int32, np.int64],
26 'uint': [np.uint8, np.uint16, np.uint32, np.uint64],
27 'float': [np.float16, np.float32, np.float64],
28 'complex': [np.complex64, np.complex128],
29 'others': [bool, object, bytes, str, np.void]
30 }
32def dino_transform_Image(n_px):
33 """DINO transform for PIL Images."""
34 return Compose([
35 Resize(size=n_px, antialias=False),
36 ToTensor(),
37 Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
38 ])
41class VideoQualityAnalyzer:
42 """Video quality analyzer base class."""
44 def __init__(self):
45 pass
47 def analyze(self, frames: List[Image.Image]):
48 """Analyze video quality.
50 Args:
51 frames: List of PIL Image frames representing the video
53 Returns:
54 Quality score(s)
55 """
56 raise NotImplementedError("Subclasses must implement analyze method")
59class SubjectConsistencyAnalyzer(VideoQualityAnalyzer):
60 """Analyzer for evaluating subject consistency across video frames using DINO features.
62 This analyzer measures how consistently the main subject appears across frames by:
63 1. Extracting DINO features from each frame
64 2. Computing cosine similarity between consecutive frames and with the first frame
65 3. Averaging these similarities to get a consistency score
66 """
67 def __init__(
68 self,
69 model_url: str = "https://dl.fbaipublicfiles.com/dino/dino_vitbase16_pretrain/dino_vitbase16_pretrain_full_checkpoint.pth",
70 model_path: str = "dino_vitb16_full.pth",
71 device: str = "cuda"
72 ):
73 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
74 self.model_path = model_path
75 self.model_url = model_url
77 # ensure weights exist / download automatically
78 self._download_weights()
80 # load model via timm
81 self.model = self._load_dino_model()
82 self.model.eval()
83 self.model.to(self.device)
85 def _download_weights(self):
86 if not os.path.exists(self.model_path):
87 import urllib
88 print("Downloading DINO ViT-B/16 weights...")
89 urllib.request.urlretrieve(self.model_url, self.model_path)
90 print("Download complete:", self.model_path)
91 else:
92 print("Weights already exist:", self.model_path)
94 def _load_dino_model(self):
95 import timm
96 # timm vit-base-p16 structure
97 model = timm.create_model(
98 "vit_base_patch16_224",
99 pretrained=False,
100 num_classes=0
101 )
103 # load full checkpoint
104 ckpt = torch.load(self.model_path, map_location="cpu")
106 # for full checkpoint the state dict is nested
107 if "teacher" in ckpt:
108 state_dict = ckpt["teacher"]
109 elif "student" in ckpt:
110 state_dict = ckpt["student"]
111 else:
112 state_dict = ckpt
114 # remove classifier head keys
115 state_dict = {k: v for k, v in state_dict.items() if "head" not in k}
117 model.load_state_dict(state_dict, strict=False)
118 return model
120 def transform(self, img: Image.Image) -> torch.Tensor:
121 """Transform PIL Image to tensor for DINO model."""
122 transform = dino_transform_Image(224)
123 return transform(img)
125 def analyze(self, frames: List[Image.Image]) -> float:
126 """Analyze subject consistency across video frames.
128 Args:
129 frames: List of PIL Image frames representing the video
131 Returns:
132 Subject consistency score (higher is better, range [0, 1])
133 """
134 if len(frames) < 2:
135 return 1.0 # Single frame is perfectly consistent with itself
137 video_sim = 0.0
138 frame_count = 0
140 # Process frames and extract features
141 with torch.no_grad():
142 for i, frame in enumerate(frames):
143 # Transform and prepare frame
144 frame_tensor = self.transform(frame).unsqueeze(0).to(self.device)
146 # Extract features
147 features = self.model(frame_tensor)
148 features = F.normalize(features, dim=-1, p=2)
150 if i == 0:
151 # Store first frame features
152 first_frame_features = features
153 else:
154 # Compute similarity with previous frame
155 sim_prev = max(0.0, F.cosine_similarity(prev_features, features).item())
157 # Compute similarity with first frame
158 sim_first = max(0.0, F.cosine_similarity(first_frame_features, features).item())
160 # Average the two similarities
161 frame_sim = (sim_prev + sim_first) / 2.0
162 video_sim += frame_sim
163 frame_count += 1
165 # Store current features as previous for next iteration
166 prev_features = features
168 # Return average similarity across all frame pairs
169 if frame_count > 0:
170 return video_sim / frame_count
171 else:
172 return 1.0
174# from contextlib import contextmanager
176# @contextmanager
177# def isolated_import_context(code_dir, isolated_prefixes, prefix_tag=None):
178# """Context manager for isolated module imports to avoid conflicts with main project.
180# Args:
181# code_dir: External code directory to add to sys.path
182# isolated_prefixes: List of module name prefixes to isolate (e.g., ['utils', 'networks'])
183# prefix_tag: Tag to prefix external modules with after loading (default: code_dir.name + '_ext_')
185# Example:
186# with isolated_import_context(CODE_DIR, ['utils', 'networks']):
187# # imports here will use CODE_DIR's modules
188# spec = importlib.util.spec_from_file_location("entry", CODE_DIR / "main.py")
189# ...
190# # after exiting, main project's 'utils' is restored
191# """
192# import sys
194# if prefix_tag is None:
195# prefix_tag = code_dir.name + '_ext_'
197# original_path = sys.path.copy()
198# saved_modules = {}
200# # Remove potentially conflicting modules
201# for prefix in isolated_prefixes:
202# for mod_name in list(sys.modules.keys()):
203# if mod_name == prefix or mod_name.startswith(prefix + '.'):
204# saved_modules[mod_name] = sys.modules.pop(mod_name)
206# sys.path.insert(0, str(code_dir))
208# try:
209# yield
210# finally:
211# sys.path[:] = original_path
213# # Rename external modules with prefix tag to avoid future conflicts
214# for prefix in isolated_prefixes:
215# for mod_name in list(sys.modules.keys()):
216# if mod_name == prefix or mod_name.startswith(prefix + '.'):
217# if mod_name not in saved_modules:
218# sys.modules[prefix_tag + mod_name] = sys.modules.pop(mod_name)
220# # Restore main project modules
221# sys.modules.update(saved_modules)
223class MotionSmoothnessAnalyzer(VideoQualityAnalyzer):
224 """Analyzer for evaluating motion smoothness in videos using AMT-S model.
226 This analyzer measures motion smoothness by:
227 1. Extracting frames at even indices from the video
228 2. Using AMT-S model to interpolate between consecutive frames
229 3. Comparing interpolated frames with actual frames to compute smoothness score
231 The score represents how well the motion can be predicted/interpolated,
232 with smoother motion resulting in higher scores.
233 """
235 def __init__(self, model_path: str = "model/amt/amt-s.pth",
236 device: str = "cuda", niters: int = 1):
237 """Initialize the MotionSmoothnessAnalyzer.
239 Args:
240 model_path: Path to the AMT-S model checkpoint
241 device: Device to run the model on ('cuda' or 'cpu')
242 niters: Number of interpolation iterations (default: 1)
243 """
244 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
245 self.niters = niters
247 # Initialize model parameters
248 self._initialize_params()
250 # Load AMT-S model
251 self.model = self._load_amt_model(model_path)
252 self.model.eval()
253 self.model.to(self.device)
255 def _initialize_params(self):
256 """Initialize parameters for video processing."""
257 if self.device.type == 'cuda':
258 self.anchor_resolution = 1024 * 512
259 self.anchor_memory = 1500 * 1024**2
260 self.anchor_memory_bias = 2500 * 1024**2
261 self.vram_avail = torch.cuda.get_device_properties(self.device).total_memory
262 else:
263 # Do not resize in cpu mode
264 self.anchor_resolution = 8192 * 8192
265 self.anchor_memory = 1
266 self.anchor_memory_bias = 0
267 self.vram_avail = 1
269 # Time embedding for interpolation (t=0.5)
270 self.embt = torch.tensor(1/2).float().view(1, 1, 1, 1).to(self.device)
272 def _load_amt_model(self, model_path: str):
273 """Load AMT-S model.
275 Args:
276 model_path: Path to the model checkpoint
278 Returns:
279 Loaded AMT-S model
280 """
281 # Import AMT-S model (note the hyphen in filename)
282 import sys
283 import importlib.util
285 # Load the module with hyphen in filename
286 spec = importlib.util.spec_from_file_location("amt_s", "model/amt/networks/AMT-S.py")
287 amt_s_module = importlib.util.module_from_spec(spec)
288 spec.loader.exec_module(amt_s_module)
289 Model = amt_s_module.Model
291 # Create model with default parameters
292 model = Model(
293 corr_radius=3,
294 corr_lvls=4,
295 num_flows=3
296 )
298 # Load checkpoint
299 if os.path.exists(model_path):
300 ckpt = torch.load(model_path, map_location="cpu", weights_only=False)
301 model.load_state_dict(ckpt['state_dict'])
303 return model
305 def _extract_frames(self, frames: List[Image.Image], start_from: int = 0) -> List[np.ndarray]:
306 """Extract frames at even indices starting from start_from.
308 Args:
309 frames: List of PIL Image frames
310 start_from: Starting index (default: 0)
312 Returns:
313 List of extracted frames as numpy arrays
314 """
315 extracted = []
316 for i in range(start_from, len(frames), 2):
317 # Convert PIL Image to numpy array
318 frame_np = np.array(frames[i])
319 extracted.append(frame_np)
320 return extracted
322 def _img2tensor(self, img: np.ndarray) -> torch.Tensor:
323 """Convert numpy image to tensor.
325 Args:
326 img: Image as numpy array (H, W, C)
328 Returns:
329 Image tensor (1, C, H, W)
330 """
331 from model.amt.utils.utils import img2tensor
332 return img2tensor(img)
334 def _tensor2img(self, tensor: torch.Tensor) -> np.ndarray:
335 """Convert tensor to numpy image.
337 Args:
338 tensor: Image tensor (1, C, H, W)
340 Returns:
341 Image as numpy array (H, W, C)
342 """
343 from model.amt.utils.utils import tensor2img
344 return tensor2img(tensor)
346 def _check_dim_and_resize(self, tensor_list: List[torch.Tensor]) -> List[torch.Tensor]:
347 """Check dimensions and resize tensors if needed.
349 Args:
350 tensor_list: List of image tensors
352 Returns:
353 List of resized tensors
354 """
355 from model.amt.utils.utils import check_dim_and_resize
356 return check_dim_and_resize(tensor_list)
358 def _calculate_scale(self, h: int, w: int) -> float:
359 """Calculate scaling factor based on available VRAM.
361 Args:
362 h: Height of the image
363 w: Width of the image
365 Returns:
366 Scaling factor
367 """
368 scale = self.anchor_resolution / (h * w) * np.sqrt((self.vram_avail - self.anchor_memory_bias) / self.anchor_memory)
369 scale = 1 if scale > 1 else scale
370 scale = 1 / np.floor(1 / np.sqrt(scale) * 16) * 16
371 return scale
373 def _interpolate_frames(self, inputs: List[torch.Tensor], scale: float) -> List[torch.Tensor]:
374 """Interpolate frames using AMT-S model.
376 Args:
377 inputs: List of input frame tensors
378 scale: Scaling factor for processing
380 Returns:
381 List of interpolated frame tensors
382 """
384 from model.amt.utils.utils import InputPadder
385 # Pad inputs
386 padding = int(16 / scale)
387 padder = InputPadder(inputs[0].shape, padding)
388 inputs = padder.pad(*inputs)
390 # Perform interpolation for specified iterations
391 for _ in range(self.niters):
392 outputs = [inputs[0]]
393 for in_0, in_1 in zip(inputs[:-1], inputs[1:]):
394 in_0 = in_0.to(self.device)
395 in_1 = in_1.to(self.device)
396 with torch.no_grad():
397 imgt_pred = self.model(in_0, in_1, self.embt, scale_factor=scale, eval=True)['imgt_pred']
398 outputs += [imgt_pred.cpu(), in_1.cpu()]
399 inputs = outputs
401 # Unpad outputs
402 outputs = padder.unpad(*outputs)
403 return outputs
405 def _compute_frame_difference(self, img1: np.ndarray, img2: np.ndarray) -> float:
406 """Compute average absolute difference between two images.
408 Args:
409 img1: First image
410 img2: Second image
412 Returns:
413 Average pixel difference
414 """
415 diff = cv2.absdiff(img1, img2)
416 return np.mean(diff)
418 def _compute_vfi_score(self, original_frames: List[np.ndarray], interpolated_frames: List[np.ndarray]) -> float:
419 """Compute video frame interpolation score.
421 Args:
422 original_frames: Original video frames
423 interpolated_frames: Interpolated frames
425 Returns:
426 VFI score (lower difference means better interpolation)
427 """
428 # Extract frames at odd indices for comparison
429 ori_compare = self._extract_frames([Image.fromarray(f) for f in original_frames], start_from=1)
430 interp_compare = self._extract_frames([Image.fromarray(f) for f in interpolated_frames], start_from=1)
432 scores = []
433 for ori, interp in zip(ori_compare, interp_compare):
434 score = self._compute_frame_difference(ori, interp)
435 scores.append(score)
437 return np.mean(scores) if scores else 0.0
439 def analyze(self, frames: List[Image.Image]) -> float:
440 """Analyze motion smoothness in video frames.
442 Args:
443 frames: List of PIL Image frames representing the video
445 Returns:
446 Motion smoothness score (higher is better, range [0, 1])
447 """
448 if len(frames) < 2:
449 return 1.0 # Single frame has perfect smoothness
451 # Convert PIL Images to numpy arrays
452 np_frames = [np.array(frame) for frame in frames]
454 # Extract frames at even indices
455 frame_list = self._extract_frames(frames, start_from=0)
457 # Convert to tensors
458 inputs = [self._img2tensor(frame).to(self.device) for frame in frame_list]
460 if len(inputs) <= 1:
461 return 1.0 # Not enough frames for interpolation
463 # Check dimensions and resize if needed
464 inputs = self._check_dim_and_resize(inputs)
465 h, w = inputs[0].shape[-2:]
467 # Calculate scale based on available memory
468 scale = self._calculate_scale(h, w)
470 # Perform frame interpolation
471 outputs = self._interpolate_frames(inputs, scale)
473 # Convert outputs back to images
474 output_images = [self._tensor2img(out) for out in outputs]
476 # Compute VFI score
477 vfi_score = self._compute_vfi_score(np_frames, output_images)
479 # Normalize score to [0, 1] range (higher is better)
480 # Original score is average pixel difference [0, 255], we normalize and invert
481 normalized_score = (255.0 - vfi_score) / 255.0
483 return normalized_score
486class DynamicDegreeAnalyzer(VideoQualityAnalyzer):
487 """Analyzer for evaluating dynamic degree (motion intensity) in videos using RAFT optical flow.
489 This analyzer measures the amount and intensity of motion in videos by:
490 1. Computing optical flow between consecutive frames using RAFT
491 2. Calculating flow magnitude for each pixel
492 3. Extracting top 5% highest flow magnitudes
493 4. Determining if video has sufficient dynamic motion based on thresholds
495 The score represents whether the video contains dynamic motion (1.0) or is mostly static (0.0).
496 """
498 def __init__(self, model_path: str = "model/raft/raft-things.pth",
499 device: str = "cuda", sample_fps: int = 8):
500 """Initialize the DynamicDegreeAnalyzer.
502 Args:
503 model_path: Path to the RAFT model checkpoint
504 device: Device to run the model on ('cuda' or 'cpu')
505 sample_fps: Target FPS for frame sampling (default: 8)
506 """
507 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
508 self.sample_fps = sample_fps
510 # Load RAFT model
511 self.model = self._load_raft_model(model_path)
512 self.model.eval()
513 self.model.to(self.device)
515 def _load_raft_model(self, model_path: str):
516 """Load RAFT optical flow model.
518 Args:
519 model_path: Path to the model checkpoint
521 Returns:
522 Loaded RAFT model
523 """
524 from model.raft.core.raft import RAFT
525 from easydict import EasyDict as edict
527 # Configure RAFT arguments
528 args = edict({
529 "model": model_path,
530 "small": False,
531 "mixed_precision": False,
532 "alternate_corr": False
533 })
535 # Create and load model
536 model = RAFT(args)
538 if os.path.exists(model_path):
539 ckpt = torch.load(model_path, map_location="cpu")
540 # Remove 'module.' prefix if present (from DataParallel)
541 new_ckpt = {k.replace('module.', ''): v for k, v in ckpt.items()}
542 model.load_state_dict(new_ckpt)
544 return model
546 def _extract_frames_for_flow(self, frames: List[Image.Image], target_fps: int = 8) -> List[torch.Tensor]:
547 """Extract and prepare frames for optical flow computation.
549 Args:
550 frames: List of PIL Image frames
551 target_fps: Target sampling rate (default: 8 fps)
553 Returns:
554 List of prepared frame tensors
555 """
556 # Estimate original FPS and calculate sampling interval
557 # Assuming 30fps original video, adjust sampling to get ~8fps
558 total_frames = len(frames)
559 assumed_fps = 30 # Common video fps
560 interval = max(1, round(assumed_fps / target_fps))
562 # Sample frames at interval
563 sampled_frames = []
564 for i in range(0, total_frames, interval):
565 frame = frames[i]
566 # Convert PIL to numpy array
567 frame_np = np.array(frame)
568 # Convert to tensor and normalize
569 frame_tensor = torch.from_numpy(frame_np.astype(np.uint8)).permute(2, 0, 1).float()
570 frame_tensor = frame_tensor[None].to(self.device)
571 sampled_frames.append(frame_tensor)
573 return sampled_frames
575 def _compute_flow_magnitude(self, flow: torch.Tensor) -> float:
576 """Compute flow magnitude score from optical flow.
578 Args:
579 flow: Optical flow tensor (B, 2, H, W)
581 Returns:
582 Flow magnitude score
583 """
584 # Extract flow components
585 flow_np = flow[0].permute(1, 2, 0).cpu().numpy()
586 u = flow_np[:, :, 0]
587 v = flow_np[:, :, 1]
589 # Compute flow magnitude
590 magnitude = np.sqrt(np.square(u) + np.square(v))
592 # Get top 5% highest magnitudes
593 h, w = magnitude.shape
594 magnitude_flat = magnitude.flatten()
595 cut_index = int(h * w * 0.05)
597 # Sort in descending order and take mean of top 5%
598 top_magnitudes = np.sort(-magnitude_flat)[:cut_index]
599 mean_magnitude = np.mean(np.abs(top_magnitudes))
601 return mean_magnitude.item()
603 def _determine_dynamic_threshold(self, frame_shape: tuple, num_frames: int) -> dict:
604 """Determine thresholds for dynamic motion detection.
606 Args:
607 frame_shape: Shape of the frame tensor
608 num_frames: Number of frames in the video
610 Returns:
611 Dictionary with threshold parameters
612 """
613 # Scale threshold based on image resolution
614 scale = min(frame_shape[-2:]) # min of height and width
615 magnitude_threshold = 6.0 * (scale / 256.0)
617 # Scale count threshold based on number of frames
618 count_threshold = round(4 * (num_frames / 16.0))
620 return {
621 "magnitude_threshold": magnitude_threshold,
622 "count_threshold": count_threshold
623 }
625 def _check_dynamic_motion(self, flow_scores: List[float], thresholds: dict) -> bool:
626 """Check if video has dynamic motion based on flow scores.
628 Args:
629 flow_scores: List of optical flow magnitude scores
630 thresholds: Threshold parameters
632 Returns:
633 True if video has dynamic motion, False otherwise
634 """
635 magnitude_threshold = thresholds["magnitude_threshold"]
636 count_threshold = thresholds["count_threshold"]
638 # Count frames with significant motion
639 motion_count = 0
640 for score in flow_scores:
641 if score > magnitude_threshold:
642 motion_count += 1
643 if motion_count >= count_threshold:
644 return True
646 return False
648 def analyze(self, frames: List[Image.Image]) -> float:
649 """Analyze dynamic degree (motion intensity) in video frames.
651 Args:
652 frames: List of PIL Image frames representing the video
654 Returns:
655 Dynamic degree score: 1.0 if video has dynamic motion, 0.0 if mostly static
656 """
657 if len(frames) < 2:
658 return 0.0 # Cannot compute optical flow with less than 2 frames
660 # Extract and prepare frames for optical flow
661 prepared_frames = self._extract_frames_for_flow(frames, self.sample_fps)
663 if len(prepared_frames) < 2:
664 return 0.0
666 # Determine thresholds based on video characteristics
667 thresholds = self._determine_dynamic_threshold(
668 prepared_frames[0].shape,
669 len(prepared_frames)
670 )
672 # Compute optical flow between consecutive frames
673 flow_scores = []
675 with torch.no_grad():
676 for frame1, frame2 in zip(prepared_frames[:-1], prepared_frames[1:]):
677 # Pad frames if necessary
678 from model.raft.core.utils_core.utils import InputPadder
679 padder = InputPadder(frame1.shape)
680 frame1_padded, frame2_padded = padder.pad(frame1, frame2)
682 # Compute optical flow
683 _, flow_up = self.model(frame1_padded, frame2_padded, iters=20, test_mode=True)
685 # Calculate flow magnitude score
686 magnitude_score = self._compute_flow_magnitude(flow_up)
687 flow_scores.append(magnitude_score)
689 # Check if video has dynamic motion
690 has_dynamic_motion = self._check_dynamic_motion(flow_scores, thresholds)
692 # Return binary score: 1.0 for dynamic, 0.0 for static
693 return 1.0 if has_dynamic_motion else 0.0
696class BackgroundConsistencyAnalyzer(VideoQualityAnalyzer):
697 """Analyzer for evaluating background consistency across video frames using CLIP features.
699 This analyzer measures how consistently the background appears across frames by:
700 1. Extracting CLIP visual features from each frame
701 2. Computing cosine similarity between consecutive frames and with the first frame
702 3. Averaging these similarities to get a consistency score
704 Similar to SubjectConsistencyAnalyzer but focuses on overall visual consistency
705 including background elements, making it suitable for detecting background stability.
706 """
708 def __init__(self, model_name: str = "ViT-B/32", device: str = "cuda"):
709 """Initialize the BackgroundConsistencyAnalyzer.
711 Args:
712 model_name: CLIP model name (default: "ViT-B/32")
713 device: Device to run the model on ('cuda' or 'cpu')
714 """
715 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
717 # Load CLIP model
718 self.model, self.preprocess = self._load_clip_model(model_name)
719 self.model.eval()
720 self.model.to(self.device)
722 # Image transform for CLIP (when processing tensor inputs)
723 self.tensor_transform = self._get_clip_tensor_transform(224)
725 def _load_clip_model(self, model_name: str):
726 """Load CLIP model.
728 Args:
729 model_name: Name of the CLIP model to load
731 Returns:
732 Tuple of (model, preprocess_function)
733 """
734 import clip
736 model, preprocess = clip.load(model_name, device=self.device)
737 return model, preprocess
739 def _get_clip_tensor_transform(self, n_px: int):
740 """Get CLIP transform for tensor inputs.
742 Args:
743 n_px: Target image size
745 Returns:
746 Transform composition for tensor inputs
747 """
748 try:
749 from torchvision.transforms import InterpolationMode
750 BICUBIC = InterpolationMode.BICUBIC
751 except ImportError:
752 BICUBIC = Image.BICUBIC
754 return Compose([
755 Resize(n_px, interpolation=BICUBIC, antialias=False),
756 CenterCrop(n_px),
757 transforms.Lambda(lambda x: x.float().div(255.0)),
758 Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
759 ])
761 def _prepare_images_for_clip(self, frames: List[Image.Image]) -> torch.Tensor:
762 """Prepare PIL images for CLIP processing.
764 Args:
765 frames: List of PIL Image frames
767 Returns:
768 Batch tensor of preprocessed images
769 """
770 # Use CLIP's built-in preprocess for PIL images
771 images = []
772 for frame in frames:
773 processed = self.preprocess(frame)
774 images.append(processed)
776 # Stack into batch tensor
777 return torch.stack(images).to(self.device)
779 def analyze(self, frames: List[Image.Image]) -> float:
780 """Analyze background consistency across video frames.
782 Args:
783 frames: List of PIL Image frames representing the video
785 Returns:
786 Background consistency score (higher is better, range [0, 1])
787 """
788 if len(frames) < 2:
789 return 1.0 # Single frame is perfectly consistent with itself
791 # Prepare images for CLIP
792 images = self._prepare_images_for_clip(frames)
794 # Extract CLIP features
795 with torch.no_grad():
796 image_features = self.model.encode_image(images)
797 image_features = F.normalize(image_features, dim=-1, p=2)
799 video_sim = 0.0
800 frame_count = 0
802 # Compute similarity between frames
803 for i in range(len(image_features)):
804 image_feature = image_features[i].unsqueeze(0)
806 if i == 0:
807 # Store first frame features
808 first_image_feature = image_feature
809 else:
810 # Compute similarity with previous frame
811 sim_prev = max(0.0, F.cosine_similarity(former_image_feature, image_feature).item())
813 # Compute similarity with first frame
814 sim_first = max(0.0, F.cosine_similarity(first_image_feature, image_feature).item())
816 # Average the two similarities
817 frame_sim = (sim_prev + sim_first) / 2.0
818 video_sim += frame_sim
819 frame_count += 1
821 # Store current features as previous for next iteration
822 former_image_feature = image_feature
824 # Return average similarity across all frame pairs
825 if frame_count > 0:
826 return video_sim / frame_count
827 else:
828 return 1.0
830class ImagingQualityAnalyzer(VideoQualityAnalyzer):
831 """Analyzer for evaluating imaging quality of videos.
833 This analyzer measures the quality of videos by:
834 1. Inputting frames into MUSIQ image quality predictor
835 2. Determining if the video is blurry or has artifacts
837 The score represents the quality of the video (higher is better).
838 """
839 def __init__(self, model_path: str = "model/musiq/musiq_spaq_ckpt-358bb6af.pth", device: str = "cuda"):
840 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
841 self.model = self._load_musiq(model_path)
842 self.model.to(self.device)
843 self.model.eval()
845 def _load_musiq(self, model_path: str):
846 """Load MUSIQ model.
848 Args:
849 model_path: Path to the MUSIQ model checkpoint
851 Returns:
852 MUSIQ model
853 """
854 # from pathlib import Path
855 # CACHE_DIR = Path("/tmp/musiq_cache")
856 # model_path = CACHE_DIR / model_path
858 # if the model_path not exists
859 # then makedir and wget
860 if not os.path.exists(model_path):
861 os.makedirs(os.path.dirname(model_path), exist_ok=True)
862 wget_command = ['wget', 'https://github.com/chaofengc/IQA-PyTorch/releases/download/v0.1-weights/musiq_spaq_ckpt-358bb6af.pth', '-P', os.path.dirname(model_path)]
863 subprocess.run(wget_command, check=True)
864 try:
865 from pyiqa.archs.musiq_arch import MUSIQ
866 except ImportError:
867 raise ImportError("Please install pyiqa to use ImagingQualityAnalyzer: pip install pyiqa")
868 model = MUSIQ(pretrained_model_path=str(model_path))
870 return model
872 def _preprocess_frames(self, frames: List[Image.Image]) -> torch.Tensor:
873 """Preprocess frames for MUSIQ model.
875 Args:
876 frames: List of PIL Image frames
878 Returns:
879 Preprocessed frames as tensor
880 """
881 frames = [pil_to_torch(frame, normalize=False) for frame in frames] # [(C, H, W)]
882 frames = torch.stack(frames) # (T, C, H, W)
884 _, _, h, w = frames.size()
885 if max(h, w) > 512:
886 scale = 512./max(h, w)
887 frames = F.interpolate(frames, size=(int(scale * h), int(scale * w)), mode='bilinear', align_corners=False)
889 return frames
891 def analyze(self, frames: List[Image.Image]) -> float:
892 """Analyze imaging quality of video frames.
894 Args:
895 frames: List of PIL Image frames representing the video
897 Returns:
898 Imaging quality score (higher is better, range [0, 1])
899 """
900 frame_tensor = self._preprocess_frames(frames)
901 acc_score_video = 0.0
902 for i in range(len(frame_tensor)):
903 frame = frame_tensor[i].unsqueeze(0).to(self.device)
904 score = self.model(frame)
905 acc_score_video += float(score)
906 return acc_score_video / (100 * len(frame_tensor))