Coverage for markdiffusion / evaluation / tools / image_quality_analyzer.py: 95.94%

468 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-14 20:17 +0000

1import os 

2from PIL import Image 

3from typing import List, Dict, Optional, Union, Tuple 

4import torch 

5import torch.nn as nn 

6import torch.nn.functional as F 

7import numpy as np 

8from abc import abstractmethod 

9import lpips 

10import piq 

11 

12# Package-relative path to bundled NIQE statistics; resolves regardless of CWD. 

13_DEFAULT_NIQE_PARAMS = os.path.join( 

14 os.path.dirname(os.path.abspath(__file__)), "data", "niqe_image_params.mat" 

15) 

16 

17class ImageQualityAnalyzer: 

18 """Base class for image quality analyzer.""" 

19 

20 def __init__(self): 

21 pass 

22 

23 @abstractmethod 

24 def analyze(self): 

25 pass 

26 

27class DirectImageQualityAnalyzer(ImageQualityAnalyzer): 

28 """Base class for direct image quality analyzer.""" 

29 

30 def __init__(self): 

31 pass 

32 

33 def analyze(self, image: Image.Image, *args, **kwargs): 

34 pass 

35 

36class ReferencedImageQualityAnalyzer(ImageQualityAnalyzer): 

37 """Base class for referenced image quality analyzer.""" 

38 

39 def __init__(self): 

40 pass 

41 

42 def analyze(self, image: Image.Image, reference: Union[Image.Image, str], *args, **kwargs): 

43 pass 

44 

45class GroupImageQualityAnalyzer(ImageQualityAnalyzer): 

46 """Base class for group image quality analyzer.""" 

47 

48 def __init__(self): 

49 pass 

50 

51 def analyze(self, images: List[Image.Image], references: List[Image.Image], *args, **kwargs): 

52 pass 

53 

54class RepeatImageQualityAnalyzer(ImageQualityAnalyzer): 

55 """Base class for repeat image quality analyzer.""" 

56 

57 def __init__(self): 

58 pass 

59 

60 def analyze(self, images: List[Image.Image], *args, **kwargs): 

61 pass 

62 

63class ComparedImageQualityAnalyzer(ImageQualityAnalyzer): 

64 """Base class for compare image quality analyzer.""" 

65 

66 def __init__(self): 

67 pass 

68 

69 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs): 

70 pass 

71 

72class InceptionScoreCalculator(RepeatImageQualityAnalyzer): 

73 """Inception Score (IS) calculator for evaluating image generation quality. 

74  

75 Inception Score measures both the quality and diversity of generated images 

76 by evaluating how confidently an Inception model can classify them and how 

77 diverse the predictions are across the image set. 

78  

79 Higher IS indicates better image quality and diversity (typical range: 1-10+). 

80 """ 

81 

82 def __init__(self, device: str = "cuda", batch_size: int = 32, splits: int = 1): 

83 """Initialize the Inception Score calculator. 

84  

85 Args: 

86 device: Device to run the model on ("cuda" or "cpu") 

87 batch_size: Batch size for processing images 

88 splits: Number of splits for computing IS (default: 1). The splits must be divisible by the number of images for fair comparison. 

89 For calculating the mean and standard error of IS, the splits should be set greater than 1. 

90 If splits is 1, the IS is calculated on the entire dataset.(Avg = IS, Std = 0) 

91 """ 

92 super().__init__() 

93 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

94 self.batch_size = batch_size 

95 self.splits = splits 

96 self._load_model() 

97 

98 def _load_model(self): 

99 """Load the Inception v3 model for feature extraction.""" 

100 from torchvision import models, transforms 

101 

102 # Load pre-trained Inception v3 model 

103 self.model = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT) 

104 self.model.aux_logits = False # Disable auxiliary output 

105 self.model.eval() 

106 self.model.to(self.device) 

107 

108 # Keep the original classification layer for proper predictions 

109 # No need to modify model.fc - it should output 1000 classes 

110 

111 # Define preprocessing pipeline for Inception v3 

112 self.preprocess = transforms.Compose([ 

113 transforms.Resize((299, 299)), # Inception v3 input size 

114 transforms.ToTensor(), 

115 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet statistics 

116 ]) 

117 

118 def _get_predictions(self, images: List[Image.Image]) -> np.ndarray: 

119 """Extract softmax predictions from images using Inception v3. 

120  

121 Args: 

122 images: List of PIL images to process 

123  

124 Returns: 

125 Numpy array of shape (n_images, n_classes) containing softmax predictions 

126 """ 

127 predictions_list = [] 

128 

129 # Process images in batches for efficiency 

130 for i in range(0, len(images), self.batch_size): 

131 batch_images = images[i:i + self.batch_size] 

132 

133 # Preprocess batch 

134 batch_tensors = [] 

135 for img in batch_images: 

136 # Ensure RGB format 

137 if img.mode != 'RGB': 

138 img = img.convert('RGB') 

139 tensor = self.preprocess(img) 

140 batch_tensors.append(tensor) 

141 

142 # Stack into batch tensor 

143 batch_tensor = torch.stack(batch_tensors).to(self.device) 

144 

145 # Get predictions from Inception model 

146 with torch.no_grad(): 

147 logits = self.model(batch_tensor) 

148 # Apply softmax to get probability distributions 

149 probs = F.softmax(logits, dim=1) 

150 predictions_list.append(probs.cpu().numpy()) 

151 

152 return np.concatenate(predictions_list, axis=0) 

153 

