Coverage for evaluation / tools / image_quality_analyzer.py: 95.92%

466 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from PIL import Image 

2from typing import List, Dict, Optional, Union, Tuple 

3import torch 

4import torch.nn as nn 

5import torch.nn.functional as F 

6import numpy as np 

7from abc import abstractmethod 

8import lpips 

9import piq 

10 

11class ImageQualityAnalyzer: 

12 """Base class for image quality analyzer.""" 

13 

14 def __init__(self): 

15 pass 

16 

17 @abstractmethod 

18 def analyze(self): 

19 pass 

20 

21class DirectImageQualityAnalyzer(ImageQualityAnalyzer): 

22 """Base class for direct image quality analyzer.""" 

23 

24 def __init__(self): 

25 pass 

26 

27 def analyze(self, image: Image.Image, *args, **kwargs): 

28 pass 

29 

30class ReferencedImageQualityAnalyzer(ImageQualityAnalyzer): 

31 """Base class for referenced image quality analyzer.""" 

32 

33 def __init__(self): 

34 pass 

35 

36 def analyze(self, image: Image.Image, reference: Union[Image.Image, str], *args, **kwargs): 

37 pass 

38 

39class GroupImageQualityAnalyzer(ImageQualityAnalyzer): 

40 """Base class for group image quality analyzer.""" 

41 

42 def __init__(self): 

43 pass 

44 

45 def analyze(self, images: List[Image.Image], references: List[Image.Image], *args, **kwargs): 

46 pass 

47 

48class RepeatImageQualityAnalyzer(ImageQualityAnalyzer): 

49 """Base class for repeat image quality analyzer.""" 

50 

51 def __init__(self): 

52 pass 

53 

54 def analyze(self, images: List[Image.Image], *args, **kwargs): 

55 pass 

56 

57class ComparedImageQualityAnalyzer(ImageQualityAnalyzer): 

58 """Base class for compare image quality analyzer.""" 

59 

60 def __init__(self): 

61 pass 

62 

63 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs): 

64 pass 

65 

66class InceptionScoreCalculator(RepeatImageQualityAnalyzer): 

67 """Inception Score (IS) calculator for evaluating image generation quality. 

68  

69 Inception Score measures both the quality and diversity of generated images 

70 by evaluating how confidently an Inception model can classify them and how 

71 diverse the predictions are across the image set. 

72  

73 Higher IS indicates better image quality and diversity (typical range: 1-10+). 

74 """ 

75 

76 def __init__(self, device: str = "cuda", batch_size: int = 32, splits: int = 1): 

77 """Initialize the Inception Score calculator. 

78  

79 Args: 

80 device: Device to run the model on ("cuda" or "cpu") 

81 batch_size: Batch size for processing images 

82 splits: Number of splits for computing IS (default: 1). The splits must be divisible by the number of images for fair comparison. 

83 For calculating the mean and standard error of IS, the splits should be set greater than 1. 

84 If splits is 1, the IS is calculated on the entire dataset.(Avg = IS, Std = 0) 

85 """ 

86 super().__init__() 

87 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

88 self.batch_size = batch_size 

89 self.splits = splits 

90 self._load_model() 

91 

92 def _load_model(self): 

93 """Load the Inception v3 model for feature extraction.""" 

94 from torchvision import models, transforms 

95 

96 # Load pre-trained Inception v3 model 

97 self.model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT) 

98 self.model.aux_logits = False # Disable auxiliary output 

99 self.model.eval() 

100 self.model.to(self.device) 

101 

102 # Keep the original classification layer for proper predictions 

103 # No need to modify model.fc - it should output 1000 classes 

104 

105 # Define preprocessing pipeline for Inception v3 

106 self.preprocess = transforms.Compose([ 

107 transforms.Resize((299, 299)), # Inception v3 input size 

108 transforms.ToTensor(), 

109 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet statistics 

110 ]) 

111 

112 def _get_predictions(self, images: List[Image.Image]) -> np.ndarray: 

113 """Extract softmax predictions from images using Inception v3. 

114  

115 Args: 

116 images: List of PIL images to process 

117  

118 Returns: 

119 Numpy array of shape (n_images, n_classes) containing softmax predictions 

120 """ 

121 predictions_list = [] 

122 

123 # Process images in batches for efficiency 

124 for i in range(0, len(images), self.batch_size): 

125 batch_images = images[i:i + self.batch_size] 

126 

127 # Preprocess batch 

128 batch_tensors = [] 

129 for img in batch_images: 

130 # Ensure RGB format 

131 if img.mode != 'RGB': 

132 img = img.convert('RGB') 

133 tensor = self.preprocess(img) 

134 batch_tensors.append(tensor) 

135 

136 # Stack into batch tensor 

137 batch_tensor = torch.stack(batch_tensors).to(self.device) 

138 

139 # Get predictions from Inception model 

140 with torch.no_grad(): 

141 logits = self.model(batch_tensor) 

142 # Apply softmax to get probability distributions 

143 probs = F.softmax(logits, dim=1) 

144 predictions_list.append(probs.cpu().numpy()) 

145 

146 return np.concatenate(predictions_list, axis=0) 

147 

148 def _calculate_inception_score(self, predictions: np.ndarray) -> tuple: 

149 """Calculate Inception Score from predictions. 

150  

151 The IS is calculated as exp(KL divergence between conditional and marginal distributions). 

152  

153 Args: 

154 predictions: Softmax predictions of shape (n_images, n_classes) 

155  

156 Returns: 

157 Tuple of (mean_is, std_is) across splits 

158 """ 

