Coverage for evaluation / tools / image_quality_analyzer.py: 95.92%
466 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from PIL import Image
2from typing import List, Dict, Optional, Union, Tuple
3import torch
4import torch.nn as nn
5import torch.nn.functional as F
6import numpy as np
7from abc import abstractmethod
8import lpips
9import piq
11class ImageQualityAnalyzer:
12 """Base class for image quality analyzer."""
14 def __init__(self):
15 pass
17 @abstractmethod
18 def analyze(self):
19 pass
21class DirectImageQualityAnalyzer(ImageQualityAnalyzer):
22 """Base class for direct image quality analyzer."""
24 def __init__(self):
25 pass
27 def analyze(self, image: Image.Image, *args, **kwargs):
28 pass
30class ReferencedImageQualityAnalyzer(ImageQualityAnalyzer):
31 """Base class for referenced image quality analyzer."""
33 def __init__(self):
34 pass
36 def analyze(self, image: Image.Image, reference: Union[Image.Image, str], *args, **kwargs):
37 pass
39class GroupImageQualityAnalyzer(ImageQualityAnalyzer):
40 """Base class for group image quality analyzer."""
42 def __init__(self):
43 pass
45 def analyze(self, images: List[Image.Image], references: List[Image.Image], *args, **kwargs):
46 pass
48class RepeatImageQualityAnalyzer(ImageQualityAnalyzer):
49 """Base class for repeat image quality analyzer."""
51 def __init__(self):
52 pass
54 def analyze(self, images: List[Image.Image], *args, **kwargs):
55 pass
57class ComparedImageQualityAnalyzer(ImageQualityAnalyzer):
58 """Base class for compare image quality analyzer."""
60 def __init__(self):
61 pass
63 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs):
64 pass
66class InceptionScoreCalculator(RepeatImageQualityAnalyzer):
67 """Inception Score (IS) calculator for evaluating image generation quality.
69 Inception Score measures both the quality and diversity of generated images
70 by evaluating how confidently an Inception model can classify them and how
71 diverse the predictions are across the image set.
73 Higher IS indicates better image quality and diversity (typical range: 1-10+).
74 """
76 def __init__(self, device: str = "cuda", batch_size: int = 32, splits: int = 1):
77 """Initialize the Inception Score calculator.
79 Args:
80 device: Device to run the model on ("cuda" or "cpu")
81 batch_size: Batch size for processing images
82 splits: Number of splits for computing IS (default: 1). The splits must be divisible by the number of images for fair comparison.
83 For calculating the mean and standard error of IS, the splits should be set greater than 1.
84 If splits is 1, the IS is calculated on the entire dataset.(Avg = IS, Std = 0)
85 """
86 super().__init__()
87 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
88 self.batch_size = batch_size
89 self.splits = splits
90 self._load_model()
92 def _load_model(self):
93 """Load the Inception v3 model for feature extraction."""
94 from torchvision import models, transforms
96 # Load pre-trained Inception v3 model
97 self.model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
98 self.model.aux_logits = False # Disable auxiliary output
99 self.model.eval()
100 self.model.to(self.device)
102 # Keep the original classification layer for proper predictions
103 # No need to modify model.fc - it should output 1000 classes
105 # Define preprocessing pipeline for Inception v3
106 self.preprocess = transforms.Compose([
107 transforms.Resize((299, 299)), # Inception v3 input size
108 transforms.ToTensor(),
109 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet statistics
110 ])
112 def _get_predictions(self, images: List[Image.Image]) -> np.ndarray:
113 """Extract softmax predictions from images using Inception v3.
115 Args:
116 images: List of PIL images to process
118 Returns:
119 Numpy array of shape (n_images, n_classes) containing softmax predictions
120 """
121 predictions_list = []
123 # Process images in batches for efficiency
124 for i in range(0, len(images), self.batch_size):
125 batch_images = images[i:i + self.batch_size]
127 # Preprocess batch
128 batch_tensors = []
129 for img in batch_images:
130 # Ensure RGB format
131 if img.mode != 'RGB':
132 img = img.convert('RGB')
133 tensor = self.preprocess(img)
134 batch_tensors.append(tensor)
136 # Stack into batch tensor
137 batch_tensor = torch.stack(batch_tensors).to(self.device)
139 # Get predictions from Inception model
140 with torch.no_grad():
141 logits = self.model(batch_tensor)
142 # Apply softmax to get probability distributions
143 probs = F.softmax(logits, dim=1)
144 predictions_list.append(probs.cpu().numpy())
146 return np.concatenate(predictions_list, axis=0)
148 def _calculate_inception_score(self, predictions: np.ndarray) -> tuple:
149 """Calculate Inception Score from predictions.
151 The IS is calculated as exp(KL divergence between conditional and marginal distributions).
153 Args:
154 predictions: Softmax predictions of shape (n_images, n_classes)
156 Returns:
157 Tuple of (mean_is, std_is) across splits
158 """
159 # Split predictions for more stable estimation
160 n_samples = predictions.shape[0] # (n_images, n_classes)
161 split_size = n_samples // self.splits
163 splits = self.splits
165 split_scores = []
167 for split_idx in range(splits):
168 # Get current split
169 start_idx = split_idx * split_size
170 end_idx = (split_idx + 1) * split_size if split_idx < splits - 1 else n_samples # Last split gets remaining samples
171 split_preds = predictions[start_idx:end_idx]
173 # Calculate marginal distribution p(y) - average across all samples
174 p_y = np.mean(split_preds, axis=0)
176 epsilon = 1e-16
177 p_y_x_safe = split_preds + epsilon
178 p_y_safe = p_y + epsilon
179 kl_divergences = np.sum(
180 p_y_x_safe * (np.log(p_y_x_safe / p_y_safe)),
181 axis=1)
183 # Inception Score for this split is exp(mean(KL divergences))
184 split_score = np.exp(np.mean(kl_divergences))
185 split_scores.append(split_score)
187 # Directly return the list of scores for each split
188 return split_scores
190 def analyze(self, images: List[Image.Image], *args, **kwargs) -> List[float]:
191 """Calculate Inception Score for a set of generated images.
193 Args:
194 images: List of generated images to evaluate
196 Returns:
197 List[float]: Inception Score values for each split (higher is better, typical range: 1-10+)
198 """
199 if len(images) < self.splits:
200 raise ValueError(f"Inception Score requires at least {self.splits} images (one per split)")
202 if len(images) % self.splits != 0:
203 raise ValueError(f"Inception Score requires the number of images to be divisible by the number of splits")
205 # Get predictions from Inception model
206 predictions = self._get_predictions(images)
208 # Calculate Inception Score
209 split_scores = self._calculate_inception_score(predictions)
211 # Log the standard deviation for reference (but return only mean)
212 mean_score = np.mean(split_scores)
213 std_score = np.std(split_scores)
214 if std_score > 0.5 * mean_score:
215 print(f"Warning: High standard deviation in IS calculation: {mean_score:.2f} ± {std_score:.2f}")
217 return split_scores
219class CLIPScoreCalculator(ReferencedImageQualityAnalyzer):
220 """CLIP score calculator for image quality analysis.
222 Calculates CLIP similarity between an image and a reference.
223 Higher scores indicate better semantic similarity.
224 """
226 def __init__(self, device: str = "cuda", model_name: str = "ViT-B/32", reference_source: str = "image"):
227 """Initialize the CLIP Score calculator.
229 Args:
230 device: Device to run the model on ("cuda" or "cpu")
231 model_name: CLIP model variant to use
232 reference_source: The source of reference ('image' or 'text')
233 """
234 super().__init__()
235 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
236 self.model_name = model_name
237 self.reference_source = reference_source
238 self._load_model()
240 def _load_model(self):
241 """Load the CLIP model."""
242 try:
243 import clip
244 self.model, self.preprocess = clip.load(self.model_name, device=self.device)
245 self.model.eval()
246 except ImportError:
247 raise ImportError("Please install the CLIP library: pip install git+https://github.com/openai/CLIP.git")
249 def analyze(self, image: Image.Image, reference: Union[Image.Image, str], *args, **kwargs) -> float:
250 """Calculate CLIP similarity between image and reference.
252 Args:
253 image: Input image to evaluate
254 reference: Reference image or text for comparison
255 - If reference_source is 'image': expects PIL Image
256 - If reference_source is 'text': expects string
258 Returns:
259 float: CLIP similarity score (0 to 1)
260 """
262 # Convert image to RGB if necessary
263 if image.mode != 'RGB':
264 image = image.convert('RGB')
266 # Preprocess image
267 img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
269 # Extract features based on reference source
270 with torch.no_grad():
271 # Encode image features
272 img_features = self.model.encode_image(img_tensor)
274 # Encode reference features based on source type
275 if self.reference_source == 'text':
276 if not isinstance(reference, str):
277 raise ValueError(f"Expected string reference for text mode, got {type(reference)}")
279 # Tokenize and encode text
280 text_tokens = clip.tokenize([reference]).to(self.device)
281 ref_features = self.model.encode_text(text_tokens)
283 elif self.reference_source == 'image':
284 if not isinstance(reference, Image.Image):
285 raise ValueError(f"Expected PIL Image reference for image mode, got {type(reference)}")
287 # Convert reference image to RGB if necessary
288 if reference.mode != 'RGB':
289 reference = reference.convert('RGB')
291 # Preprocess and encode reference image
292 ref_tensor = self.preprocess(reference).unsqueeze(0).to(self.device)
293 ref_features = self.model.encode_image(ref_tensor)
295 else:
296 raise ValueError(f"Invalid reference_source: {self.reference_source}. Must be 'image' or 'text'")
298 # Normalize features
299 img_features = F.normalize(img_features, p=2, dim=1)
300 ref_features = F.normalize(ref_features, p=2, dim=1)
302 # Calculate cosine similarity
303 similarity = torch.cosine_similarity(img_features, ref_features).item()
305 # Convert to 0-1 range
306 similarity = (similarity + 1) / 2
308 return similarity
310class FIDCalculator(GroupImageQualityAnalyzer):
311 """FID calculator for image quality analysis.
313 Calculates Fréchet Inception Distance between two sets of images.
314 Lower FID indicates better quality and similarity to reference distribution.
315 """
317 def __init__(self, device: str = "cuda", batch_size: int = 32, splits: int = 1):
318 """Initialize the FID calculator.
320 Args:
321 device: Device to run the model on ("cuda" or "cpu")
322 batch_size: Batch size for processing images
323 splits: Number of splits for computing FID (default: 5). The splits must be divisible by the number of images for fair comparison.
324 For calculating the mean and standard error of FID, the splits should be set greater than 1.
325 If splits is 1, the FID is calculated on the entire dataset.(Avg = FID, Std = 0)
326 """
327 super().__init__()
328 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
329 self.batch_size = batch_size
330 self.splits = splits
331 self._load_model()
333 def _load_model(self):
334 """Load the Inception v3 model for feature extraction."""
335 from torchvision import models, transforms
337 # Load Inception v3 model
338 inception = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT, init_weights=False)
339 inception.fc = nn.Identity() # Remove final classification layer
340 inception.aux_logits = False
341 inception.eval()
342 inception.to(self.device)
343 self.model = inception
345 # Define preprocessing
346 self.preprocess = transforms.Compose([
347 transforms.Resize((512, 512)),
348 transforms.ToTensor(),
349 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet statistics
350 ])
352 def _extract_features(self, images: List[Image.Image]) -> np.ndarray:
353 """Extract features from a list of images.
355 Args:
356 images: List of PIL images
358 Returns:
359 Feature matrix of shape (n_images, 2048)
360 """
361 features_list = []
363 for i in range(0, len(images), self.batch_size):
364 batch_images = images[i:i + self.batch_size]
366 # Preprocess batch
367 batch_tensors = []
368 for img in batch_images:
369 if img.mode != 'RGB':
370 img = img.convert('RGB')
371 tensor = self.preprocess(img)
372 batch_tensors.append(tensor)
374 batch_tensor = torch.stack(batch_tensors).to(self.device)
376 # Extract features
377 with torch.no_grad():
378 features = self.model(batch_tensor) # (batch_size, 2048)
379 features_list.append(features.cpu().numpy())
381 return np.concatenate(features_list, axis=0) # (n_images, 2048)
383 def _calculate_fid(self, features1: np.ndarray, features2: np.ndarray) -> float:
384 """Calculate FID between two feature sets.
386 Args:
387 features1: First feature set
388 features2: Second feature set
390 Returns:
391 float: FID score
392 """
393 from scipy.linalg import sqrtm
395 # Calculate statistics
396 mu1, sigma1 = features1.mean(axis=0), np.cov(features1, rowvar=False)
397 mu2, sigma2 = features2.mean(axis=0), np.cov(features2, rowvar=False)
399 # Calculate FID
400 diff = mu1 - mu2
401 covmean = sqrtm(sigma1.dot(sigma2))
403 # Numerical stability
404 if np.iscomplexobj(covmean):
405 covmean = covmean.real
407 fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
408 return float(fid)
410 def analyze(self, images: List[Image.Image], references: List[Image.Image], *args, **kwargs) -> List[float]:
411 """Calculate FID between two sets of images.
413 Args:
414 images: Set of images to evaluate
415 references: Reference set of images
417 Returns:
418 List[float]: FID values for each split
419 """
420 if len(images) < 2 or len(references) < 2:
421 raise ValueError("FID requires at least 2 images in each set")
422 if len(images) % self.splits != 0 or len(references) % self.splits != 0:
423 raise ValueError("FID requires the number of images to be divisible by the number of splits")
425 fid_scores = []
426 # Extract features
427 features1 = self._extract_features(images)
428 features2 = self._extract_features(references)
430 # Calculate FID
431 # for i in range(self.splits):
432 # start_idx = i * len(images) // self.splits
433 # end_idx = (i + 1) * len(images) // self.splits
434 # fid_scores.append(self._calculate_fid(features1[start_idx:end_idx], features2[start_idx:end_idx]))
436 fid_scores = self._calculate_fid(features1, features2)
438 return fid_scores
440class LPIPSAnalyzer(RepeatImageQualityAnalyzer):
441 """LPIPS analyzer for image quality analysis.
443 Calculates perceptual diversity within a set of images.
444 Higher LPIPS indicates more diverse/varied images.
445 """
447 def __init__(self, device: str = "cuda", net: str = "alex"):
448 """Initialize the LPIPS analyzer.
450 Args:
451 device: Device to run the model on ("cuda" or "cpu")
452 net: Network to use ('alex', 'vgg', or 'squeeze')
453 """
454 super().__init__()
455 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
456 self.net = net
457 self._load_model()
459 def _load_model(self) -> None:
460 """
461 Load the LPIPS model.
462 """
463 self.model = lpips.LPIPS(net=self.net)
464 self.model.eval()
465 self.model.to(self.device)
467 def analyze(self, images: List[Image.Image], *args, **kwargs) -> float:
468 """Calculate average pairwise LPIPS distance within a set of images.
470 Args:
471 images: List of images to analyze diversity
473 Returns:
474 float: Average LPIPS distance (diversity score)
475 """
476 if len(images) < 2:
477 return 0.0 # No diversity with single image
479 # Preprocess all images
480 tensors = []
481 for img in images:
482 if img.mode != 'RGB':
483 img = img.convert('RGB')
484 tensor = lpips.im2tensor(np.array(img).astype(np.uint8)).to(self.device) # Convert to tensor
485 tensors.append(tensor)
487 # Calculate pairwise LPIPS distances
488 distances = []
489 for i in range(len(tensors)):
490 for j in range(i + 1, len(tensors)):
491 with torch.no_grad():
492 distance = self.model.forward(tensors[i], tensors[j]).item()
493 distances.append(distance)
495 # Return average distance as diversity score
496 return np.mean(distances) if distances else 0.0
498class PSNRAnalyzer(ComparedImageQualityAnalyzer):
499 """PSNR analyzer for image quality analysis.
501 Calculates Peak Signal-to-Noise Ratio between two images.
502 Higher PSNR indicates better quality/similarity.
503 """
505 def __init__(self, max_pixel_value: float = 255.0):
506 """Initialize the PSNR analyzer.
508 Args:
509 max_pixel_value: Maximum pixel value (255 for 8-bit images)
510 """
511 super().__init__()
512 self.max_pixel_value = max_pixel_value
514 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs) -> float:
515 """Calculate PSNR between two images.
517 Args:
518 image (Image.Image): Image to evaluate
519 reference (Image.Image): Reference image
521 Returns:
522 float: PSNR value in dB
523 """
524 # Convert to RGB if necessary
525 if image.mode != 'RGB':
526 image = image.convert('RGB')
527 if reference.mode != 'RGB':
528 reference = reference.convert('RGB')
530 # Resize if necessary
531 if image.size != reference.size:
532 reference = reference.resize(image.size, Image.Resampling.BILINEAR)
534 # Convert to numpy arrays
535 img_array = np.array(image, dtype=np.float32)
536 ref_array = np.array(reference, dtype=np.float32)
538 # Calculate MSE
539 mse = np.mean((img_array - ref_array) ** 2)
541 # Avoid division by zero
542 if mse == 0:
543 return float('inf')
545 # Calculate PSNR
546 psnr = 20 * np.log10(self.max_pixel_value / np.sqrt(mse))
548 return float(psnr)
551class NIQECalculator(DirectImageQualityAnalyzer):
552 """Natural Image Quality Evaluator (NIQE) for no-reference image quality assessment.
554 NIQE evaluates image quality based on deviations from natural scene statistics.
555 It uses a pre-trained model of natural image statistics to assess quality without
556 requiring reference images.
558 Lower NIQE scores indicate better/more natural image quality (typical range: 2-8).
559 """
561 def __init__(self,
562 model_path: str = "evaluation/tools/data/niqe_image_params.mat",
563 patch_size: int = 96,
564 sigma: float = 7.0/6.0,
565 C: float = 1.0):
566 """Initialize NIQE calculator with pre-trained natural image statistics.
568 Args:
569 model_path: Path to the pre-trained NIQE model parameters (.mat file)
570 patch_size: Size of patches for feature extraction (default: 96)
571 sigma: Standard deviation for Gaussian window (default: 7/6)
572 C: Constant for numerical stability in MSCN transform (default: 1.0)
573 """
574 super().__init__()
575 self.patch_size = patch_size
576 self.sigma = sigma
577 self.C = C
579 # Load pre-trained natural image statistics
580 self._load_model_params(model_path)
582 # Pre-compute gamma lookup table for AGGD parameter estimation
583 self._precompute_gamma_table()
585 # Generate Gaussian window for local mean/variance computation
586 self.avg_window = self._generate_gaussian_window(3, self.sigma)
588 def _load_model_params(self, model_path: str) -> None:
589 """Load pre-trained NIQE model parameters from MAT file.
591 Args:
592 model_path: Path to the model parameters file
593 """
594 import scipy.io
595 try:
596 params = scipy.io.loadmat(model_path)
597 self.pop_mu = np.ravel(params["pop_mu"])
598 self.pop_cov = params["pop_cov"]
599 except Exception as e:
600 raise RuntimeError(f"Failed to load NIQE model parameters from {model_path}: {e}")
602 def _precompute_gamma_table(self) -> None:
603 """Pre-compute gamma function values for AGGD parameter estimation."""
604 import scipy.special
606 self.gamma_range = np.arange(0.2, 10, 0.001)
607 a = scipy.special.gamma(2.0 / self.gamma_range)
608 a *= a
609 b = scipy.special.gamma(1.0 / self.gamma_range)
610 c = scipy.special.gamma(3.0 / self.gamma_range)
611 self.prec_gammas = a / (b * c)
613 def _generate_gaussian_window(self, window_size: int, sigma: float) -> np.ndarray:
614 """Generate 1D Gaussian window for filtering.
616 Args:
617 window_size: Half-size of the window (full size = 2*window_size + 1)
618 sigma: Standard deviation of Gaussian
620 Returns:
621 1D Gaussian window weights
622 """
623 window_size = int(window_size)
624 weights = np.zeros(2 * window_size + 1)
625 weights[window_size] = 1.0
626 sum_weights = 1.0
628 sigma_sq = sigma * sigma
629 for i in range(1, window_size + 1):
630 tmp = np.exp(-0.5 * i * i / sigma_sq)
631 weights[window_size + i] = tmp
632 weights[window_size - i] = tmp
633 sum_weights += 2.0 * tmp
635 weights /= sum_weights
636 return weights
638 def _compute_mscn_transform(self, image: np.ndarray, extend_mode: str = 'constant') -> tuple:
639 """Compute Mean Subtracted Contrast Normalized (MSCN) coefficients.
641 MSCN transformation normalizes image patches by local mean and variance,
642 making the coefficients more suitable for statistical modeling.
644 Args:
645 image: Input image array
646 extend_mode: Boundary extension mode for filtering
648 Returns:
649 Tuple of (mscn_coefficients, local_variance, local_mean)
650 """
651 import scipy.ndimage
653 assert len(image.shape) == 2, "Input must be grayscale image"
654 h, w = image.shape
656 # Allocate arrays for local statistics
657 mu_image = np.zeros((h, w), dtype=np.float32)
658 var_image = np.zeros((h, w), dtype=np.float32)
659 image_float = image.astype(np.float32)
661 # Compute local mean using separable Gaussian filtering
662 scipy.ndimage.correlate1d(image_float, self.avg_window, 0, mu_image, mode=extend_mode)
663 scipy.ndimage.correlate1d(mu_image, self.avg_window, 1, mu_image, mode=extend_mode)
665 # Compute local variance
666 scipy.ndimage.correlate1d(image_float**2, self.avg_window, 0, var_image, mode=extend_mode)
667 scipy.ndimage.correlate1d(var_image, self.avg_window, 1, var_image, mode=extend_mode)
669 # Variance = E[X^2] - E[X]^2
670 var_image = np.sqrt(np.abs(var_image - mu_image**2))
672 # MSCN transform
673 mscn = (image_float - mu_image) / (var_image + self.C)
675 return mscn, var_image, mu_image
677 def _compute_aggd_features(self, coefficients: np.ndarray) -> tuple:
678 """Compute Asymmetric Generalized Gaussian Distribution (AGGD) parameters.
680 AGGD models the distribution of MSCN coefficients and their products,
681 capturing shape and asymmetry characteristics.
683 Args:
684 coefficients: MSCN coefficients
686 Returns:
687 Tuple of (alpha, N, bl, br, left_std, right_std)
688 """
689 import scipy.special
691 # Flatten coefficients
692 coeffs_flat = coefficients.flatten()
693 coeffs_squared = coeffs_flat * coeffs_flat
695 # Separate left (negative) and right (positive) tail data
696 left_data = coeffs_squared[coeffs_flat < 0]
697 right_data = coeffs_squared[coeffs_flat >= 0]
699 # Compute standard deviations for left and right tails
700 left_std = np.sqrt(np.mean(left_data)) if len(left_data) > 0 else 0
701 right_std = np.sqrt(np.mean(right_data)) if len(right_data) > 0 else 0
703 # Estimate gamma (shape asymmetry parameter)
704 if right_std != 0:
705 gamma_hat = left_std / right_std
706 else:
707 gamma_hat = np.inf
709 # Estimate r_hat (generalized Gaussian ratio)
710 mean_abs = np.mean(np.abs(coeffs_flat))
711 mean_squared = np.mean(coeffs_squared)
713 if mean_squared != 0:
714 r_hat = (mean_abs ** 2) / mean_squared
715 else:
716 r_hat = np.inf
718 # Normalize r_hat using gamma
719 rhat_norm = r_hat * (((gamma_hat**3 + 1) * (gamma_hat + 1)) /
720 ((gamma_hat**2 + 1) ** 2))
722 # Find best-fitting alpha by comparing with pre-computed values
723 pos = np.argmin((self.prec_gammas - rhat_norm) ** 2)
724 alpha = self.gamma_range[pos]
726 # Compute AGGD parameters
727 gam1 = scipy.special.gamma(1.0 / alpha)
728 gam2 = scipy.special.gamma(2.0 / alpha)
729 gam3 = scipy.special.gamma(3.0 / alpha)
731 aggd_ratio = np.sqrt(gam1) / np.sqrt(gam3)
732 bl = aggd_ratio * left_std # Left scale parameter
733 br = aggd_ratio * right_std # Right scale parameter
735 # Mean parameter
736 N = (br - bl) * (gam2 / gam1)
738 return alpha, N, bl, br, left_std, right_std
740 def _compute_paired_products(self, mscn_coeffs: np.ndarray) -> tuple:
741 """Compute products of adjacent MSCN coefficients in four orientations.
743 These products capture dependencies between neighboring pixels.
745 Args:
746 mscn_coeffs: MSCN coefficient matrix
748 Returns:
749 Tuple of (horizontal, vertical, diagonal1, diagonal2) products
750 """
751 # Shift in four directions and compute products
752 shift_h = np.roll(mscn_coeffs, 1, axis=1) # Horizontal shift
753 shift_v = np.roll(mscn_coeffs, 1, axis=0) # Vertical shift
754 shift_d1 = np.roll(shift_v, 1, axis=1) # Main diagonal shift
755 shift_d2 = np.roll(shift_v, -1, axis=1) # Anti-diagonal shift
757 # Compute products
758 prod_h = mscn_coeffs * shift_h # Horizontal pairs
759 prod_v = mscn_coeffs * shift_v # Vertical pairs
760 prod_d1 = mscn_coeffs * shift_d1 # Diagonal pairs
761 prod_d2 = mscn_coeffs * shift_d2 # Anti-diagonal pairs
763 return prod_h, prod_v, prod_d1, prod_d2
765 def _extract_subband_features(self, mscn_coeffs: np.ndarray) -> np.ndarray:
766 """Extract statistical features from MSCN coefficients and their products.
768 Args:
769 mscn_coeffs: MSCN coefficient matrix
771 Returns:
772 Feature vector of length 18
773 """
774 # Extract AGGD parameters from MSCN coefficients
775 alpha_m, N, bl, br, _, _ = self._compute_aggd_features(mscn_coeffs)
777 # Compute paired products in four orientations
778 prod_h, prod_v, prod_d1, prod_d2 = self._compute_paired_products(mscn_coeffs)
780 # Extract AGGD parameters for each product orientation
781 alpha1, N1, bl1, br1, _, _ = self._compute_aggd_features(prod_h)
782 alpha2, N2, bl2, br2, _, _ = self._compute_aggd_features(prod_v)
783 alpha3, N3, bl3, br3, _, _ = self._compute_aggd_features(prod_d1)
784 alpha4, N4, bl4, br4, _, _ = self._compute_aggd_features(prod_d2)
786 # Combine all features into feature vector
787 # Note: For diagonal pairs in reference, bl3 is repeated twice (not br3)
788 features = np.array([
789 alpha_m, (bl + br) / 2.0, # Shape and scale of MSCN
790 alpha1, N1, bl1, br1, # Vertical pairs (V)
791 alpha2, N2, bl2, br2, # Horizontal pairs (H)
792 alpha3, N3, bl3, bl3, # Diagonal pairs (D1) - note: bl3 repeated
793 alpha4, N4, bl4, bl4, # Anti-diagonal pairs (D2) - note: bl4 repeated
794 ])
796 return features
798 def _extract_multiscale_features(self, image: np.ndarray) -> tuple:
799 """Extract features at multiple scales.
801 Args:
802 image: Input grayscale image
804 Returns:
805 Tuple of (all_features, mean_features, sample_covariance)
806 """
807 h, w = image.shape
809 # Check minimum size requirements
810 if h < self.patch_size or w < self.patch_size:
811 raise ValueError(f"Image too small. Minimum size: {self.patch_size}x{self.patch_size}")
813 # Ensure that the patch divides evenly into img
814 hoffset = h % self.patch_size
815 woffset = w % self.patch_size
817 if hoffset > 0:
818 image = image[:-hoffset, :]
819 if woffset > 0:
820 image = image[:, :-woffset]
822 # Convert to float32 for processing
823 image = image.astype(np.float32)
825 # Downsample image by factor of 2 using PIL (as in reference)
826 img_pil = Image.fromarray(image)
827 size = tuple((np.array(img_pil.size) * 0.5).astype(int))
828 img2 = np.array(img_pil.resize(size, Image.BICUBIC))
830 # Compute MSCN transforms at two scales
831 mscn1, _, _ = self._compute_mscn_transform(image)
832 mscn1 = mscn1.astype(np.float32)
834 mscn2, _, _ = self._compute_mscn_transform(img2)
835 mscn2 = mscn2.astype(np.float32)
837 # Extract features from patches at each scale
838 feats_lvl1 = self._extract_patches_test_features(mscn1, self.patch_size)
839 feats_lvl2 = self._extract_patches_test_features(mscn2, self.patch_size // 2)
841 # Concatenate features from both scales
842 feats = np.hstack((feats_lvl1, feats_lvl2))
844 # Calculate mean and covariance
845 sample_mu = np.mean(feats, axis=0)
846 sample_cov = np.cov(feats.T)
848 return feats, sample_mu, sample_cov
850 def _extract_patches_test_features(self, mscn: np.ndarray, patch_size: int) -> np.ndarray:
851 """Extract features from non-overlapping patches for test images.
853 Args:
854 mscn: MSCN coefficient matrix
855 patch_size: Size of patches
857 Returns:
858 Array of patch features
859 """
860 h, w = mscn.shape
861 patch_size = int(patch_size)
863 # Extract non-overlapping patches
864 patches = []
865 for j in range(0, h - patch_size + 1, patch_size):
866 for i in range(0, w - patch_size + 1, patch_size):
867 patch = mscn[j:j + patch_size, i:i + patch_size]
868 patches.append(patch)
870 patches = np.array(patches)
872 # Extract features from each patch
873 patch_features = []
874 for p in patches:
875 patch_features.append(self._extract_subband_features(p))
877 patch_features = np.array(patch_features)
879 return patch_features
881 def analyze(self, image: Image.Image, *args, **kwargs) -> float:
882 """Calculate NIQE score for a single image.
884 Args:
885 image: Input image to evaluate
887 Returns:
888 float: NIQE score (lower is better, typical range: 2-8)
889 """
891 import scipy.linalg
892 import scipy.special
894 # Convert to grayscale if needed
895 if image.mode != 'L':
896 if image.mode == 'RGB':
897 # Convert RGB to grayscale as in reference: using 'LA' and taking first channel
898 image = image.convert('LA')
899 img_array = np.array(image)[:,:,0].astype(np.float32)
900 else:
901 image = image.convert('L')
902 img_array = np.array(image, dtype=np.float32)
903 else:
904 img_array = np.array(image, dtype=np.float32)
906 # Check minimum size requirements
907 min_size = self.patch_size * 2 + 1
908 if img_array.shape[0] < min_size or img_array.shape[1] < min_size:
909 raise ValueError(f"Image too small. Minimum size: {min_size}x{min_size}")
911 # Extract multi-scale features
912 all_features, sample_mu, sample_cov = self._extract_multiscale_features(img_array)
914 # Compute distance from natural image statistics
915 X = sample_mu - self.pop_mu
917 # Calculate Mahalanobis-like distance
918 # Use average of sample and population covariance as in reference
919 covmat = (self.pop_cov + sample_cov) / 2.0
921 # Compute pseudo-inverse for numerical stability
922 pinv_cov = scipy.linalg.pinv(covmat)
924 # Calculate NIQE score
925 niqe_score = np.sqrt(np.dot(np.dot(X, pinv_cov), X))
927 return float(niqe_score)
930class SSIMAnalyzer(ComparedImageQualityAnalyzer):
931 """SSIM analyzer for image quality analysis.
933 Calculates Structural Similarity Index between two images.
934 Higher SSIM indicates better quality/similarity.
935 """
937 def __init__(self, max_pixel_value: float = 255.0):
938 """Initialize the SSIM analyzer.
940 Args:
941 max_pixel_value: Maximum pixel value (255 for 8-bit images)
942 """
943 super().__init__()
944 self.max_pixel_value = max_pixel_value
945 self.C1 = (0.01 * max_pixel_value) ** 2
946 self.C2 = (0.03 * max_pixel_value) ** 2
947 self.C3 = self.C2 / 2.0
949 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs) -> float:
950 """Calculate SSIM between two images.
952 Args:
953 image (Image.Image): Image to evaluate
954 reference (Image.Image): Reference image
956 Returns:
957 float: SSIM value (0 to 1)
958 """
959 # Convert to RGB if necessary
960 if image.mode != 'RGB':
961 image = image.convert('RGB')
962 if reference.mode != 'RGB':
963 reference = reference.convert('RGB')
965 # Resize if necessary
966 if image.size != reference.size:
967 reference = reference.resize(image.size, Image.Resampling.BILINEAR)
969 # Convert to numpy arrays
970 img_array = np.array(image, dtype=np.float32)
971 ref_array = np.array(reference, dtype=np.float32)
973 # Calculate means
974 mu_x = np.mean(img_array)
975 mu_y = np.mean(ref_array)
977 # Calculate variances and covariance
978 sigma_x = np.std(img_array)
979 sigma_y = np.std(ref_array)
980 sigma_xy = np.mean((img_array - mu_x) * (ref_array - mu_y))
982 # Calculate SSIM
983 luminance_mean=(2 * mu_x * mu_y + self.C1) / (mu_x**2 + mu_y**2 + self.C1)
984 contrast=(2 * sigma_x * sigma_y + self.C2) / (sigma_x**2 + sigma_y**2 + self.C2)
985 structure_comparison=(sigma_xy + self.C3) / (sigma_x * sigma_y + self.C3)
986 ssim = luminance_mean * contrast * structure_comparison
988 return float(ssim)
991class BRISQUEAnalyzer(DirectImageQualityAnalyzer):
992 """BRISQUE analyzer for no-reference image quality analysis.
994 BRISQUE (Blind/Referenceless Image Spatial Quality Evaluator)
995 evaluates perceptual quality of an image without requiring
996 a reference. Lower BRISQUE scores indicate better quality.
997 Typical range: 0 (best) ~ 100 (worst).
998 """
1000 def __init__(self, device: str = "cuda"):
1001 super().__init__()
1002 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
1004 def _preprocess(self, image: Image.Image) -> torch.Tensor:
1005 """Convert PIL Image to tensor in range [0,1] with shape (1,C,H,W)."""
1006 if image.mode != 'RGB':
1007 image = image.convert('RGB')
1008 arr = np.array(image).astype(np.float32) / 255.0
1009 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # BCHW
1010 return tensor.to(self.device)
1012 def analyze(self, image: Image.Image, *args, **kwargs) -> float:
1013 """Calculate BRISQUE score for a single image.
1015 Args:
1016 image: PIL Image
1018 Returns:
1019 float: BRISQUE score (lower is better)
1020 """
1021 x = self._preprocess(image)
1022 with torch.no_grad():
1023 score = piq.brisque(x, data_range=1.0) # piq expects [0,1]
1024 return float(score.item())
1027class VIFAnalyzer(ComparedImageQualityAnalyzer):
1028 """VIF (Visual Information Fidelity) analyzer using piq.
1030 VIF compares a distorted image with a reference image to
1031 quantify the amount of visual information preserved.
1032 Higher VIF indicates better quality/similarity.
1033 Typical range: 0 ~ 1 (sometimes higher for good quality).
1034 """
1036 def __init__(self, device: str = "cuda"):
1037 super().__init__()
1038 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
1040 def _preprocess(self, image: Image.Image) -> torch.Tensor:
1041 """Convert PIL Image to tensor in range [0,1] with shape (1,C,H,W)."""
1042 if image.mode != 'RGB':
1043 image = image.convert('RGB')
1044 arr = np.array(image).astype(np.float32) / 255.0
1045 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # BCHW
1046 return tensor.to(self.device)
1048 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs) -> float:
1049 """Calculate VIF score between image and reference.
1051 Args:
1052 image: Distorted/test image (PIL)
1053 reference: Reference image (PIL)
1055 Returns:
1056 float: VIF score (higher is better)
1057 """
1058 x = self._preprocess(image)
1059 y = self._preprocess(reference)
1061 # Ensure same size (piq expects matching shapes)
1062 if x.shape != y.shape:
1063 _, _, h, w = x.shape
1064 y = torch.nn.functional.interpolate(y, size=(h, w), mode='bilinear', align_corners=False)
1066 with torch.no_grad():
1067 score = piq.vif_p(x, y, data_range=1.0)
1068 return float(score.item())
1071class FSIMAnalyzer(ComparedImageQualityAnalyzer):
1072 """FSIM (Feature Similarity Index) analyzer using piq.
1074 FSIM compares structural similarity between two images
1075 based on phase congruency and gradient magnitude.
1076 Higher FSIM indicates better quality/similarity.
1077 Typical range: 0 ~ 1.
1078 """
1080 def __init__(self, device: str = "cuda"):
1081 super().__init__()
1082 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
1084 def _preprocess(self, image: Image.Image) -> torch.Tensor:
1085 """Convert PIL Image to tensor in range [0,1] with shape (1,C,H,W)."""
1086 if image.mode != 'RGB':
1087 image = image.convert('RGB')
1088 arr = np.array(image).astype(np.float32) / 255.0
1089 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # BCHW
1090 return tensor.to(self.device)
1092 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs) -> float:
1093 """Calculate FSIM score between image and reference.
1095 Args:
1096 image: Distorted/test image (PIL)
1097 reference: Reference image (PIL)
1099 Returns:
1100 float: FSIM score (higher is better)
1101 """
1102 x = self._preprocess(image)
1103 y = self._preprocess(reference)
1105 # Ensure same size
1106 if x.shape != y.shape:
1107 _, _, h, w = x.shape
1108 y = torch.nn.functional.interpolate(y, size=(h, w), mode='bilinear', align_corners=False)
1110 with torch.no_grad():
1111 score = piq.fsim(x, y, data_range=1.0)
1112 return float(score.item())