154 def _calculate_inception_score(self, predictions: np.ndarray) -> tuple: 

155 """Calculate Inception Score from predictions. 

156  

157 The IS is calculated as exp(KL divergence between conditional and marginal distributions). 

158  

159 Args: 

160 predictions: Softmax predictions of shape (n_images, n_classes) 

161  

162 Returns: 

163 Tuple of (mean_is, std_is) across splits 

164 """ 

165 # Split predictions for more stable estimation 

166 n_samples = predictions.shape[0] # (n_images, n_classes) 

167 split_size = n_samples // self.splits 

168 

169 splits = self.splits 

170 

171 split_scores = [] 

172 

173 for split_idx in range(splits): 

174 # Get current split 

175 start_idx = split_idx * split_size 

176 end_idx = (split_idx + 1) * split_size if split_idx < splits - 1 else n_samples # Last split gets remaining samples 

177 split_preds = predictions[start_idx:end_idx] 

178 

179 # Calculate marginal distribution p(y) - average across all samples 

180 p_y = np.mean(split_preds, axis=0) 

181 

182 epsilon = 1e-16 

183 p_y_x_safe = split_preds + epsilon 

184 p_y_safe = p_y + epsilon 

185 kl_divergences = np.sum( 

186 p_y_x_safe * (np.log(p_y_x_safe / p_y_safe)), 

187 axis=1) 

188 

189 # Inception Score for this split is exp(mean(KL divergences)) 

190 split_score = np.exp(np.mean(kl_divergences)) 

191 split_scores.append(split_score) 

192 

193 # Directly return the list of scores for each split 

194 return split_scores 

195 

196 def analyze(self, images: List[Image.Image], *args, **kwargs) -> List[float]: 

197 """Calculate Inception Score for a set of generated images. 

198  

199 Args: 

200 images: List of generated images to evaluate 

201  

202 Returns: 

203 List[float]: Inception Score values for each split (higher is better, typical range: 1-10+) 

204 """ 

205 if len(images) < self.splits: 

206 raise ValueError(f"Inception Score requires at least {self.splits} images (one per split)") 

207 

208 if len(images) % self.splits != 0: 

209 raise ValueError(f"Inception Score requires the number of images to be divisible by the number of splits") 

210 

211 # Get predictions from Inception model 

212 predictions = self._get_predictions(images) 

213 

214 # Calculate Inception Score 

215 split_scores = self._calculate_inception_score(predictions) 

216 

217 # Log the standard deviation for reference (but return only mean) 

218 mean_score = np.mean(split_scores) 

219 std_score = np.std(split_scores) 

220 if std_score > 0.5 * mean_score: 

221 print(f"Warning: High standard deviation in IS calculation: {mean_score:.2f} ± {std_score:.2f}") 

222 

223 return split_scores 

224 

225class CLIPScoreCalculator(ReferencedImageQualityAnalyzer): 

226 """CLIP score calculator for image quality analysis. 

227  

228 Calculates CLIP similarity between an image and a reference. 

229 Higher scores indicate better semantic similarity. 

230 """ 

231 

232 def __init__(self, device: str = "cuda", model_name: str = "ViT-B/32", reference_source: str = "image"): 

233 """Initialize the CLIP Score calculator. 

234  

235 Args: 

236 device: Device to run the model on ("cuda" or "cpu") 

237 model_name: CLIP model variant to use 

238 reference_source: The source of reference ('image' or 'text') 

239 """ 

240 super().__init__() 

241 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

242 self.model_name = model_name 

243 self.reference_source = reference_source 

244 self._load_model() 

245 

246 def _load_model(self): 

247 """Load the CLIP model.""" 

248 try: 

249 import clip 

250 self.model, self.preprocess = clip.load(self.model_name, device=self.device) 

251 self.model.eval() 

252 except ImportError: 

253 raise ImportError("Please install the CLIP library: pip install git+https://github.com/openai/CLIP.git") 

254 

255 def analyze(self, image: Image.Image, reference: Union[Image.Image, str], *args, **kwargs) -> float: 

256 """Calculate CLIP similarity between image and reference. 

257  

258 Args: 

259 image: Input image to evaluate 

260 reference: Reference image or text for comparison 

261 - If reference_source is 'image': expects PIL Image 

262 - If reference_source is 'text': expects string 

263  

264 Returns: 

265 float: CLIP similarity score (0 to 1) 

266 """ 

267 

268 # Convert image to RGB if necessary 

269 if image.mode != 'RGB': 

270 image = image.convert('RGB') 

271 

272 # Preprocess image 

273 img_tensor = self.preprocess(image).unsqueeze(0).to(self.device) 

274 

275 # Extract features based on reference source 

276 with torch.no_grad(): 

277 # Encode image features 

278 img_features = self.model.encode_image(img_tensor) 

279 

280 # Encode reference features based on source type 

281 if self.reference_source == 'text': 

282 if not isinstance(reference, str): 

283 raise ValueError(f"Expected string reference for text mode, got {type(reference)}") 

284 

285 # Tokenize and encode text 

286 text_tokens = clip.tokenize([reference]).to(self.device) 

287 ref_features = self.model.encode_text(text_tokens) 

288 

289 elif self.reference_source == 'image': 

290 if not isinstance(reference, Image.Image): 

291 raise ValueError(f"Expected PIL Image reference for image mode, got {type(reference)}") 

292 

293 # Convert reference image to RGB if necessary 

294 if reference.mode != 'RGB': 

295 reference = reference.convert('RGB') 

296 

297 # Preprocess and encode reference image 