159 # Split predictions for more stable estimation 

160 n_samples = predictions.shape[0] # (n_images, n_classes) 

161 split_size = n_samples // self.splits 

162 

163 splits = self.splits 

164 

165 split_scores = [] 

166 

167 for split_idx in range(splits): 

168 # Get current split 

169 start_idx = split_idx * split_size 

170 end_idx = (split_idx + 1) * split_size if split_idx < splits - 1 else n_samples # Last split gets remaining samples 

171 split_preds = predictions[start_idx:end_idx] 

172 

173 # Calculate marginal distribution p(y) - average across all samples 

174 p_y = np.mean(split_preds, axis=0) 

175 

176 epsilon = 1e-16 

177 p_y_x_safe = split_preds + epsilon 

178 p_y_safe = p_y + epsilon 

179 kl_divergences = np.sum( 

180 p_y_x_safe * (np.log(p_y_x_safe / p_y_safe)), 

181 axis=1) 

182 

183 # Inception Score for this split is exp(mean(KL divergences)) 

184 split_score = np.exp(np.mean(kl_divergences)) 

185 split_scores.append(split_score) 

186 

187 # Directly return the list of scores for each split 

188 return split_scores 

189 

190 def analyze(self, images: List[Image.Image], *args, **kwargs) -> List[float]: 

191 """Calculate Inception Score for a set of generated images. 

192  

193 Args: 

194 images: List of generated images to evaluate 

195  

196 Returns: 

197 List[float]: Inception Score values for each split (higher is better, typical range: 1-10+) 

198 """ 

199 if len(images) < self.splits: 

200 raise ValueError(f"Inception Score requires at least {self.splits} images (one per split)") 

201 

202 if len(images) % self.splits != 0: 

203 raise ValueError(f"Inception Score requires the number of images to be divisible by the number of splits") 

204 

205 # Get predictions from Inception model 

206 predictions = self._get_predictions(images) 

207 

208 # Calculate Inception Score 

209 split_scores = self._calculate_inception_score(predictions) 

210 

211 # Log the standard deviation for reference (but return only mean) 

212 mean_score = np.mean(split_scores) 

213 std_score = np.std(split_scores) 

214 if std_score > 0.5 * mean_score: 

215 print(f"Warning: High standard deviation in IS calculation: {mean_score:.2f} ± {std_score:.2f}") 

216 

217 return split_scores 

218 

219class CLIPScoreCalculator(ReferencedImageQualityAnalyzer): 

220 """CLIP score calculator for image quality analysis. 

221  

222 Calculates CLIP similarity between an image and a reference. 

223 Higher scores indicate better semantic similarity. 

224 """ 

225 

226 def __init__(self, device: str = "cuda", model_name: str = "ViT-B/32", reference_source: str = "image"): 

227 """Initialize the CLIP Score calculator. 

228  

229 Args: 

230 device: Device to run the model on ("cuda" or "cpu") 

231 model_name: CLIP model variant to use 

232 reference_source: The source of reference ('image' or 'text') 

233 """ 

234 super().__init__() 

235 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

236 self.model_name = model_name 

237 self.reference_source = reference_source 

238 self._load_model() 

239 

240 def _load_model(self): 

241 """Load the CLIP model.""" 

242 try: 

243 import clip 

244 self.model, self.preprocess = clip.load(self.model_name, device=self.device) 

245 self.model.eval() 

246 except ImportError: 

247 raise ImportError("Please install the CLIP library: pip install git+https://github.com/openai/CLIP.git") 

248 

249 def analyze(self, image: Image.Image, reference: Union[Image.Image, str], *args, **kwargs) -> float: 

250 """Calculate CLIP similarity between image and reference. 

251  

252 Args: 

253 image: Input image to evaluate 

254 reference: Reference image or text for comparison 

255 - If reference_source is 'image': expects PIL Image 

256 - If reference_source is 'text': expects string 

257  

258 Returns: 

259 float: CLIP similarity score (0 to 1) 

260 """ 

261 

262 # Convert image to RGB if necessary 

263 if image.mode != 'RGB': 

264 image = image.convert('RGB') 

265 

266 # Preprocess image 

267 img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) 

268 

269 # Extract features based on reference source 

270 with torch.no_grad(): 

271 # Encode image features 

272 img_features = self.model.encode_image(img_tensor) 

273 

274 # Encode reference features based on source type 

275 if self.reference_source == 'text': 

276 if not isinstance(reference, str): 

277 raise ValueError(f"Expected string reference for text mode, got {type(reference)}") 

278 

279 # Tokenize and encode text 

280 text_tokens = clip.tokenize([reference]).to(self.device) 

281 ref_features = self.model.encode_text(text_tokens) 

282 

283 elif self.reference_source == 'image': 

284 if not isinstance(reference, Image.Image): 

285 raise ValueError(f"Expected PIL Image reference for image mode, got {type(reference)}") 

286 

287 # Convert reference image to RGB if necessary 

288 if reference.mode != 'RGB': 

289 reference = reference.convert('RGB') 

290 

291 # Preprocess and encode reference image 

292 ref_tensor = self.preprocess(reference).unsqueeze(0).to(self.device) 

293 ref_features = self.model.encode_image(ref_tensor) 

294 

295 else: 

