Coverage for visualize / auto_visualization.py: 100.00%
33 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1import torch
2from PIL import Image
3from typing import List, Optional, Dict, Any, Union
4from abc import ABC, abstractmethod
5from visualize.data_for_visualization import DataForVisualization
6import importlib
7from visualize.base import BaseVisualizer
9# Mapping of algorithm names to visualization data classes
10VISUALIZATION_DATA_MAPPING = {
11 'TR': 'visualize.tr.TreeRingVisualizer',
12 'GS': 'visualize.gs.GaussianShadingVisualizer',
13 'PRC': 'visualize.prc.PRCVisualizer',
14 'RI': 'visualize.ri.RingIDVisualizer',
15 'WIND': 'visualize.wind.WINDVisualizer',
16 'SEAL': 'visualize.seal.SEALVisualizer',
17 'ROBIN': 'visualize.robin.ROBINVisualizer',
18 'VideoShield': 'visualize.videoshield.VideoShieldVisualizer',
19 'SFW': 'visualize.sfw.SFWVisualizer',
20 'VideoMark': 'visualize.videomark.VideoMarkVisualizer',
21 'GM': 'visualize.gm.GaussMarkerVisualizer',
22}
24class AutoVisualizer:
25 """
26 Factory class for creating visualization data instances.
28 This is a generic visualization data factory that will instantiate the appropriate
29 visualization data class based on the algorithm name.
31 This class cannot be instantiated directly using __init__() (throws an error).
32 """
34 def __init__(self):
35 raise EnvironmentError(
36 "AutoVisualizer is designed to be instantiated "
37 "using the `AutoVisualizer.load(algorithm_name, **kwargs)` method."
38 )
40 @staticmethod
41 def _get_visualization_class_name(algorithm_name: str) -> Optional[str]:
42 """Get the visualization data class name from the algorithm name."""
43 for alg_name, class_path in VISUALIZATION_DATA_MAPPING.items():
44 if algorithm_name.lower() == alg_name.lower():
45 return class_path
46 return None
48 @classmethod
49 def load(cls, algorithm_name: str, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1) -> BaseVisualizer:
50 """
51 Load the visualization data instance based on the algorithm name.
53 Args:
54 algorithm_name: Name of the watermarking algorithm (e.g., 'TR', 'GS', 'PRC')
55 data_for_visualization: DataForVisualization instance
57 Returns:
58 BaseVisualizer: Instance of the appropriate visualization data class
60 Raises:
61 ValueError: If the algorithm name is not supported
62 """
63 # Check if the algorithm exists
64 class_path = cls._get_visualization_class_name(algorithm_name)
66 if algorithm_name != data_for_visualization.algorithm_name:
67 raise ValueError(f"Algorithm name mismatch: {algorithm_name} != {data_for_visualization.algorithm_name}")
69 if class_path is None:
70 supported_algs = list(VISUALIZATION_DATA_MAPPING.keys())
71 raise ValueError(
72 f"Invalid algorithm name: {algorithm_name}. "
73 f"Supported algorithms: {', '.join(supported_algs)}"
74 )
76 # Load the visualization data module and class
77 module_name, class_name = class_path.rsplit('.', 1)
78 try:
79 module = importlib.import_module(module_name)
80 visualization_class = getattr(module, class_name)
81 except (ImportError, AttributeError) as e:
82 raise ImportError(
83 f"Failed to load visualization data class '{class_name}' "
84 f"from module '{module_name}': {e}"
85 )
87 # Create and validate the instance
88 instance = visualization_class(data_for_visualization=data_for_visualization, dpi=dpi, watermarking_step=watermarking_step)
89 return instance