298 ref_tensor = self.preprocess(reference).unsqueeze(0).to(self.device) 

299 ref_features = self.model.encode_image(ref_tensor) 

300 

301 else: 

302 raise ValueError(f"Invalid reference_source: {self.reference_source}. Must be 'image' or 'text'") 

303 

304 # Normalize features 

305 img_features = F.normalize(img_features, p=2, dim=1) 

306 ref_features = F.normalize(ref_features, p=2, dim=1) 

307 

308 # Calculate cosine similarity 

309 similarity = torch.cosine_similarity(img_features, ref_features).item() 

310 

311 # Convert to 0-1 range 

312 similarity = (similarity + 1) / 2 

313 

314 return similarity 

315 

316class FIDCalculator(GroupImageQualityAnalyzer): 

317 """FID calculator for image quality analysis. 

318  

319 Calculates Fréchet Inception Distance between two sets of images. 

320 Lower FID indicates better quality and similarity to reference distribution. 

321 """ 

322 

323 def __init__(self, device: str = "cuda", batch_size: int = 32, splits: int = 1): 

324 """Initialize the FID calculator. 

325  

326 Args: 

327 device: Device to run the model on ("cuda" or "cpu") 

328 batch_size: Batch size for processing images 

329 splits: Number of splits for computing FID (default: 5). The splits must be divisible by the number of images for fair comparison. 

330 For calculating the mean and standard error of FID, the splits should be set greater than 1. 

331 If splits is 1, the FID is calculated on the entire dataset.(Avg = FID, Std = 0) 

332 """ 

333 super().__init__() 

334 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

335 self.batch_size = batch_size 

336 self.splits = splits 

337 self._load_model() 

338 

339 def _load_model(self): 

340 """Load the Inception v3 model for feature extraction.""" 

341 from torchvision import models, transforms 

342 

343 # Load Inception v3 model 

344 inception = models.inception_v3(weights=models.Inception_V3_Weights.DEFAULT, init_weights=False) 

345 inception.fc = nn.Identity() # Remove final classification layer 

346 inception.aux_logits = False 

347 inception.eval() 

348 inception.to(self.device) 

349 self.model = inception 

350 

351 # Define preprocessing 

352 self.preprocess = transforms.Compose([ 

353 transforms.Resize((512, 512)), 

354 transforms.ToTensor(), 

355 transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # ImageNet statistics 

356 ]) 

357 

358 def _extract_features(self, images: List[Image.Image]) -> np.ndarray: 

359 """Extract features from a list of images. 

360  

361 Args: 

362 images: List of PIL images 

363  

364 Returns: 

365 Feature matrix of shape (n_images, 2048) 

366 """ 

367 features_list = [] 

368 

369 for i in range(0, len(images), self.batch_size): 

370 batch_images = images[i:i + self.batch_size] 

371 

372 # Preprocess batch 

373 batch_tensors = [] 

374 for img in batch_images: 

375 if img.mode != 'RGB': 

376 img = img.convert('RGB') 

377 tensor = self.preprocess(img) 

378 batch_tensors.append(tensor) 

379 

380 batch_tensor = torch.stack(batch_tensors).to(self.device) 

381 

382 # Extract features 

383 with torch.no_grad(): 

384 features = self.model(batch_tensor) # (batch_size, 2048) 

385 features_list.append(features.cpu().numpy()) 

386 

387 return np.concatenate(features_list, axis=0) # (n_images, 2048) 

388 

389 def _calculate_fid(self, features1: np.ndarray, features2: np.ndarray) -> float: 

390 """Calculate FID between two feature sets. 

391  

392 Args: 

393 features1: First feature set 

394 features2: Second feature set 

395  

396 Returns: 

397 float: FID score 

398 """ 

399 from scipy.linalg import sqrtm 

400 

401 # Calculate statistics 

402 mu1, sigma1 = features1.mean(axis=0), np.cov(features1, rowvar=False) 

403 mu2, sigma2 = features2.mean(axis=0), np.cov(features2, rowvar=False) 

404 

405 # Calculate FID 

406 diff = mu1 - mu2 

407 covmean = sqrtm(sigma1.dot(sigma2)) 

408 

409 # Numerical stability 

410 if np.iscomplexobj(covmean): 

411 covmean = covmean.real 

412 

413 fid = diff.dot(diff) + np.trace(sigma1 + sigma2 - 2 * covmean) 

414 return float(fid) 

415 

416 def analyze(self, images: List[Image.Image], references: List[Image.Image], *args, **kwargs) -> List[float]: 

417 """Calculate FID between two sets of images. 

418  

419 Args: 

420 images: Set of images to evaluate 

421 references: Reference set of images 

422  

423 Returns: 

424 List[float]: FID values for each split 

425 """ 

426 if len(images) < 2 or len(references) < 2: 

427 raise ValueError("FID requires at least 2 images in each set") 

428 if len(images) % self.splits != 0 or len(references) % self.splits != 0: 

429 raise ValueError("FID requires the number of images to be divisible by the number of splits") 

430 

431 fid_scores = [] 

432 # Extract features 

433 features1 = self._extract_features(images) 

434 features2 = self._extract_features(references) 

435 

436 # Calculate FID 

437 # for i in range(self.splits): 

438 # start_idx = i * len(images) // self.splits 

439 # end_idx = (i + 1) * len(images) // self.splits 

440 # fid_scores.append(self._calculate_fid(features1[start_idx:end_idx], features2[start_idx:end_idx])) 

441 

442 fid_scores = self._calculate_fid(features1, features2) 

