Coverage for markdiffusion / evaluation / tools / image_quality_analyzer.py: 95.94%
468 statements
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-14 20:17 +0000
« prev ^ index » next coverage.py v7.14.0, created at 2026-05-14 20:17 +0000
1import os
2from PIL import Image
3from typing import List, Dict, Optional, Union, Tuple
4import torch
5import torch.nn as nn
6import torch.nn.functional as F
7import numpy as np
8from abc import abstractmethod
9import lpips
10import piq
12# Package-relative path to bundled NIQE statistics; resolves regardless of CWD.
13_DEFAULT_NIQE_PARAMS = os.path.join(
14 os.path.dirname(os.path.abspath(__file__)), "data", "niqe_image_params.mat"
15)
17class ImageQualityAnalyzer:
18 """Base class for image quality analyzer."""
20 def __init__(self):
21 pass
23 @abstractmethod
24 def analyze(self):
25 pass
27class DirectImageQualityAnalyzer(ImageQualityAnalyzer):
28 """Base class for direct image quality analyzer."""
30 def __init__(self):
31 pass
33 def analyze(self, image: Image.Image, *args, **kwargs):
34 pass
36class ReferencedImageQualityAnalyzer(ImageQualityAnalyzer):
37 """Base class for referenced image quality analyzer."""
39 def __init__(self):
40 pass
42 def analyze(self, image: Image.Image, reference: Union[Image.Image, str], *args, **kwargs):
43 pass
45class GroupImageQualityAnalyzer(ImageQualityAnalyzer):
46 """Base class for group image quality analyzer."""
48 def __init__(self):
49 pass
51 def analyze(self, images: List[Image.Image], references: List[Image.Image], *args, **kwargs):
52 pass
54class RepeatImageQualityAnalyzer(ImageQualityAnalyzer):
55 """Base class for repeat image quality analyzer."""
57 def __init__(self):
58 pass
60 def analyze(self, images: List[Image.Image], *args, **kwargs):
61 pass
63class ComparedImageQualityAnalyzer(ImageQualityAnalyzer):
64 """Base class for compare image quality analyzer."""
66 def __init__(self):
67 pass
69 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs):
70 pass
72class InceptionScoreCalculator(RepeatImageQualityAnalyzer):
73 """Inception Score (IS) calculator for evaluating image generation quality.
75 Inception Score measures both the quality and diversity of generated images
76 by evaluating how confidently an Inception model can classify them and how
77 diverse the predictions are across the image set.
79 Higher IS indicates better image quality and diversity (typical range: 1-10+).
80 """
82 def __init__(self, device: str = "cuda", batch_size: int = 32, splits: int = 1):
83 """Initialize the Inception Score calculator.
85 Args:
86 device: Device to run the model on ("cuda" or "cpu")
87 batch_size: Batch size for processing images
88 splits: Number of splits for computing IS (default: 1). The splits must be divisible by the number of images for fair comparison.
89 For calculating the mean and standard error of IS, the splits should be set greater than 1.
90 If splits is 1, the IS is calculated on the entire dataset.(Avg = IS, Std = 0)
91 """
92 super().__init__()
93 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
94 self.batch_size = batch_size
95 self.splits = splits
96 self._load_model()
98 def _load_model(self):
99 """Load the Inception v3 model for feature extraction."""
100 from torchvision import models, transforms
102 # Load pre-trained Inception v3 model
103 self.model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT)
104 self.model.aux_logits = False # Disable auxiliary output
105 self.model.eval()
106 self.model.to(self.device)
108 # Keep the original classification layer for proper predictions
109 # No need to modify model.fc - it should output 1000 classes
111 # Define preprocessing pipeline for Inception v3
112 self.preprocess = transforms.Compose([
113 transforms.Resize((299, 299)), # Inception v3 input size
114 transforms.ToTensor(),
115 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet statistics
116 ])
118 def _get_predictions(self, images: List[Image.Image]) -> np.ndarray:
119 """Extract softmax predictions from images using Inception v3.
121 Args:
122 images: List of PIL images to process
124 Returns:
125 Numpy array of shape (n_images, n_classes) containing softmax predictions
126 """
127 predictions_list = []
129 # Process images in batches for efficiency
130 for i in range(0, len(images), self.batch_size):
131 batch_images = images[i:i + self.batch_size]
133 # Preprocess batch
134 batch_tensors = []
135 for img in batch_images:
136 # Ensure RGB format
137 if img.mode != 'RGB':
138 img = img.convert('RGB')
139 tensor = self.preprocess(img)
140 batch_tensors.append(tensor)
142 # Stack into batch tensor
143 batch_tensor = torch.stack(batch_tensors).to(self.device)
145 # Get predictions from Inception model
146 with torch.no_grad():
147 logits = self.model(batch_tensor)
148 # Apply softmax to get probability distributions
149 probs = F.softmax(logits, dim=1)
150 predictions_list.append(probs.cpu().numpy())
152 return np.concatenate(predictions_list, axis=0)
154 def _calculate_inception_score(self, predictions: np.ndarray) -> tuple:
155 """Calculate Inception Score from predictions.
157 The IS is calculated as exp(KL divergence between conditional and marginal distributions).
159 Args:
160 predictions: Softmax predictions of shape (n_images, n_classes)
162 Returns:
163 Tuple of (mean_is, std_is) across splits
164 """
165 # Split predictions for more stable estimation
166 n_samples = predictions.shape[0] # (n_images, n_classes)
167 split_size = n_samples // self.splits
169 splits = self.splits
171 split_scores = []
173 for split_idx in range(splits):
174 # Get current split
175 start_idx = split_idx * split_size
176 end_idx = (split_idx + 1) * split_size if split_idx < splits - 1 else n_samples # Last split gets remaining samples
177 split_preds = predictions[start_idx:end_idx]
179 # Calculate marginal distribution p(y) - average across all samples
180 p_y = np.mean(split_preds, axis=0)
182 epsilon = 1e-16
183 p_y_x_safe = split_preds + epsilon
184 p_y_safe = p_y + epsilon
185 kl_divergences = np.sum(
186 p_y_x_safe * (np.log(p_y_x_safe / p_y_safe)),
187 axis=1)
189 # Inception Score for this split is exp(mean(KL divergences))
190 split_score = np.exp(np.mean(kl_divergences))
191 split_scores.append(split_score)
193 # Directly return the list of scores for each split
194 return split_scores
196 def analyze(self, images: List[Image.Image], *args, **kwargs) -> List[float]:
197 """Calculate Inception Score for a set of generated images.
199 Args:
200 images: List of generated images to evaluate
202 Returns:
203 List[float]: Inception Score values for each split (higher is better, typical range: 1-10+)
204 """
205 if len(images) < self.splits:
206 raise ValueError(f"Inception Score requires at least {self.splits} images (one per split)")
208 if len(images) % self.splits != 0:
209 raise ValueError(f"Inception Score requires the number of images to be divisible by the number of splits")
211 # Get predictions from Inception model
212 predictions = self._get_predictions(images)
214 # Calculate Inception Score
215 split_scores = self._calculate_inception_score(predictions)
217 # Log the standard deviation for reference (but return only mean)
218 mean_score = np.mean(split_scores)
219 std_score = np.std(split_scores)
220 if std_score > 0.5 * mean_score:
221 print(f"Warning: High standard deviation in IS calculation: {mean_score:.2f} ± {std_score:.2f}")
223 return split_scores
225class CLIPScoreCalculator(ReferencedImageQualityAnalyzer):
226 """CLIP score calculator for image quality analysis.
228 Calculates CLIP similarity between an image and a reference.
229 Higher scores indicate better semantic similarity.
230 """
232 def __init__(self, device: str = "cuda", model_name: str = "ViT-B/32", reference_source: str = "image"):
233 """Initialize the CLIP Score calculator.
235 Args:
236 device: Device to run the model on ("cuda" or "cpu")
237 model_name: CLIP model variant to use
238 reference_source: The source of reference ('image' or 'text')
239 """
240 super().__init__()
241 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
242 self.model_name = model_name
243 self.reference_source = reference_source
244 self._load_model()
246 def _load_model(self):
247 """Load the CLIP model."""
248 try:
249 import clip
250 self.model, self.preprocess = clip.load(self.model_name, device=self.device)
251 self.model.eval()
252 except ImportError:
253 raise ImportError("Please install the CLIP library: pip install git+https://github.com/openai/CLIP.git")
255 def analyze(self, image: Image.Image, reference: Union[Image.Image, str], *args, **kwargs) -> float:
256 """Calculate CLIP similarity between image and reference.
258 Args:
259 image: Input image to evaluate
260 reference: Reference image or text for comparison
261 - If reference_source is 'image': expects PIL Image
262 - If reference_source is 'text': expects string
264 Returns:
265 float: CLIP similarity score (0 to 1)
266 """
268 # Convert image to RGB if necessary
269 if image.mode != 'RGB':
270 image = image.convert('RGB')
272 # Preprocess image
273 img_tensor = self.preprocess(image).unsqueeze(0).to(self.device)
275 # Extract features based on reference source
276 with torch.no_grad():
277 # Encode image features
278 img_features = self.model.encode_image(img_tensor)
280 # Encode reference features based on source type
281 if self.reference_source == 'text':
282 if not isinstance(reference, str):
283 raise ValueError(f"Expected string reference for text mode, got {type(reference)}")
285 # Tokenize and encode text
286 text_tokens = clip.tokenize([reference]).to(self.device)
287 ref_features = self.model.encode_text(text_tokens)
289 elif self.reference_source == 'image':
290 if not isinstance(reference, Image.Image):
291 raise ValueError(f"Expected PIL Image reference for image mode, got {type(reference)}")
293 # Convert reference image to RGB if necessary
294 if reference.mode != 'RGB':
295 reference = reference.convert('RGB')
297 # Preprocess and encode reference image
298 ref_tensor = self.preprocess(reference).unsqueeze(0).to(self.device)
299 ref_features = self.model.encode_image(ref_tensor)
301 else:
302 raise ValueError(f"Invalid reference_source: {self.reference_source}. Must be 'image' or 'text'")
304 # Normalize features
305 img_features = F.normalize(img_features, p=2, dim=1)
306 ref_features = F.normalize(ref_features, p=2, dim=1)
308 # Calculate cosine similarity
309 similarity = torch.cosine_similarity(img_features, ref_features).item()
311 # Convert to 0-1 range
312 similarity = (similarity + 1) / 2
314 return similarity
316class FIDCalculator(GroupImageQualityAnalyzer):
317 """FID calculator for image quality analysis.
319 Calculates Fréchet Inception Distance between two sets of images.
320 Lower FID indicates better quality and similarity to reference distribution.
321 """
323 def __init__(self, device: str = "cuda", batch_size: int = 32, splits: int = 1):
324 """Initialize the FID calculator.
326 Args:
327 device: Device to run the model on ("cuda" or "cpu")
328 batch_size: Batch size for processing images
329 splits: Number of splits for computing FID (default: 5). The splits must be divisible by the number of images for fair comparison.
330 For calculating the mean and standard error of FID, the splits should be set greater than 1.
331 If splits is 1, the FID is calculated on the entire dataset.(Avg = FID, Std = 0)
332 """
333 super().__init__()
334 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
335 self.batch_size = batch_size
336 self.splits = splits
337 self._load_model()
339 def _load_model(self):
340 """Load the Inception v3 model for feature extraction."""
341 from torchvision import models, transforms
343 # Load Inception v3 model
344 inception = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT, init_weights=False)
345 inception.fc = nn.Identity() # Remove final classification layer
346 inception.aux_logits = False
347 inception.eval()
348 inception.to(self.device)
349 self.model = inception
351 # Define preprocessing
352 self.preprocess = transforms.Compose([
353 transforms.Resize((512, 512)),
354 transforms.ToTensor(),
355 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet statistics
356 ])
358 def _extract_features(self, images: List[Image.Image]) -> np.ndarray:
359 """Extract features from a list of images.
361 Args:
362 images: List of PIL images
364 Returns:
365 Feature matrix of shape (n_images, 2048)
366 """
367 features_list = []
369 for i in range(0, len(images), self.batch_size):
370 batch_images = images[i:i + self.batch_size]
372 # Preprocess batch
373 batch_tensors = []
374 for img in batch_images:
375 if img.mode != 'RGB':
376 img = img.convert('RGB')
377 tensor = self.preprocess(img)
378 batch_tensors.append(tensor)
380 batch_tensor = torch.stack(batch_tensors).to(self.device)
382 # Extract features
383 with torch.no_grad():
384 features = self.model(batch_tensor) # (batch_size, 2048)
385 features_list.append(features.cpu().numpy())
387 return np.concatenate(features_list, axis=0) # (n_images, 2048)
389 def _calculate_fid(self, features1: np.ndarray, features2: np.ndarray) -> float:
390 """Calculate FID between two feature sets.
392 Args:
393 features1: First feature set
394 features2: Second feature set
396 Returns:
397 float: FID score
398 """
399 from scipy.linalg import sqrtm
401 # Calculate statistics
402 mu1, sigma1 = features1.mean(axis=0), np.cov(features1, rowvar=False)
403 mu2, sigma2 = features2.mean(axis=0), np.cov(features2, rowvar=False)
405 # Calculate FID
406 diff = mu1 - mu2
407 covmean = sqrtm(sigma1.dot(sigma2))
409 # Numerical stability
410 if np.iscomplexobj(covmean):
411 covmean = covmean.real
413 fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean)
414 return float(fid)
416 def analyze(self, images: List[Image.Image], references: List[Image.Image], *args, **kwargs) -> List[float]:
417 """Calculate FID between two sets of images.
419 Args:
420 images: Set of images to evaluate
421 references: Reference set of images
423 Returns:
424 List[float]: FID values for each split
425 """
426 if len(images) < 2 or len(references) < 2:
427 raise ValueError("FID requires at least 2 images in each set")
428 if len(images) % self.splits != 0 or len(references) % self.splits != 0:
429 raise ValueError("FID requires the number of images to be divisible by the number of splits")
431 fid_scores = []
432 # Extract features
433 features1 = self._extract_features(images)
434 features2 = self._extract_features(references)
436 # Calculate FID
437 # for i in range(self.splits):
438 # start_idx = i * len(images) // self.splits
439 # end_idx = (i + 1) * len(images) // self.splits
440 # fid_scores.append(self._calculate_fid(features1[start_idx:end_idx], features2[start_idx:end_idx]))
442 fid_scores = self._calculate_fid(features1, features2)
444 return fid_scores
446class LPIPSAnalyzer(RepeatImageQualityAnalyzer):
447 """LPIPS analyzer for image quality analysis.
449 Calculates perceptual diversity within a set of images.
450 Higher LPIPS indicates more diverse/varied images.
451 """
453 def __init__(self, device: str = "cuda", net: str = "alex"):
454 """Initialize the LPIPS analyzer.
456 Args:
457 device: Device to run the model on ("cuda" or "cpu")
458 net: Network to use ('alex', 'vgg', or 'squeeze')
459 """
460 super().__init__()
461 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
462 self.net = net
463 self._load_model()
465 def _load_model(self) -> None:
466 """
467 Load the LPIPS model.
468 """
469 self.model = lpips.LPIPS(net=self.net)
470 self.model.eval()
471 self.model.to(self.device)
473 def analyze(self, images: List[Image.Image], *args, **kwargs) -> float:
474 """Calculate average pairwise LPIPS distance within a set of images.
476 Args:
477 images: List of images to analyze diversity
479 Returns:
480 float: Average LPIPS distance (diversity score)
481 """
482 if len(images) < 2:
483 return 0.0 # No diversity with single image
485 # Preprocess all images
486 tensors = []
487 for img in images:
488 if img.mode != 'RGB':
489 img = img.convert('RGB')
490 tensor = lpips.im2tensor(np.array(img).astype(np.uint8)).to(self.device) # Convert to tensor
491 tensors.append(tensor)
493 # Calculate pairwise LPIPS distances
494 distances = []
495 for i in range(len(tensors)):
496 for j in range(i + 1, len(tensors)):
497 with torch.no_grad():
498 distance = self.model.forward(tensors[i], tensors[j]).item()
499 distances.append(distance)
501 # Return average distance as diversity score
502 return np.mean(distances) if distances else 0.0
504class PSNRAnalyzer(ComparedImageQualityAnalyzer):
505 """PSNR analyzer for image quality analysis.
507 Calculates Peak Signal-to-Noise Ratio between two images.
508 Higher PSNR indicates better quality/similarity.
509 """
511 def __init__(self, max_pixel_value: float = 255.0):
512 """Initialize the PSNR analyzer.
514 Args:
515 max_pixel_value: Maximum pixel value (255 for 8-bit images)
516 """
517 super().__init__()
518 self.max_pixel_value = max_pixel_value
520 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs) -> float:
521 """Calculate PSNR between two images.
523 Args:
524 image (Image.Image): Image to evaluate
525 reference (Image.Image): Reference image
527 Returns:
528 float: PSNR value in dB
529 """
530 # Convert to RGB if necessary
531 if image.mode != 'RGB':
532 image = image.convert('RGB')
533 if reference.mode != 'RGB':
534 reference = reference.convert('RGB')
536 # Resize if necessary
537 if image.size != reference.size:
538 reference = reference.resize(image.size, Image.Resampling.BILINEAR)
540 # Convert to numpy arrays
541 img_array = np.array(image, dtype=np.float32)
542 ref_array = np.array(reference, dtype=np.float32)
544 # Calculate MSE
545 mse = np.mean((img_array - ref_array) ** 2)
547 # Avoid division by zero
548 if mse == 0:
549 return float('inf')
551 # Calculate PSNR
552 psnr = 20 * np.log10(self.max_pixel_value / np.sqrt(mse))
554 return float(psnr)
557class NIQECalculator(DirectImageQualityAnalyzer):
558 """Natural Image Quality Evaluator (NIQE) for no-reference image quality assessment.
560 NIQE evaluates image quality based on deviations from natural scene statistics.
561 It uses a pre-trained model of natural image statistics to assess quality without
562 requiring reference images.
564 Lower NIQE scores indicate better/more natural image quality (typical range: 2-8).
565 """
567 def __init__(self,
568 model_path: str = _DEFAULT_NIQE_PARAMS,
569 patch_size: int = 96,
570 sigma: float = 7.0/6.0,
571 C: float = 1.0):
572 """Initialize NIQE calculator with pre-trained natural image statistics.
574 Args:
575 model_path: Path to the pre-trained NIQE model parameters (.mat file)
576 patch_size: Size of patches for feature extraction (default: 96)
577 sigma: Standard deviation for Gaussian window (default: 7/6)
578 C: Constant for numerical stability in MSCN transform (default: 1.0)
579 """
580 super().__init__()
581 self.patch_size = patch_size
582 self.sigma = sigma
583 self.C = C
585 # Load pre-trained natural image statistics
586 self._load_model_params(model_path)
588 # Pre-compute gamma lookup table for AGGD parameter estimation
589 self._precompute_gamma_table()
591 # Generate Gaussian window for local mean/variance computation
592 self.avg_window = self._generate_gaussian_window(3, self.sigma)
594 def _load_model_params(self, model_path: str) -> None:
595 """Load pre-trained NIQE model parameters from MAT file.
597 Args:
598 model_path: Path to the model parameters file
599 """
600 import scipy.io
601 try:
602 params = scipy.io.loadmat(model_path)
603 self.pop_mu = np.ravel(params["pop_mu"])
604 self.pop_cov = params["pop_cov"]
605 except Exception as e:
606 raise RuntimeError(f"Failed to load NIQE model parameters from {model_path}: {e}")
608 def _precompute_gamma_table(self) -> None:
609 """Pre-compute gamma function values for AGGD parameter estimation."""
610 import scipy.special
612 self.gamma_range = np.arange(0.2, 10, 0.001)
613 a = scipy.special.gamma(2.0 / self.gamma_range)
614 a *= a
615 b = scipy.special.gamma(1.0 / self.gamma_range)
616 c = scipy.special.gamma(3.0 / self.gamma_range)
617 self.prec_gammas = a / (b * c)
619 def _generate_gaussian_window(self, window_size: int, sigma: float) -> np.ndarray:
620 """Generate 1D Gaussian window for filtering.
622 Args:
623 window_size: Half-size of the window (full size = 2*window_size + 1)
624 sigma: Standard deviation of Gaussian
626 Returns:
627 1D Gaussian window weights
628 """
629 window_size = int(window_size)
630 weights = np.zeros(2 * window_size + 1)
631 weights[window_size] = 1.0
632 sum_weights = 1.0
634 sigma_sq = sigma * sigma
635 for i in range(1, window_size + 1):
636 tmp = np.exp(-0.5 * i * i / sigma_sq)
637 weights[window_size + i] = tmp
638 weights[window_size - i] = tmp
639 sum_weights += 2.0 * tmp
641 weights /= sum_weights
642 return weights
644 def _compute_mscn_transform(self, image: np.ndarray, extend_mode: str = 'constant') -> tuple:
645 """Compute Mean Subtracted Contrast Normalized (MSCN) coefficients.
647 MSCN transformation normalizes image patches by local mean and variance,
648 making the coefficients more suitable for statistical modeling.
650 Args:
651 image: Input image array
652 extend_mode: Boundary extension mode for filtering
654 Returns:
655 Tuple of (mscn_coefficients, local_variance, local_mean)
656 """
657 import scipy.ndimage
659 assert len(image.shape) == 2, "Input must be grayscale image"
660 h, w = image.shape
662 # Allocate arrays for local statistics
663 mu_image = np.zeros((h, w), dtype=np.float32)
664 var_image = np.zeros((h, w), dtype=np.float32)
665 image_float = image.astype(np.float32)
667 # Compute local mean using separable Gaussian filtering
668 scipy.ndimage.correlate1d(image_float, self.avg_window, 0, mu_image, mode=extend_mode)
669 scipy.ndimage.correlate1d(mu_image, self.avg_window, 1, mu_image, mode=extend_mode)
671 # Compute local variance
672 scipy.ndimage.correlate1d(image_float**2, self.avg_window, 0, var_image, mode=extend_mode)
673 scipy.ndimage.correlate1d(var_image, self.avg_window, 1, var_image, mode=extend_mode)
675 # Variance = E[X^2] - E[X]^2
676 var_image = np.sqrt(np.abs(var_image - mu_image**2))
678 # MSCN transform
679 mscn = (image_float - mu_image) / (var_image + self.C)
681 return mscn, var_image, mu_image
683 def _compute_aggd_features(self, coefficients: np.ndarray) -> tuple:
684 """Compute Asymmetric Generalized Gaussian Distribution (AGGD) parameters.
686 AGGD models the distribution of MSCN coefficients and their products,
687 capturing shape and asymmetry characteristics.
689 Args:
690 coefficients: MSCN coefficients
692 Returns:
693 Tuple of (alpha, N, bl, br, left_std, right_std)
694 """
695 import scipy.special
697 # Flatten coefficients
698 coeffs_flat = coefficients.flatten()
699 coeffs_squared = coeffs_flat * coeffs_flat
701 # Separate left (negative) and right (positive) tail data
702 left_data = coeffs_squared[coeffs_flat < 0]
703 right_data = coeffs_squared[coeffs_flat >= 0]
705 # Compute standard deviations for left and right tails
706 left_std = np.sqrt(np.mean(left_data)) if len(left_data) > 0 else 0
707 right_std = np.sqrt(np.mean(right_data)) if len(right_data) > 0 else 0
709 # Estimate gamma (shape asymmetry parameter)
710 if right_std != 0:
711 gamma_hat = left_std / right_std
712 else:
713 gamma_hat = np.inf
715 # Estimate r_hat (generalized Gaussian ratio)
716 mean_abs = np.mean(np.abs(coeffs_flat))
717 mean_squared = np.mean(coeffs_squared)
719 if mean_squared != 0:
720 r_hat = (mean_abs ** 2) / mean_squared
721 else:
722 r_hat = np.inf
724 # Normalize r_hat using gamma
725 rhat_norm = r_hat * (((gamma_hat**3 + 1) * (gamma_hat + 1)) /
726 ((gamma_hat**2 + 1) ** 2))
728 # Find best-fitting alpha by comparing with pre-computed values
729 pos = np.argmin((self.prec_gammas - rhat_norm) ** 2)
730 alpha = self.gamma_range[pos]
732 # Compute AGGD parameters
733 gam1 = scipy.special.gamma(1.0 / alpha)
734 gam2 = scipy.special.gamma(2.0 / alpha)
735 gam3 = scipy.special.gamma(3.0 / alpha)
737 aggd_ratio = np.sqrt(gam1) / np.sqrt(gam3)
738 bl = aggd_ratio * left_std # Left scale parameter
739 br = aggd_ratio * right_std # Right scale parameter
741 # Mean parameter
742 N = (br - bl) * (gam2 / gam1)
744 return alpha, N, bl, br, left_std, right_std
746 def _compute_paired_products(self, mscn_coeffs: np.ndarray) -> tuple:
747 """Compute products of adjacent MSCN coefficients in four orientations.
749 These products capture dependencies between neighboring pixels.
751 Args:
752 mscn_coeffs: MSCN coefficient matrix
754 Returns:
755 Tuple of (horizontal, vertical, diagonal1, diagonal2) products
756 """
757 # Shift in four directions and compute products
758 shift_h = np.roll(mscn_coeffs, 1, axis=1) # Horizontal shift
759 shift_v = np.roll(mscn_coeffs, 1, axis=0) # Vertical shift
760 shift_d1 = np.roll(shift_v, 1, axis=1) # Main diagonal shift
761 shift_d2 = np.roll(shift_v, -1, axis=1) # Anti-diagonal shift
763 # Compute products
764 prod_h = mscn_coeffs * shift_h # Horizontal pairs
765 prod_v = mscn_coeffs * shift_v # Vertical pairs
766 prod_d1 = mscn_coeffs * shift_d1 # Diagonal pairs
767 prod_d2 = mscn_coeffs * shift_d2 # Anti-diagonal pairs
769 return prod_h, prod_v, prod_d1, prod_d2
771 def _extract_subband_features(self, mscn_coeffs: np.ndarray) -> np.ndarray:
772 """Extract statistical features from MSCN coefficients and their products.
774 Args:
775 mscn_coeffs: MSCN coefficient matrix
777 Returns:
778 Feature vector of length 18
779 """
780 # Extract AGGD parameters from MSCN coefficients
781 alpha_m, N, bl, br, _, _ = self._compute_aggd_features(mscn_coeffs)
783 # Compute paired products in four orientations
784 prod_h, prod_v, prod_d1, prod_d2 = self._compute_paired_products(mscn_coeffs)
786 # Extract AGGD parameters for each product orientation
787 alpha1, N1, bl1, br1, _, _ = self._compute_aggd_features(prod_h)
788 alpha2, N2, bl2, br2, _, _ = self._compute_aggd_features(prod_v)
789 alpha3, N3, bl3, br3, _, _ = self._compute_aggd_features(prod_d1)
790 alpha4, N4, bl4, br4, _, _ = self._compute_aggd_features(prod_d2)
792 # Combine all features into feature vector
793 # Note: For diagonal pairs in reference, bl3 is repeated twice (not br3)
794 features = np.array([
795 alpha_m, (bl + br) / 2.0, # Shape and scale of MSCN
796 alpha1, N1, bl1, br1, # Vertical pairs (V)
797 alpha2, N2, bl2, br2, # Horizontal pairs (H)
798 alpha3, N3, bl3, bl3, # Diagonal pairs (D1) - note: bl3 repeated
799 alpha4, N4, bl4, bl4, # Anti-diagonal pairs (D2) - note: bl4 repeated
800 ])
802 return features
804 def _extract_multiscale_features(self, image: np.ndarray) -> tuple:
805 """Extract features at multiple scales.
807 Args:
808 image: Input grayscale image
810 Returns:
811 Tuple of (all_features, mean_features, sample_covariance)
812 """
813 h, w = image.shape
815 # Check minimum size requirements
816 if h < self.patch_size or w < self.patch_size:
817 raise ValueError(f"Image too small. Minimum size: {self.patch_size}x{self.patch_size}")
819 # Ensure that the patch divides evenly into img
820 hoffset = h % self.patch_size
821 woffset = w % self.patch_size
823 if hoffset > 0:
824 image = image[:-hoffset, :]
825 if woffset > 0:
826 image = image[:, :-woffset]
828 # Convert to float32 for processing
829 image = image.astype(np.float32)
831 # Downsample image by factor of 2 using PIL (as in reference)
832 img_pil = Image.fromarray(image)
833 size = tuple((np.array(img_pil.size) * 0.5).astype(int))
834 img2 = np.array(img_pil.resize(size, Image.BICUBIC))
836 # Compute MSCN transforms at two scales
837 mscn1, _, _ = self._compute_mscn_transform(image)
838 mscn1 = mscn1.astype(np.float32)
840 mscn2, _, _ = self._compute_mscn_transform(img2)
841 mscn2 = mscn2.astype(np.float32)
843 # Extract features from patches at each scale
844 feats_lvl1 = self._extract_patches_test_features(mscn1, self.patch_size)
845 feats_lvl2 = self._extract_patches_test_features(mscn2, self.patch_size // 2)
847 # Concatenate features from both scales
848 feats = np.hstack((feats_lvl1, feats_lvl2))
850 # Calculate mean and covariance
851 sample_mu = np.mean(feats, axis=0)
852 sample_cov = np.cov(feats.T)
854 return feats, sample_mu, sample_cov
856 def _extract_patches_test_features(self, mscn: np.ndarray, patch_size: int) -> np.ndarray:
857 """Extract features from non-overlapping patches for test images.
859 Args:
860 mscn: MSCN coefficient matrix
861 patch_size: Size of patches
863 Returns:
864 Array of patch features
865 """
866 h, w = mscn.shape
867 patch_size = int(patch_size)
869 # Extract non-overlapping patches
870 patches = []
871 for j in range(0, h - patch_size + 1, patch_size):
872 for i in range(0, w - patch_size + 1, patch_size):
873 patch = mscn[j:j + patch_size, i:i + patch_size]
874 patches.append(patch)
876 patches = np.array(patches)
878 # Extract features from each patch
879 patch_features = []
880 for p in patches:
881 patch_features.append(self._extract_subband_features(p))
883 patch_features = np.array(patch_features)
885 return patch_features
887 def analyze(self, image: Image.Image, *args, **kwargs) -> float:
888 """Calculate NIQE score for a single image.
890 Args:
891 image: Input image to evaluate
893 Returns:
894 float: NIQE score (lower is better, typical range: 2-8)
895 """
897 import scipy.linalg
898 import scipy.special
900 # Convert to grayscale if needed
901 if image.mode != 'L':
902 if image.mode == 'RGB':
903 # Convert RGB to grayscale as in reference: using 'LA' and taking first channel
904 image = image.convert('LA')
905 img_array = np.array(image)[:,:,0].astype(np.float32)
906 else:
907 image = image.convert('L')
908 img_array = np.array(image, dtype=np.float32)
909 else:
910 img_array = np.array(image, dtype=np.float32)
912 # Check minimum size requirements
913 min_size = self.patch_size * 2 + 1
914 if img_array.shape[0] < min_size or img_array.shape[1] < min_size:
915 raise ValueError(f"Image too small. Minimum size: {min_size}x{min_size}")
917 # Extract multi-scale features
918 all_features, sample_mu, sample_cov = self._extract_multiscale_features(img_array)
920 # Compute distance from natural image statistics
921 X = sample_mu - self.pop_mu
923 # Calculate Mahalanobis-like distance
924 # Use average of sample and population covariance as in reference
925 covmat = (self.pop_cov + sample_cov) / 2.0
927 # Compute pseudo-inverse for numerical stability
928 pinv_cov = scipy.linalg.pinv(covmat)
930 # Calculate NIQE score
931 niqe_score = np.sqrt(np.dot(np.dot(X, pinv_cov), X))
933 return float(niqe_score)
936class SSIMAnalyzer(ComparedImageQualityAnalyzer):
937 """SSIM analyzer for image quality analysis.
939 Calculates Structural Similarity Index between two images.
940 Higher SSIM indicates better quality/similarity.
941 """
943 def __init__(self, max_pixel_value: float = 255.0):
944 """Initialize the SSIM analyzer.
946 Args:
947 max_pixel_value: Maximum pixel value (255 for 8-bit images)
948 """
949 super().__init__()
950 self.max_pixel_value = max_pixel_value
951 self.C1 = (0.01 * max_pixel_value) ** 2
952 self.C2 = (0.03 * max_pixel_value) ** 2
953 self.C3 = self.C2 / 2.0
955 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs) -> float:
956 """Calculate SSIM between two images.
958 Args:
959 image (Image.Image): Image to evaluate
960 reference (Image.Image): Reference image
962 Returns:
963 float: SSIM value (0 to 1)
964 """
965 # Convert to RGB if necessary
966 if image.mode != 'RGB':
967 image = image.convert('RGB')
968 if reference.mode != 'RGB':
969 reference = reference.convert('RGB')
971 # Resize if necessary
972 if image.size != reference.size:
973 reference = reference.resize(image.size, Image.Resampling.BILINEAR)
975 # Convert to numpy arrays
976 img_array = np.array(image, dtype=np.float32)
977 ref_array = np.array(reference, dtype=np.float32)
979 # Calculate means
980 mu_x = np.mean(img_array)
981 mu_y = np.mean(ref_array)
983 # Calculate variances and covariance
984 sigma_x = np.std(img_array)
985 sigma_y = np.std(ref_array)
986 sigma_xy = np.mean((img_array - mu_x) * (ref_array - mu_y))
988 # Calculate SSIM
989 luminance_mean=(2 * mu_x * mu_y + self.C1) / (mu_x**2 + mu_y**2 + self.C1)
990 contrast=(2 * sigma_x * sigma_y + self.C2) / (sigma_x**2 + sigma_y**2 + self.C2)
991 structure_comparison=(sigma_xy + self.C3) / (sigma_x * sigma_y + self.C3)
992 ssim = luminance_mean * contrast * structure_comparison
994 return float(ssim)
997class BRISQUEAnalyzer(DirectImageQualityAnalyzer):
998 """BRISQUE analyzer for no-reference image quality analysis.
1000 BRISQUE (Blind/Referenceless Image Spatial Quality Evaluator)
1001 evaluates perceptual quality of an image without requiring
1002 a reference. Lower BRISQUE scores indicate better quality.
1003 Typical range: 0 (best) ~ 100 (worst).
1004 """
1006 def __init__(self, device: str = "cuda"):
1007 super().__init__()
1008 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
1010 def _preprocess(self, image: Image.Image) -> torch.Tensor:
1011 """Convert PIL Image to tensor in range [0,1] with shape (1,C,H,W)."""
1012 if image.mode != 'RGB':
1013 image = image.convert('RGB')
1014 arr = np.array(image).astype(np.float32) / 255.0
1015 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # BCHW
1016 return tensor.to(self.device)
1018 def analyze(self, image: Image.Image, *args, **kwargs) -> float:
1019 """Calculate BRISQUE score for a single image.
1021 Args:
1022 image: PIL Image
1024 Returns:
1025 float: BRISQUE score (lower is better)
1026 """
1027 x = self._preprocess(image)
1028 with torch.no_grad():
1029 score = piq.brisque(x, data_range=1.0) # piq expects [0,1]
1030 return float(score.item())
1033class VIFAnalyzer(ComparedImageQualityAnalyzer):
1034 """VIF (Visual Information Fidelity) analyzer using piq.
1036 VIF compares a distorted image with a reference image to
1037 quantify the amount of visual information preserved.
1038 Higher VIF indicates better quality/similarity.
1039 Typical range: 0 ~ 1 (sometimes higher for good quality).
1040 """
1042 def __init__(self, device: str = "cuda"):
1043 super().__init__()
1044 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
1046 def _preprocess(self, image: Image.Image) -> torch.Tensor:
1047 """Convert PIL Image to tensor in range [0,1] with shape (1,C,H,W)."""
1048 if image.mode != 'RGB':
1049 image = image.convert('RGB')
1050 arr = np.array(image).astype(np.float32) / 255.0
1051 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # BCHW
1052 return tensor.to(self.device)
1054 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs) -> float:
1055 """Calculate VIF score between image and reference.
1057 Args:
1058 image: Distorted/test image (PIL)
1059 reference: Reference image (PIL)
1061 Returns:
1062 float: VIF score (higher is better)
1063 """
1064 x = self._preprocess(image)
1065 y = self._preprocess(reference)
1067 # Ensure same size (piq expects matching shapes)
1068 if x.shape != y.shape:
1069 _, _, h, w = x.shape
1070 y = torch.nn.functional.interpolate(y, size=(h, w), mode='bilinear', align_corners=False)
1072 with torch.no_grad():
1073 score = piq.vif_p(x, y, data_range=1.0)
1074 return float(score.item())
1077class FSIMAnalyzer(ComparedImageQualityAnalyzer):
1078 """FSIM (Feature Similarity Index) analyzer using piq.
1080 FSIM compares structural similarity between two images
1081 based on phase congruency and gradient magnitude.
1082 Higher FSIM indicates better quality/similarity.
1083 Typical range: 0 ~ 1.
1084 """
1086 def __init__(self, device: str = "cuda"):
1087 super().__init__()
1088 self.device = torch.device(device if torch.cuda.is_available() else "cpu")
1090 def _preprocess(self, image: Image.Image) -> torch.Tensor:
1091 """Convert PIL Image to tensor in range [0,1] with shape (1,C,H,W)."""
1092 if image.mode != 'RGB':
1093 image = image.convert('RGB')
1094 arr = np.array(image).astype(np.float32) / 255.0
1095 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # BCHW
1096 return tensor.to(self.device)
1098 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs) -> float:
1099 """Calculate FSIM score between image and reference.
1101 Args:
1102 image: Distorted/test image (PIL)
1103 reference: Reference image (PIL)
1105 Returns:
1106 float: FSIM score (higher is better)
1107 """
1108 x = self._preprocess(image)
1109 y = self._preprocess(reference)
1111 # Ensure same size
1112 if x.shape != y.shape:
1113 _, _, h, w = x.shape
1114 y = torch.nn.functional.interpolate(y, size=(h, w), mode='bilinear', align_corners=False)
1116 with torch.no_grad():
1117 score = piq.fsim(x, y, data_range=1.0)
1118 return float(score.item())