296 raise ValueError(f"Invalid reference_source: {self.reference_source}. Must be 'image' or 'text'") 

297 

298 # Normalize features 

299 img_features = F.normalize(img_features, p=2, dim=1) 

300 ref_features = F.normalize(ref_features, p=2, dim=1) 

301 

302 # Calculate cosine similarity 

303 similarity = torch.cosine_similarity(img_features, ref_features).item() 

304 

305 # Convert to 0-1 range 

306 similarity = (similarity + 1) / 2 

307 

308 return similarity 

309 

310class FIDCalculator(GroupImageQualityAnalyzer): 

311 """FID calculator for image quality analysis. 

312  

313 Calculates Fréchet Inception Distance between two sets of images. 

314 Lower FID indicates better quality and similarity to reference distribution. 

315 """ 

316 

317 def __init__(self, device: str = "cuda", batch_size: int = 32, splits: int = 1): 

318 """Initialize the FID calculator. 

319  

320 Args: 

321 device: Device to run the model on ("cuda" or "cpu") 

322 batch_size: Batch size for processing images 

323 splits: Number of splits for computing FID (default: 5). The splits must be divisible by the number of images for fair comparison. 

324 For calculating the mean and standard error of FID, the splits should be set greater than 1. 

325 If splits is 1, the FID is calculated on the entire dataset.(Avg = FID, Std = 0) 

326 """ 

327 super().__init__() 

328 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

329 self.batch_size = batch_size 

330 self.splits = splits 

331 self._load_model() 

332 

333 def _load_model(self): 

334 """Load the Inception v3 model for feature extraction.""" 

335 from torchvision import models, transforms 

336 

337 # Load Inception v3 model 

338 inception = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT, init_weights=False) 

339 inception.fc = nn.Identity() # Remove final classification layer 

340 inception.aux_logits = False 

341 inception.eval() 

342 inception.to(self.device) 

343 self.model = inception 

344 

345 # Define preprocessing 

346 self.preprocess = transforms.Compose([ 

347 transforms.Resize((512, 512)), 

348 transforms.ToTensor(), 

349 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet statistics 

350 ]) 

351 

352 def _extract_features(self, images: List[Image.Image]) -> np.ndarray: 

353 """Extract features from a list of images. 

354  

355 Args: 

356 images: List of PIL images 

357  

358 Returns: 

359 Feature matrix of shape (n_images, 2048) 

360 """ 

361 features_list = [] 

362 

363 for i in range(0, len(images), self.batch_size): 

364 batch_images = images[i:i + self.batch_size] 

365 

366 # Preprocess batch 

367 batch_tensors = [] 

368 for img in batch_images: 

369 if img.mode != 'RGB': 

370 img = img.convert('RGB') 

371 tensor = self.preprocess(img) 

372 batch_tensors.append(tensor) 

373 

374 batch_tensor = torch.stack(batch_tensors).to(self.device) 

375 

376 # Extract features 

377 with torch.no_grad(): 

378 features = self.model(batch_tensor) # (batch_size, 2048) 

379 features_list.append(features.cpu().numpy()) 

380 

381 return np.concatenate(features_list, axis=0) # (n_images, 2048) 

382 

383 def _calculate_fid(self, features1: np.ndarray, features2: np.ndarray) -> float: 

384 """Calculate FID between two feature sets. 

385  

386 Args: 

387 features1: First feature set 

388 features2: Second feature set 

389  

390 Returns: 

391 float: FID score 

392 """ 

393 from scipy.linalg import sqrtm 

394 

395 # Calculate statistics 

396 mu1, sigma1 = features1.mean(axis=0), np.cov(features1, rowvar=False) 

397 mu2, sigma2 = features2.mean(axis=0), np.cov(features2, rowvar=False) 

398 

399 # Calculate FID 

400 diff = mu1 - mu2 

401 covmean = sqrtm(sigma1.dot(sigma2)) 

402 

403 # Numerical stability 

404 if np.iscomplexobj(covmean): 

405 covmean = covmean.real 

406 

407 fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean) 

408 return float(fid) 

409 

410 def analyze(self, images: List[Image.Image], references: List[Image.Image], *args, **kwargs) -> List[float]: 

411 """Calculate FID between two sets of images. 

412  

413 Args: 

414 images: Set of images to evaluate 

415 references: Reference set of images 

416  

417 Returns: 

418 List[float]: FID values for each split 

419 """ 

420 if len(images) < 2 or len(references) < 2: 

421 raise ValueError("FID requires at least 2 images in each set") 

422 if len(images) % self.splits != 0 or len(references) % self.splits != 0: 

423 raise ValueError("FID requires the number of images to be divisible by the number of splits") 

424 

425 fid_scores = [] 

426 # Extract features 

427 features1 = self._extract_features(images) 

428 features2 = self._extract_features(references) 

429 

430 # Calculate FID 

431 # for i in range(self.splits): 

432 # start_idx = i * len(images) // self.splits 

433 # end_idx = (i + 1) * len(images) // self.splits 

434 # fid_scores.append(self._calculate_fid(features1[start_idx:end_idx], features2[start_idx:end_idx])) 

435 

436 fid_scores = self._calculate_fid(features1, features2) 

437 

438 return fid_scores 

439 

440class LPIPSAnalyzer(RepeatImageQualityAnalyzer): 

441 """LPIPS analyzer for image quality analysis. 

442  

443 Calculates perceptual diversity within a set of images. 

444 Higher LPIPS indicates more diverse/varied images. 

445 """ 