443 

444 return fid_scores 

445 

446class LPIPSAnalyzer(RepeatImageQualityAnalyzer): 

447 """LPIPS analyzer for image quality analysis. 

448  

449 Calculates perceptual diversity within a set of images. 

450 Higher LPIPS indicates more diverse/varied images. 

451 """ 

452 

453 def __init__(self, device: str = "cuda", net: str = "alex"): 

454 """Initialize the LPIPS analyzer. 

455  

456 Args: 

457 device: Device to run the model on ("cuda" or "cpu") 

458 net: Network to use ('alex', 'vgg', or 'squeeze') 

459 """ 

460 super().__init__() 

461 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

462 self.net = net 

463 self._load_model() 

464 

465 def _load_model(self) -> None: 

466 """ 

467 Load the LPIPS model. 

468 """ 

469 self.model = lpips.LPIPS(net=self.net) 

470 self.model.eval() 

471 self.model.to(self.device) 

472 

473 def analyze(self, images: List[Image.Image], *args, **kwargs) -> float: 

474 """Calculate average pairwise LPIPS distance within a set of images. 

475  

476 Args: 

477 images: List of images to analyze diversity 

478  

479 Returns: 

480 float: Average LPIPS distance (diversity score) 

481 """ 

482 if len(images) < 2: 

483 return 0.0 # No diversity with single image 

484 

485 # Preprocess all images 

486 tensors = [] 

487 for img in images: 

488 if img.mode != 'RGB': 

489 img = img.convert('RGB') 

490 tensor = lpips.im2tensor(np.array(img).astype(np.uint8)).to(self.device) # Convert to tensor 

491 tensors.append(tensor) 

492 

493 # Calculate pairwise LPIPS distances 

494 distances = [] 

495 for i in range(len(tensors)): 

496 for j in range(i + 1, len(tensors)): 

497 with torch.no_grad(): 

498 distance = self.model.forward(tensors[i], tensors[j]).item() 

499 distances.append(distance) 

500 

501 # Return average distance as diversity score 

502 return np.mean(distances) if distances else 0.0 

503 

504class PSNRAnalyzer(ComparedImageQualityAnalyzer): 

505 """PSNR analyzer for image quality analysis. 

506  

507 Calculates Peak Signal-to-Noise Ratio between two images. 

508 Higher PSNR indicates better quality/similarity. 

509 """ 

510 

511 def __init__(self, max_pixel_value: float = 255.0): 

512 """Initialize the PSNR analyzer. 

513  

514 Args: 

515 max_pixel_value: Maximum pixel value (255 for 8-bit images) 

516 """ 

517 super().__init__() 

518 self.max_pixel_value = max_pixel_value 

519 

520 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs) -> float: 

521 """Calculate PSNR between two images. 

522  

523 Args: 

524 image (Image.Image): Image to evaluate 

525 reference (Image.Image): Reference image 

526 

527 Returns: 

528 float: PSNR value in dB 

529 """ 

530 # Convert to RGB if necessary 

531 if image.mode != 'RGB': 

532 image = image.convert('RGB') 

533 if reference.mode != 'RGB': 

534 reference = reference.convert('RGB') 

535 

536 # Resize if necessary 

537 if image.size != reference.size: 

538 reference = reference.resize(image.size, Image.Resampling.BILINEAR) 

539 

540 # Convert to numpy arrays 

541 img_array = np.array(image, dtype=np.float32) 

542 ref_array = np.array(reference, dtype=np.float32) 

543 

544 # Calculate MSE 

545 mse = np.mean((img_array - ref_array) ** 2) 

546 

547 # Avoid division by zero 

548 if mse == 0: 

549 return float('inf') 

550 

551 # Calculate PSNR 

552 psnr = 20 * np.log10(self.max_pixel_value / np.sqrt(mse)) 

553 

554 return float(psnr) 

555 

556 

557class NIQECalculator(DirectImageQualityAnalyzer): 

558 """Natural Image Quality Evaluator (NIQE) for no-reference image quality assessment. 

559  

560 NIQE evaluates image quality based on deviations from natural scene statistics. 

561 It uses a pre-trained model of natural image statistics to assess quality without 

562 requiring reference images. 

563  

564 Lower NIQE scores indicate better/more natural image quality (typical range: 2-8). 

565 """ 

566 

567 def __init__(self, 

568 model_path: str = _DEFAULT_NIQE_PARAMS, 

569 patch_size: int = 96, 

570 sigma: float = 7.0/6.0, 

571 C: float = 1.0): 

572 """Initialize NIQE calculator with pre-trained natural image statistics. 

573  

574 Args: 

575 model_path: Path to the pre-trained NIQE model parameters (.mat file) 

576 patch_size: Size of patches for feature extraction (default: 96) 

577 sigma: Standard deviation for Gaussian window (default: 7/6) 

578 C: Constant for numerical stability in MSCN transform (default: 1.0) 

579 """ 

580 super().__init__() 

581 self.patch_size = patch_size 

582 self.sigma = sigma 

583 self.C = C 

584 

585 # Load pre-trained natural image statistics 

586 self._load_model_params(model_path) 

587 

588 # Pre-compute gamma lookup table for AGGD parameter estimation 

589 self._precompute_gamma_table() 

590 

591 # Generate Gaussian window for local mean/variance computation 

592 self.avg_window = self._generate_gaussian_window(3, self.sigma) 

593 

594 def _load_model_params(self, model_path: str) -> None: 

595 """Load pre-trained NIQE model parameters from MAT file. 

596  

597 Args: 

598 model_path: Path to the model parameters file 

599 """ 

600 import scipy.io 

601 try: 

602 params = scipy.io.loadmat(model_path) 

603 self.pop_mu = np.ravel(params["pop_mu"]) 

604 self.pop_cov = params["pop_cov"] 

605 except Exception as e: 

606 raise RuntimeError(f"Failed to load NIQE model parameters from {model_path}: {e}") 

607 

608 def _precompute_gamma_table(self) -> None: 

609 """Pre-compute gamma function values for AGGD parameter estimation.""" 

610 import scipy.special 

611 

612 self.gamma_range = np.arange(0.2, 10, 0.001) 

613 a = scipy.special.gamma(2.0 / self.gamma_range) 

614 a *= a 

615 b = scipy.special.gamma(1.0 / self.gamma_range) 

616 c = scipy.special.gamma(3.0 / self.gamma_range) 

617 self.prec_gammas = a / (b * c) 

618 

619 def _generate_gaussian_window(self, window_size: int, sigma: float) -> np.ndarray: 

620 """Generate 1D Gaussian window for filtering. 

621  

622 Args: 

623 window_size: Half-size of the window (full size = 2*window_size + 1) 

624 sigma: Standard deviation of Gaussian 

625  

626 Returns: 

627 1D Gaussian window weights 

628 """ 

629 window_size = int(window_size) 

630 weights = np.zeros(2 * window_size + 1) 

631 weights[window_size] = 1.0 

632 sum_weights = 1.0 

633 

634 sigma_sq = sigma * sigma 

635 for i in range(1, window_size + 1): 

636 tmp = np.exp(-0.5 * i * i / sigma_sq) 

637 weights[window_size + i] = tmp 

638 weights[window_size - i] = tmp 

639 sum_weights += 2.0 * tmp 

640 

641 weights /= sum_weights 

642 return weights 

643 

644 def _compute_mscn_transform(self, image: np.ndarray, extend_mode: str = 'constant') -> tuple: 

645 """Compute Mean Subtracted Contrast Normalized (MSCN) coefficients. 

646  

647 MSCN transformation normalizes image patches by local mean and variance, 

648 making the coefficients more suitable for statistical modeling. 

649  

650 Args: 

651 image: Input image array 

652 extend_mode: Boundary extension mode for filtering 

653  

654 Returns: 

655 Tuple of (mscn_coefficients, local_variance, local_mean) 

656 """ 

657 import scipy.ndimage 

658 

659 assert len(image.shape) == 2, "Input must be grayscale image" 

660 h, w = image.shape 

661 

662 # Allocate arrays for local statistics 

663 mu_image = np.zeros((h, w), dtype=np.float32) 

664 var_image = np.zeros((h, w), dtype=np.float32) 

665 image_float = image.astype(np.float32) 

666 

667 # Compute local mean using separable Gaussian filtering 

668 scipy.ndimage.correlate1d(image_float, self.avg_window, 0, mu_image, mode=extend_mode) 

669 scipy.ndimage.correlate1d(mu_image, self.avg_window, 1, mu_image, mode=extend_mode) 

670 

671 # Compute local variance 

672 scipy.ndimage.correlate1d(image_float**2, self.avg_window, 0, var_image, mode=extend_mode) 

673 scipy.ndimage.correlate1d(var_image, self.avg_window, 1, var_image, mode=extend_mode) 

674 

675 # Variance = E[X^2] - E[X]^2 

676 var_image = np.sqrt(np.abs(var_image - mu_image**2)) 

677 

678 # MSCN transform 

679 mscn = (image_float - mu_image) / (var_image + self.C) 

680 

681 return mscn, var_image, mu_image 

682 

683 def _compute_aggd_features(self, coefficients: np.ndarray) -> tuple: 

684 """Compute Asymmetric Generalized Gaussian Distribution (AGGD) parameters. 

685  

686 AGGD models the distribution of MSCN coefficients and their products, 

687 capturing shape and asymmetry characteristics. 

688  

689 Args: 

690 coefficients: MSCN coefficients 

691  

692 Returns: 

693 Tuple of (alpha, N, bl, br, left_std, right_std) 

694 """ 

695 import scipy.special 

696 

697 # Flatten coefficients 

698 coeffs_flat = coefficients.flatten() 

699 coeffs_squared = coeffs_flat * coeffs_flat 

700 

701 # Separate left (negative) and right (positive) tail data 

702 left_data = coeffs_squared[coeffs_flat < 0] 

703 right_data = coeffs_squared[coeffs_flat >= 0] 

704 

705 # Compute standard deviations for left and right tails 

706 left_std = np.sqrt(np.mean(left_data)) if len(left_data) > 0 else 0 

707 right_std = np.sqrt(np.mean(right_data)) if len(right_data) > 0 else 0 

708 

709 # Estimate gamma (shape asymmetry parameter) 

710 if right_std != 0: 

711 gamma_hat = left_std / right_std 

712 else: 

713 gamma_hat = np.inf 

714 

715 # Estimate r_hat (generalized Gaussian ratio) 

716 mean_abs = np.mean(np.abs(coeffs_flat)) 

717 mean_squared = np.mean(coeffs_squared) 

718 

719 if mean_squared != 0: 

720 r_hat = (mean_abs ** 2) / mean_squared 

721 else: 

722 r_hat = np.inf 

723 

724 # Normalize r_hat using gamma 

725 rhat_norm = r_hat * (((gamma_hat**3 + 1) * (gamma_hat + 1)) / 

726 ((gamma_hat**2 + 1) ** 2)) 