446 

447 def __init__(self, device: str = "cuda", net: str = "alex"): 

448 """Initialize the LPIPS analyzer. 

449  

450 Args: 

451 device: Device to run the model on ("cuda" or "cpu") 

452 net: Network to use ('alex', 'vgg', or 'squeeze') 

453 """ 

454 super().__init__() 

455 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

456 self.net = net 

457 self._load_model() 

458 

459 def _load_model(self) -> None: 

460 """ 

461 Load the LPIPS model. 

462 """ 

463 self.model = lpips.LPIPS(net=self.net) 

464 self.model.eval() 

465 self.model.to(self.device) 

466 

467 def analyze(self, images: List[Image.Image], *args, **kwargs) -> float: 

468 """Calculate average pairwise LPIPS distance within a set of images. 

469  

470 Args: 

471 images: List of images to analyze diversity 

472  

473 Returns: 

474 float: Average LPIPS distance (diversity score) 

475 """ 

476 if len(images) < 2: 

477 return 0.0 # No diversity with single image 

478 

479 # Preprocess all images 

480 tensors = [] 

481 for img in images: 

482 if img.mode != 'RGB': 

483 img = img.convert('RGB') 

484 tensor = lpips.im2tensor(np.array(img).astype(np.uint8)).to(self.device) # Convert to tensor 

485 tensors.append(tensor) 

486 

487 # Calculate pairwise LPIPS distances 

488 distances = [] 

489 for i in range(len(tensors)): 

490 for j in range(i + 1, len(tensors)): 

491 with torch.no_grad(): 

492 distance = self.model.forward(tensors[i], tensors[j]).item() 

493 distances.append(distance) 

494 

495 # Return average distance as diversity score 

496 return np.mean(distances) if distances else 0.0 

497 

498class PSNRAnalyzer(ComparedImageQualityAnalyzer): 

499 """PSNR analyzer for image quality analysis. 

500  

501 Calculates Peak Signal-to-Noise Ratio between two images. 

502 Higher PSNR indicates better quality/similarity. 

503 """ 

504 

505 def __init__(self, max_pixel_value: float = 255.0): 

506 """Initialize the PSNR analyzer. 

507  

508 Args: 

509 max_pixel_value: Maximum pixel value (255 for 8-bit images) 

510 """ 

511 super().__init__() 

512 self.max_pixel_value = max_pixel_value 

513 

514 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs) -> float: 

515 """Calculate PSNR between two images. 

516  

517 Args: 

518 image (Image.Image): Image to evaluate 

519 reference (Image.Image): Reference image 

520 

521 Returns: 

522 float: PSNR value in dB 

523 """ 

524 # Convert to RGB if necessary 

525 if image.mode != 'RGB': 

526 image = image.convert('RGB') 

527 if reference.mode != 'RGB': 

528 reference = reference.convert('RGB') 

529 

530 # Resize if necessary 

531 if image.size != reference.size: 

532 reference = reference.resize(image.size, Image.Resampling.BILINEAR) 

533 

534 # Convert to numpy arrays 

535 img_array = np.array(image, dtype=np.float32) 

536 ref_array = np.array(reference, dtype=np.float32) 

537 

538 # Calculate MSE 

539 mse = np.mean((img_array - ref_array) ** 2) 

540 

541 # Avoid division by zero 

542 if mse == 0: 

543 return float('inf') 

544 

545 # Calculate PSNR 

546 psnr = 20 * np.log10(self.max_pixel_value / np.sqrt(mse)) 

547 

548 return float(psnr) 

549 

550 

551class NIQECalculator(DirectImageQualityAnalyzer): 

552 """Natural Image Quality Evaluator (NIQE) for no-reference image quality assessment. 

553  

554 NIQE evaluates image quality based on deviations from natural scene statistics. 

555 It uses a pre-trained model of natural image statistics to assess quality without 

556 requiring reference images. 

557  

558 Lower NIQE scores indicate better/more natural image quality (typical range: 2-8). 

559 """ 

560 

561 def __init__(self, 

562 model_path: str = "evaluation/tools/data/niqe_image_params.mat", 

563 patch_size: int = 96, 

564 sigma: float = 7.0/6.0, 

565 C: float = 1.0): 

566 """Initialize NIQE calculator with pre-trained natural image statistics. 

567  

568 Args: 

569 model_path: Path to the pre-trained NIQE model parameters (.mat file) 

570 patch_size: Size of patches for feature extraction (default: 96) 

571 sigma: Standard deviation for Gaussian window (default: 7/6) 

572 C: Constant for numerical stability in MSCN transform (default: 1.0) 

573 """ 

574 super().__init__() 

575 self.patch_size = patch_size 

576 self.sigma = sigma 

577 self.C = C 

578 

579 # Load pre-trained natural image statistics 

580 self._load_model_params(model_path) 

581 

582 # Pre-compute gamma lookup table for AGGD parameter estimation 

583 self._precompute_gamma_table() 

584 

585 # Generate Gaussian window for local mean/variance computation 

586 self.avg_window = self._generate_gaussian_window(3, self.sigma) 

587 

588 def _load_model_params(self, model_path: str) -> None: 

589 """Load pre-trained NIQE model parameters from MAT file. 

590  

591 Args: 

592 model_path: Path to the model parameters file 

593 """ 

594 import scipy.io 

595 try: 