727 

728 # Find best-fitting alpha by comparing with pre-computed values 

729 pos = np.argmin((self.prec_gammas - rhat_norm) ** 2) 

730 alpha = self.gamma_range[pos] 

731 

732 # Compute AGGD parameters 

733 gam1 = scipy.special.gamma(1.0 / alpha) 

734 gam2 = scipy.special.gamma(2.0 / alpha) 

735 gam3 = scipy.special.gamma(3.0 / alpha) 

736 

737 aggd_ratio = np.sqrt(gam1) / np.sqrt(gam3) 

738 bl = aggd_ratio * left_std # Left scale parameter 

739 br = aggd_ratio * right_std # Right scale parameter 

740 

741 # Mean parameter 

742 N = (br - bl) * (gam2 / gam1) 

743 

744 return alpha, N, bl, br, left_std, right_std 

745 

746 def _compute_paired_products(self, mscn_coeffs: np.ndarray) -> tuple: 

747 """Compute products of adjacent MSCN coefficients in four orientations. 

748  

749 These products capture dependencies between neighboring pixels. 

750  

751 Args: 

752 mscn_coeffs: MSCN coefficient matrix 

753  

754 Returns: 

755 Tuple of (horizontal, vertical, diagonal1, diagonal2) products 

756 """ 

757 # Shift in four directions and compute products 

758 shift_h = np.roll(mscn_coeffs, 1, axis=1) # Horizontal shift 

759 shift_v = np.roll(mscn_coeffs, 1, axis=0) # Vertical shift 

760 shift_d1 = np.roll(shift_v, 1, axis=1) # Main diagonal shift 

761 shift_d2 = np.roll(shift_v, -1, axis=1) # Anti-diagonal shift 

762 

763 # Compute products 

764 prod_h = mscn_coeffs * shift_h # Horizontal pairs 

765 prod_v = mscn_coeffs * shift_v # Vertical pairs 

766 prod_d1 = mscn_coeffs * shift_d1 # Diagonal pairs 

767 prod_d2 = mscn_coeffs * shift_d2 # Anti-diagonal pairs 

768 

769 return prod_h, prod_v, prod_d1, prod_d2 

770 

771 def _extract_subband_features(self, mscn_coeffs: np.ndarray) -> np.ndarray: 

772 """Extract statistical features from MSCN coefficients and their products. 

773  

774 Args: 

775 mscn_coeffs: MSCN coefficient matrix 

776  

777 Returns: 

778 Feature vector of length 18 

779 """ 

780 # Extract AGGD parameters from MSCN coefficients 

781 alpha_m, N, bl, br, _, _ = self._compute_aggd_features(mscn_coeffs) 

782 

783 # Compute paired products in four orientations 

784 prod_h, prod_v, prod_d1, prod_d2 = self._compute_paired_products(mscn_coeffs) 

785 

786 # Extract AGGD parameters for each product orientation 

787 alpha1, N1, bl1, br1, _, _ = self._compute_aggd_features(prod_h) 

788 alpha2, N2, bl2, br2, _, _ = self._compute_aggd_features(prod_v) 

789 alpha3, N3, bl3, br3, _, _ = self._compute_aggd_features(prod_d1) 

790 alpha4, N4, bl4, br4, _, _ = self._compute_aggd_features(prod_d2) 

791 

792 # Combine all features into feature vector 

793 # Note: For diagonal pairs in reference, bl3 is repeated twice (not br3) 

794 features = np.array([ 

795 alpha_m, (bl + br) / 2.0, # Shape and scale of MSCN 

796 alpha1, N1, bl1, br1, # Vertical pairs (V) 

797 alpha2, N2, bl2, br2, # Horizontal pairs (H) 

798 alpha3, N3, bl3, bl3, # Diagonal pairs (D1) - note: bl3 repeated 

799 alpha4, N4, bl4, bl4, # Anti-diagonal pairs (D2) - note: bl4 repeated 

800 ]) 

801 

802 return features 

803 

804 def _extract_multiscale_features(self, image: np.ndarray) -> tuple: 

805 """Extract features at multiple scales. 

806  

807 Args: 

808 image: Input grayscale image 

809  

810 Returns: 

811 Tuple of (all_features, mean_features, sample_covariance) 

812 """ 

813 h, w = image.shape 

814 

815 # Check minimum size requirements 

816 if h < self.patch_size or w < self.patch_size: 

817 raise ValueError(f"Image too small. Minimum size: {self.patch_size}x{self.patch_size}") 

818 

819 # Ensure that the patch divides evenly into img 

820 hoffset = h % self.patch_size 

821 woffset = w % self.patch_size 

822 

823 if hoffset > 0: 

824 image = image[:-hoffset, :] 

825 if woffset > 0: 

826 image = image[:, :-woffset] 

827 

828 # Convert to float32 for processing 

829 image = image.astype(np.float32) 

830 

831 # Downsample image by factor of 2 using PIL (as in reference) 

832 img_pil = Image.fromarray(image) 

833 size = tuple((np.array(img_pil.size) * 0.5).astype(int)) 

834 img2 = np.array(img_pil.resize(size, Image.BICUBIC)) 

835 

836 # Compute MSCN transforms at two scales 

837 mscn1, _, _ = self._compute_mscn_transform(image) 

838 mscn1 = mscn1.astype(np.float32) 

839 

840 mscn2, _, _ = self._compute_mscn_transform(img2) 

841 mscn2 = mscn2.astype(np.float32) 

842 

843 # Extract features from patches at each scale 

844 feats_lvl1 = self._extract_patches_test_features(mscn1, self.patch_size) 

845 feats_lvl2 = self._extract_patches_test_features(mscn2, self.patch_size // 2) 

846 

847 # Concatenate features from both scales 

848 feats = np.hstack((feats_lvl1, feats_lvl2)) 

849 

850 # Calculate mean and covariance 

851 sample_mu = np.mean(feats, axis=0) 

852 sample_cov = np.cov(feats.T) 

853 

854 return feats, sample_mu, sample_cov 

855 

856 def _extract_patches_test_features(self, mscn: np.ndarray, patch_size: int) -> np.ndarray: 

857 """Extract features from non-overlapping patches for test images. 

858  

859 Args: 

860 mscn: MSCN coefficient matrix 

861 patch_size: Size of patches 

862  

863 Returns: 

864 Array of patch features 

865 """ 

866 h, w = mscn.shape 

867 patch_size = int(patch_size) 

868 

869 # Extract non-overlapping patches 

870 patches = [] 

871 for j in range(0, h - patch_size + 1, patch_size): 

872 for i in range(0, w - patch_size + 1, patch_size): 

873 patch = mscn[j:j + patch_size, i:i + patch_size] 

874 patches.append(patch) 

875 

876 patches = np.array(patches) 

877 

878 # Extract features from each patch 

879 patch_features = [] 

880 for p in patches: 

881 patch_features.append(self._extract_subband_features(p)) 

882 

883 patch_features = np.array(patch_features) 

884 

885 return patch_features 

886 

887 def analyze(self, image: Image.Image, *args, **kwargs) -> float: 

888 """Calculate NIQE score for a single image. 

889  

890 Args: 

891 image: Input image to evaluate 

892  

893 Returns: 

894 float: NIQE score (lower is better, typical range: 2-8) 

895 """ 

896 

897 import scipy.linalg 

898 import scipy.special 

899 

900 # Convert to grayscale if needed 

901 if image.mode != 'L': 

902 if image.mode == 'RGB': 

903 # Convert RGB to grayscale as in reference: using 'LA' and taking first channel 

904 image = image.convert('LA') 

905 img_array = np.array(image)[:,:,0].astype(np.float32) 

906 else: 

907 image = image.convert('L') 

908 img_array = np.array(image, dtype=np.float32) 

909 else: 

910 img_array = np.array(image, dtype=np.float32) 

911 

912 # Check minimum size requirements 

913 min_size = self.patch_size * 2 + 1 

914 if img_array.shape[0] < min_size or img_array.shape[1] < min_size: 

915 raise ValueError(f"Image too small. Minimum size: {min_size}x{min_size}") 

916 

917 # Extract multi-scale features  

918 all_features, sample_mu, sample_cov = self._extract_multiscale_features(img_array) 

919 

920 # Compute distance from natural image statistics 

921 X = sample_mu - self.pop_mu 

922 

923 # Calculate Mahalanobis-like distance 

924 # Use average of sample and population covariance as in reference 

925 covmat = (self.pop_cov + sample_cov) / 2.0 

926 

927 # Compute pseudo-inverse for numerical stability 

928 pinv_cov = scipy.linalg.pinv(covmat) 

929 

930 # Calculate NIQE score 

931 niqe_score = np.sqrt(np.dot(np.dot(X, pinv_cov), X)) 

932 

933 return float(niqe_score) 

934 

935 

936class SSIMAnalyzer(ComparedImageQualityAnalyzer): 

937 """SSIM analyzer for image quality analysis. 

938  

939 Calculates Structural Similarity Index between two images. 

940 Higher SSIM indicates better quality/similarity. 

941 """ 

942 

943 def __init__(self, max_pixel_value: float = 255.0): 

944 """Initialize the SSIM analyzer. 

945  

946 Args: 

947 max_pixel_value: Maximum pixel value (255 for 8-bit images) 

948 """ 

949 super().__init__() 

950 self.max_pixel_value = max_pixel_value 

951 self.C1 = (0.01 * max_pixel_value) ** 2 

952 self.C2 = (0.03 * max_pixel_value) ** 2 

953 self.C3 = self.C2 / 2.0 

954 

955 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs) -> float: 

956 """Calculate SSIM between two images. 

957 

958 Args: 

959 image (Image.Image): Image to evaluate 

960 reference (Image.Image): Reference image 

961 

962 Returns: 

963 float: SSIM value (0 to 1) 

964 """ 

965 # Convert to RGB if necessary 

966 if image.mode != 'RGB': 

967 image = image.convert('RGB') 

968 if reference.mode != 'RGB': 

969 reference = reference.convert('RGB') 

970 

971 # Resize if necessary 

972 if image.size != reference.size: 

973 reference = reference.resize(image.size, Image.Resampling.BILINEAR) 

974 

975 # Convert to numpy arrays 

976 img_array = np.array(image, dtype=np.float32) 

977 ref_array = np.array(reference, dtype=np.float32) 

978 

979 # Calculate means 

980 mu_x = np.mean(img_array) 

981 mu_y = np.mean(ref_array) 

982 

983 # Calculate variances and covariance 

984 sigma_x = np.std(img_array) 

985 sigma_y = np.std(ref_array) 

986 sigma_xy = np.mean((img_array - mu_x) * (ref_array - mu_y)) 

987 

988 # Calculate SSIM 

989 luminance_mean=(2 * mu_x * mu_y + self.C1) / (mu_x**2 + mu_y**2 + self.C1) 