596 params = scipy.io.loadmat(model_path) 

597 self.pop_mu = np.ravel(params["pop_mu"]) 

598 self.pop_cov = params["pop_cov"] 

599 except Exception as e: 

600 raise RuntimeError(f"Failed to load NIQE model parameters from {model_path}: {e}") 

601 

602 def _precompute_gamma_table(self) -> None: 

603 """Pre-compute gamma function values for AGGD parameter estimation.""" 

604 import scipy.special 

605 

606 self.gamma_range = np.arange(0.2, 10, 0.001) 

607 a = scipy.special.gamma(2.0 / self.gamma_range) 

608 a *= a 

609 b = scipy.special.gamma(1.0 / self.gamma_range) 

610 c = scipy.special.gamma(3.0 / self.gamma_range) 

611 self.prec_gammas = a / (b * c) 

612 

613 def _generate_gaussian_window(self, window_size: int, sigma: float) -> np.ndarray: 

614 """Generate 1D Gaussian window for filtering. 

615  

616 Args: 

617 window_size: Half-size of the window (full size = 2*window_size + 1) 

618 sigma: Standard deviation of Gaussian 

619  

620 Returns: 

621 1D Gaussian window weights 

622 """ 

623 window_size = int(window_size) 

624 weights = np.zeros(2 * window_size + 1) 

625 weights[window_size] = 1.0 

626 sum_weights = 1.0 

627 

628 sigma_sq = sigma * sigma 

629 for i in range(1, window_size + 1): 

630 tmp = np.exp(-0.5 * i * i / sigma_sq) 

631 weights[window_size + i] = tmp 

632 weights[window_size - i] = tmp 

633 sum_weights += 2.0 * tmp 

634 

635 weights /= sum_weights 

636 return weights 

637 

638 def _compute_mscn_transform(self, image: np.ndarray, extend_mode: str = 'constant') -> tuple: 

639 """Compute Mean Subtracted Contrast Normalized (MSCN) coefficients. 

640  

641 MSCN transformation normalizes image patches by local mean and variance, 

642 making the coefficients more suitable for statistical modeling. 

643  

644 Args: 

645 image: Input image array 

646 extend_mode: Boundary extension mode for filtering 

647  

648 Returns: 

649 Tuple of (mscn_coefficients, local_variance, local_mean) 

650 """ 

651 import scipy.ndimage 

652 

653 assert len(image.shape) == 2, "Input must be grayscale image" 

654 h, w = image.shape 

655 

656 # Allocate arrays for local statistics 

657 mu_image = np.zeros((h, w), dtype=np.float32) 

658 var_image = np.zeros((h, w), dtype=np.float32) 

659 image_float = image.astype(np.float32) 

660 

661 # Compute local mean using separable Gaussian filtering 

662 scipy.ndimage.correlate1d(image_float, self.avg_window, 0, mu_image, mode=extend_mode) 

663 scipy.ndimage.correlate1d(mu_image, self.avg_window, 1, mu_image, mode=extend_mode) 

664 

665 # Compute local variance 

666 scipy.ndimage.correlate1d(image_float**2, self.avg_window, 0, var_image, mode=extend_mode) 

667 scipy.ndimage.correlate1d(var_image, self.avg_window, 1, var_image, mode=extend_mode) 

668 

669 # Variance = E[X^2] - E[X]^2 

670 var_image = np.sqrt(np.abs(var_image - mu_image**2)) 

671 

672 # MSCN transform 

673 mscn = (image_float - mu_image) / (var_image + self.C) 

674 

675 return mscn, var_image, mu_image 

676 

677 def _compute_aggd_features(self, coefficients: np.ndarray) -> tuple: 

678 """Compute Asymmetric Generalized Gaussian Distribution (AGGD) parameters. 

679  

680 AGGD models the distribution of MSCN coefficients and their products, 

681 capturing shape and asymmetry characteristics. 

682  

683 Args: 

684 coefficients: MSCN coefficients 

685  

686 Returns: 

687 Tuple of (alpha, N, bl, br, left_std, right_std) 

688 """ 

689 import scipy.special 

690 

691 # Flatten coefficients 

692 coeffs_flat = coefficients.flatten() 

693 coeffs_squared = coeffs_flat * coeffs_flat 

694 

695 # Separate left (negative) and right (positive) tail data 

696 left_data = coeffs_squared[coeffs_flat < 0] 

697 right_data = coeffs_squared[coeffs_flat >= 0] 

698 

699 # Compute standard deviations for left and right tails 

700 left_std = np.sqrt(np.mean(left_data)) if len(left_data) > 0 else 0 

701 right_std = np.sqrt(np.mean(right_data)) if len(right_data) > 0 else 0 

702 

703 # Estimate gamma (shape asymmetry parameter) 

704 if right_std != 0: 

705 gamma_hat = left_std / right_std 

706 else: 

707 gamma_hat = np.inf 

708 

709 # Estimate r_hat (generalized Gaussian ratio) 

710 mean_abs = np.mean(np.abs(coeffs_flat)) 

711 mean_squared = np.mean(coeffs_squared) 

712 

713 if mean_squared != 0: 

714 r_hat = (mean_abs ** 2) / mean_squared 

715 else: 

716 r_hat = np.inf 

717 

718 # Normalize r_hat using gamma 

719 rhat_norm = r_hat * (((gamma_hat**3 + 1) * (gamma_hat + 1)) / 

720 ((gamma_hat**2 + 1) ** 2)) 