990 contrast=(2 * sigma_x * sigma_y + self.C2) / (sigma_x**2 + sigma_y**2 + self.C2) 

991 structure_comparison=(sigma_xy + self.C3) / (sigma_x * sigma_y + self.C3) 

992 ssim = luminance_mean * contrast * structure_comparison 

993 

994 return float(ssim) 

995 

996 

997class BRISQUEAnalyzer(DirectImageQualityAnalyzer): 

998 """BRISQUE analyzer for no-reference image quality analysis. 

999  

1000 BRISQUE (Blind/Referenceless Image Spatial Quality Evaluator) 

1001 evaluates perceptual quality of an image without requiring 

1002 a reference. Lower BRISQUE scores indicate better quality. 

1003 Typical range: 0 (best) ~ 100 (worst). 

1004 """ 

1005 

1006 def __init__(self, device: str = "cuda"): 

1007 super().__init__() 

1008 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

1009 

1010 def _preprocess(self, image: Image.Image) -> torch.Tensor: 

1011 """Convert PIL Image to tensor in range [0,1] with shape (1,C,H,W).""" 

1012 if image.mode != 'RGB': 

1013 image = image.convert('RGB') 

1014 arr = np.array(image).astype(np.float32) / 255.0 

1015 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # BCHW 

1016 return tensor.to(self.device) 

1017 

1018 def analyze(self, image: Image.Image, *args, **kwargs) -> float: 

1019 """Calculate BRISQUE score for a single image. 

1020  

1021 Args: 

1022 image: PIL Image 

1023  

1024 Returns: 

1025 float: BRISQUE score (lower is better) 

1026 """ 

1027 x = self._preprocess(image) 

1028 with torch.no_grad(): 

1029 score = piq.brisque(x, data_range=1.0) # piq expects [0,1] 

1030 return float(score.item()) 

1031 

1032 

1033class VIFAnalyzer(ComparedImageQualityAnalyzer): 

1034 """VIF (Visual Information Fidelity) analyzer using piq. 

1035  

1036 VIF compares a distorted image with a reference image to  

1037 quantify the amount of visual information preserved. 

1038 Higher VIF indicates better quality/similarity. 

1039 Typical range: 0 ~ 1 (sometimes higher for good quality). 

1040 """ 

1041 

1042 def __init__(self, device: str = "cuda"): 

1043 super().__init__() 

1044 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

1045 

1046 def _preprocess(self, image: Image.Image) -> torch.Tensor: 

1047 """Convert PIL Image to tensor in range [0,1] with shape (1,C,H,W).""" 

1048 if image.mode != 'RGB': 

1049 image = image.convert('RGB') 

1050 arr = np.array(image).astype(np.float32) / 255.0 

1051 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # BCHW 

1052 return tensor.to(self.device) 

1053 

1054 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs) -> float: 

1055 """Calculate VIF score between image and reference. 

1056  

1057 Args: 

1058 image: Distorted/test image (PIL) 

1059 reference: Reference image (PIL) 

1060  

1061 Returns: 

1062 float: VIF score (higher is better) 

1063 """ 

1064 x = self._preprocess(image) 

1065 y = self._preprocess(reference) 

1066 

1067 # Ensure same size (piq expects matching shapes) 

1068 if x.shape != y.shape: 

1069 _, _, h, w = x.shape 

1070 y = torch.nn.functional.interpolate(y, size=(h, w), mode='bilinear', align_corners=False) 

1071 

1072 with torch.no_grad(): 

1073 score = piq.vif_p(x, y, data_range=1.0) 

1074 return float(score.item()) 

1075 

1076 

1077class FSIMAnalyzer(ComparedImageQualityAnalyzer): 

1078 """FSIM (Feature Similarity Index) analyzer using piq. 

1079  

1080 FSIM compares structural similarity between two images  

1081 based on phase congruency and gradient magnitude. 

1082 Higher FSIM indicates better quality/similarity. 

1083 Typical range: 0 ~ 1. 

1084 """ 

1085 

1086 def __init__(self, device: str = "cuda"): 

1087 super().__init__() 

1088 self.device = torch.device(device if torch.cuda.is_available() else "cpu") 

1089 

1090 def _preprocess(self, image: Image.Image) -> torch.Tensor: 

1091 """Convert PIL Image to tensor in range [0,1] with shape (1,C,H,W).""" 

1092 if image.mode != 'RGB': 

1093 image = image.convert('RGB') 

1094 arr = np.array(image).astype(np.float32) / 255.0 

1095 tensor = torch.from_numpy(arr).permute(2, 0, 1).unsqueeze(0) # BCHW 

1096 return tensor.to(self.device) 

1097 

1098 def analyze(self, image: Image.Image, reference: Image.Image, *args, **kwargs) -> float: 

1099 """Calculate FSIM score between image and reference. 

1100  

1101 Args: 

1102 image: Distorted/test image (PIL) 

1103 reference: Reference image (PIL) 

1104  

1105 Returns: 

1106 float: FSIM score (higher is better) 

1107 """ 

1108 x = self._preprocess(image) 

1109 y = self._preprocess(reference) 

1110 

1111 # Ensure same size 

1112 if x.shape != y.shape: 

1113 _, _, h, w = x.shape 

1114 y = torch.nn.functional.interpolate(y, size=(h, w), mode='bilinear', align_corners=False) 

1115 

1116 with torch.no_grad(): 

1117 score = piq.fsim(x, y, data_range=1.0) 

1118 return float(score.item())