721 

722 # Find best-fitting alpha by comparing with pre-computed values 

723 pos = np.argmin((self.prec_gammas - rhat_norm) ** 2) 

724 alpha = self.gamma_range[pos] 

725 

726 # Compute AGGD parameters 

727 gam1 = scipy.special.gamma(1.0 / alpha) 

728 gam2 = scipy.special.gamma(2.0 / alpha) 

729 gam3 = scipy.special.gamma(3.0 / alpha) 

730 

731 aggd_ratio = np.sqrt(gam1) / np.sqrt(gam3) 

732 bl = aggd_ratio * left_std # Left scale parameter 

733 br = aggd_ratio * right_std # Right scale parameter 

734 

735 # Mean parameter 

736 N = (br - bl) * (gam2 / gam1) 

737 

738 return alpha, N, bl, br, left_std, right_std 

739 

740 def _compute_paired_products(self, mscn_coeffs: np.ndarray) -> tuple: 

741 """Compute products of adjacent MSCN coefficients in four orientations. 

742  

743 These products capture dependencies between neighboring pixels. 

744  

745 Args: 

746 mscn_coeffs: MSCN coefficient matrix 

747  

748 Returns: 

749 Tuple of (horizontal, vertical, diagonal1, diagonal2) products 

750 """ 

751 # Shift in four directions and compute products 

752 shift_h = np.roll(mscn_coeffs, 1, axis=1) # Horizontal shift 

753 shift_v = np.roll(mscn_coeffs, 1, axis=0) # Vertical shift 

754 shift_d1 = np.roll(shift_v, 1, axis=1) # Main diagonal shift 

755 shift_d2 = np.roll(shift_v, -1, axis=1) # Anti-diagonal shift 

756 

757 # Compute products 

758 prod_h = mscn_coeffs * shift_h # Horizontal pairs 

759 prod_v = mscn_coeffs * shift_v # Vertical pairs 

760 prod_d1 = mscn_coeffs * shift_d1 # Diagonal pairs 

761 prod_d2 = mscn_coeffs * shift_d2 # Anti-diagonal pairs 

762 

763 return prod_h, prod_v, prod_d1, prod_d2 

764 

765 def _extract_subband_features(self, mscn_coeffs: np.ndarray) -> np.ndarray: 

766 """Extract statistical features from MSCN coefficients and their products. 

767  

768 Args: 

769 mscn_coeffs: MSCN coefficient matrix 

770  

771 Returns: 

772 Feature vector of length 18 

773 """ 

774 # Extract AGGD parameters from MSCN coefficients 

775 alpha_m, N, bl, br, _, _ = self._compute_aggd_features(mscn_coeffs) 

776 

777 # Compute paired products in four orientations 

778 prod_h, prod_v, prod_d1, prod_d2 = self._compute_paired_products(mscn_coeffs) 

779 

780 # Extract AGGD parameters for each product orientation 

781 alpha1, N1, bl1, br1, _, _ = self._compute_aggd_features(prod_h) 

782 alpha2, N2, bl2, br2, _, _ = self._compute_aggd_features(prod_v) 

783 alpha3, N3, bl3, br3, _, _ = self._compute_aggd_features(prod_d1) 

784 alpha4, N4, bl4, br4, _, _ = self._compute_aggd_features(prod_d2) 

785 

786 # Combine all features into feature vector 

787 # Note: For diagonal pairs in reference, bl3 is repeated twice (not br3) 

788 features = np.array([ 

789 alpha_m, (bl + br) / 2.0, # Shape and scale of MSCN 

790 alpha1, N1, bl1, br1, # Vertical pairs (V) 

791 alpha2, N2, bl2, br2, # Horizontal pairs (H) 

792 alpha3, N3, bl3, bl3, # Diagonal pairs (D1) - note: bl3 repeated 

793 alpha4, N4, bl4, bl4, # Anti-diagonal pairs (D2) - note: bl4 repeated 

794 ]) 

795 

796 return features 

797 

798 def _extract_multiscale_features(self, image: np.ndarray) -> tuple: 

799 """Extract features at multiple scales. 

800  

801 Args: 

802 image: Input grayscale image 

803  

804 Returns: 

805 Tuple of (all_features, mean_features, sample_covariance) 

806 """ 

807 h, w = image.shape 

808 

809 # Check minimum size requirements 

810 if h < self.patch_size or w < self.patch_size: 

811 raise ValueError(f"Image too small. Minimum size: {self.patch_size}x{self.patch_size}") 

812 

813 # Ensure that the patch divides evenly into img 

814 hoffset = h % self.patch_size 

815 woffset = w % self.patch_size 

816 

817 if hoffset > 0: 

818 image = image[:-hoffset, :] 

819 if woffset > 0: 

820 image = image[:, :-woffset] 

821 

822 # Convert to float32 for processing 

823 image = image.astype(np.float32) 

824 

825 # Downsample image by factor of 2 using PIL (as in reference) 

826 img_pil = Image.fromarray(image) 

827 size = tuple((np.array(img_pil.size) * 0.5).astype(int)) 

828 img2 = np.array(img_pil.resize(size, Image.BICUBIC)) 

829 

830 # Compute MSCN transforms at two scales 

831 mscn1, _, _ = self._compute_mscn_transform(image) 

832 mscn1 = mscn1.astype(np.float32) 

833 

834 mscn2, _, _ = self._compute_mscn_transform(img2) 

835 mscn2 = mscn2.astype(np.float32) 

836 

837 # Extract features from patches at each scale 

838 feats_lvl1 = self._extract_patches_test_features(mscn1, self.patch_size) 

839 feats_lvl2 = self._extract_patches_test_features(mscn2, self.patch_size // 2) 

840 

841 # Concatenate features from both scales 

842 feats = np.hstack((feats_lvl1, feats_lvl2)) 

843 

844 # Calculate mean and covariance 

845 sample_mu = np.mean(feats, axis=0) 

846 sample_cov = np.cov(feats.T) 

847 

848 return feats, sample_mu, sample_cov 

849 

850 def _extract_patches_test_features(self, mscn: np.ndarray, patch_size: int) -> np.ndarray: 

851 """Extract features from non-overlapping patches for test images. 

852  

853 Args: 

854 mscn: MSCN coefficient matrix 

855 patch_size: Size of patches 

856  

857 Returns: 

858 Array of patch features 

859 """ 

860 h, w = mscn.shape 

861 patch_size = int(patch_size) 

862 

863 # Extract non-overlapping patches 

864 patches = [] 

865 for j in range(0, h - patch_size + 1, patch_size): 

866 for i in range(0, w - patch_size + 1, patch_size): 

867 patch = mscn[j:j + patch_size, i:i + patch_size] 

868 patches.append(patch) 

869 

870 patches = np.array(patches) 

871 

872 # Extract features from each patch 

873 patch_features = [] 

874 for p in patches: 

875 patch_features.append(self._extract_subband_features(p)) 

876 

877 patch_features = np.array(patch_features) 

878 

879 return patch_features 

880 

881 def analyze(self, image: Image.Image, *args, **kwargs) -> float: 

882 """Calculate NIQE score for a single image. 

883  

884 Args: 

885 image: Input image to evaluate 

886  

887 Returns: 

888 float: NIQE score (lower is better, typical range: 2-8) 

889 """ 

890 

891 import scipy.linalg 

892 import scipy.special 

893 

894 # Convert to grayscale if needed 

895 if image.mode != 'L': 

896 if image.mode == 'RGB': 

897 # Convert RGB to grayscale as in reference: using 'LA' and taking first channel 

898 image = image.convert('LA') 

899 img_array = np.array(image)[:,:,0].astype(np.float32) 

900 else: 

901 image = image.convert('L') 

902 img_array = np.array(image, dtype=np.float32) 

903 else: 

904 img_array = np.array(image, dtype=np.float32) 

905 

906 # Check minimum size requirements 

907 min_size = self.patch_size * 2 + 1 

908 if img_array.shape[0] < min_size or img_array.shape[1] < min_size: 

909 raise ValueError(f"Image too small. Minimum size: {min_size}x{min_size}") 

910 

911 # Extract multi-scale features  

912 all_features, sample_mu, sample_cov = self._extract_multiscale_features(img_array) 

913 

914 # Compute distance from natural image statistics 

915 X = sample_mu - self.pop_mu 

916 

917 # Calculate Mahalanobis-like distance 

918 # Use average of sample and population covariance as in reference 

919 covmat = (self.pop_cov + sample_cov) / 2.0 

920 

921 # Compute pseudo-inverse for numerical stability 

922 pinv_cov = scipy.linalg.pinv(covmat) 

923 

924 # Calculate NIQE score 

925 niqe_score = np.sqrt(np.dot(np.dot(X, pinv_cov), X)) 

926 

927 return float(niqe_score) 

928 

929 

930class SSIMAnalyzer(ComparedImageQualityAnalyzer): 

931 """SSIM analyzer for image quality analysis. 

932  

933 Calculates Structural Similarity Index between two images. 

934 Higher SSIM indicates better quality/similarity. 

935 """ 

936 

937 def __init__(self, max_pixel_value: float = 255.0): 

938 """Initialize the SSIM analyzer. 

939  

940 Args: 

941 max_pixel_value: Maximum pixel value (255 for 8-bit images) 

942 """ 

943 super().__init__() 

944 self.max_pixel_value = max_pixel_value 

945 self.C1 = (0.01 * max_pixel_value) ** 2 

946 self.C2 = (0.03 * max_pixel_value) ** 2 

947 self.C3 = self.C2 / 2.0 

948 

949 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs) -> float: 

950 """Calculate SSIM between two images. 

951 

952 Args: 

953 image (Image.Image): Image to evaluate 

954 reference (Image.Image): Reference image 

955 

956 Returns: 

957 float: SSIM value (0 to 1) 

958 """ 

959 # Convert to RGB if necessary 

960 if image.mode != 'RGB': 

961 image = image.convert('RGB') 

962 if reference.mode != 'RGB': 

963 reference = reference.convert('RGB') 

964 

965 # Resize if necessary 

966 if image.size != reference.size: 

967 reference = reference.resize(image.size, Image.Resampling.BILINEAR) 

968 

969 # Convert to numpy arrays 

970 img_array = np.array(image, dtype=np.float32) 

971 ref_array = np.array(reference, dtype=np.float32) 

972 

973 # Calculate means 

974 mu_x = np.mean(img_array) 

975 mu_y = np.mean(ref_array) 

976 

977 # Calculate variances and covariance 

978 sigma_x = np.std(img_array) 

979 sigma_y = np.std(ref_array) 

980 sigma_xy = np.mean((img_array - mu_x) * (ref_array - mu_y)) 

981 

982 # Calculate SSIM 

983 luminance_mean=(2 * mu_x * mu_y + self.C1) / (mu_x**2 + mu_y**2 + self.C1) 

984 contrast=(2 * sigma_x * sigma_y + self.C2) / (sigma_x**2 + sigma_y**2 + self.C2) 

985 structure_comparison=(sigma_xy + self.C3) / (sigma_x * sigma_y + self.C3) 

986 ssim = luminance_mean * contrast * structure_comparison 

987 

988 return float(ssim) 

989 

990 

991class BRISQUEAnalyzer(DirectImageQualityAnalyzer): 

992 """BRISQUE analyzer for no-reference image quality analysis. 

993  

994 BRISQUE (Blind/Referenceless Image Spatial Quality Evaluator) 

995 evaluates perceptual quality of an image without requiring 

996 a reference. Lower BRISQUE scores indicate better quality. 

997 Typical range: 0 (best) ~ 100 (worst). 

998 """ 

999 

1000 def __init__(self, device: str = "cuda"): 

1001 super().__init__() 

1002 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

1003 

1004 def _preprocess(self, image: Image.Image) -> torch.Tensor: 

1005 """Convert PIL Image to tensor in range [0,1] with shape (1,C,H,W).""" 

1006 if image.mode != 'RGB': 

1007 image = image.convert('RGB') 

1008 arr = np.array(image).astype(np.float32) / 255.0 

1009 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # BCHW 

1010 return tensor.to(self.device) 

1011 

1012 def analyze(self, image: Image.Image, *args, **kwargs) -> float: 

1013 """Calculate BRISQUE score for a single image. 

1014  

1015 Args: 

1016 image: PIL Image 

1017  

1018 Returns: 

1019 float: BRISQUE score (lower is better) 

1020 """ 

1021 x = self._preprocess(image) 

1022 with torch.no_grad(): 

1023 score = piq.brisque(x, data_range=1.0) # piq expects [0,1] 

1024 return float(score.item()) 

1025 

1026 

1027class VIFAnalyzer(ComparedImageQualityAnalyzer): 

1028 """VIF (Visual Information Fidelity) analyzer using piq. 

1029  

1030 VIF compares a distorted image with a reference image to  

1031 quantify the amount of visual information preserved. 

1032 Higher VIF indicates better quality/similarity. 

1033 Typical range: 0 ~ 1 (sometimes higher for good quality). 

1034 """ 

1035 

1036 def __init__(self, device: str = "cuda"): 

1037 super().__init__() 

1038 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

1039 

1040 def _preprocess(self, image: Image.Image) -> torch.Tensor: 

1041 """Convert PIL Image to tensor in range [0,1] with shape (1,C,H,W).""" 

1042 if image.mode != 'RGB': 

1043 image = image.convert('RGB') 

1044 arr = np.array(image).astype(np.float32) / 255.0 

1045 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # BCHW 

1046 return tensor.to(self.device) 

1047 

1048 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs) -> float: 

1049 """Calculate VIF score between image and reference. 

1050  

1051 Args: 

1052 image: Distorted/test image (PIL) 

1053 reference: Reference image (PIL) 

1054  

1055 Returns: 

1056 float: VIF score (higher is better) 

1057 """ 

1058 x = self._preprocess(image) 

1059 y = self._preprocess(reference) 

1060 

1061 # Ensure same size (piq expects matching shapes) 

1062 if x.shape != y.shape: 

1063 _, _, h, w = x.shape 

1064 y = torch.nn.functional.interpolate(y, size=(h, w), mode='bilinear', align_corners=False) 

1065 

1066 with torch.no_grad(): 

1067 score = piq.vif_p(x, y, data_range=1.0) 

1068 return float(score.item()) 

1069 

1070 

1071class FSIMAnalyzer(ComparedImageQualityAnalyzer): 

1072 """FSIM (Feature Similarity Index) analyzer using piq. 

1073  

1074 FSIM compares structural similarity between two images  

1075 based on phase congruency and gradient magnitude. 

1076 Higher FSIM indicates better quality/similarity. 

1077 Typical range: 0 ~ 1. 

1078 """ 

1079 

1080 def __init__(self, device: str = "cuda"): 

1081 super().__init__() 

1082 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

1083 

1084 def _preprocess(self, image: Image.Image) -> torch.Tensor: 

1085 """Convert PIL Image to tensor in range [0,1] with shape (1,C,H,W).""" 

1086 if image.mode != 'RGB': 

1087 image = image.convert('RGB') 

1088 arr = np.array(image).astype(np.float32) / 255.0 

1089 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # BCHW 

1090 return tensor.to(self.device) 

1091 

1092 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs) -> float: 

1093 """Calculate FSIM score between image and reference. 

1094  

1095 Args: 

1096 image: Distorted/test image (PIL) 

1097 reference: Reference image (PIL) 

1098  

1099 Returns: 

1100 float: FSIM score (higher is better) 

1101 """ 

1102 x = self._preprocess(image) 

1103 y = self._preprocess(reference) 

1104 

1105 # Ensure same size 

1106 if x.shape != y.shape: 

1107 _, _, h, w = x.shape 

1108 y = torch.nn.functional.interpolate(y, size=(h, w), mode='bilinear', align_corners=False) 

1109 

1110 with torch.no_grad(): 

1111 score = piq.fsim(x, y, data_range=1.0) 

1112 return float(score.item())