Coverage for visualize / auto_visualization.py: 100.00%

33 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1import torch 

2from PIL import Image 

3from typing import List, Optional, Dict, Any, Union 

4from abc import ABC, abstractmethod 

5from visualize.data_for_visualization import DataForVisualization 

6import importlib 

7from visualize.base import BaseVisualizer 

8 

9# Mapping of algorithm names to visualization data classes 

10VISUALIZATION_DATA_MAPPING = { 

11 'TR': 'visualize.tr.TreeRingVisualizer', 

12 'GS': 'visualize.gs.GaussianShadingVisualizer', 

13 'PRC': 'visualize.prc.PRCVisualizer', 

14 'RI': 'visualize.ri.RingIDVisualizer', 

15 'WIND': 'visualize.wind.WINDVisualizer', 

16 'SEAL': 'visualize.seal.SEALVisualizer', 

17 'ROBIN': 'visualize.robin.ROBINVisualizer', 

18 'VideoShield': 'visualize.videoshield.VideoShieldVisualizer', 

19 'SFW': 'visualize.sfw.SFWVisualizer', 

20 'VideoMark': 'visualize.videomark.VideoMarkVisualizer', 

21 'GM': 'visualize.gm.GaussMarkerVisualizer', 

22} 

23 

24class AutoVisualizer: 

25 """ 

26 Factory class for creating visualization data instances. 

27  

28 This is a generic visualization data factory that will instantiate the appropriate 

29 visualization data class based on the algorithm name. 

30  

31 This class cannot be instantiated directly using __init__() (throws an error). 

32 """ 

33 

34 def __init__(self): 

35 raise EnvironmentError( 

36 "AutoVisualizer is designed to be instantiated " 

37 "using the `AutoVisualizer.load(algorithm_name, **kwargs)` method." 

38 ) 

39 

40 @staticmethod 

41 def _get_visualization_class_name(algorithm_name: str) -> Optional[str]: 

42 """Get the visualization data class name from the algorithm name.""" 

43 for alg_name, class_path in VISUALIZATION_DATA_MAPPING.items(): 

44 if algorithm_name.lower() == alg_name.lower(): 

45 return class_path 

46 return None 

47 

48 @classmethod 

49 def load(cls, algorithm_name: str, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1) -> BaseVisualizer: 

50 """ 

51 Load the visualization data instance based on the algorithm name. 

52  

53 Args: 

54 algorithm_name: Name of the watermarking algorithm (e.g., 'TR', 'GS', 'PRC') 

55 data_for_visualization: DataForVisualization instance 

56  

57 Returns: 

58 BaseVisualizer: Instance of the appropriate visualization data class 

59  

60 Raises: 

61 ValueError: If the algorithm name is not supported 

62 """ 

63 # Check if the algorithm exists 

64 class_path = cls._get_visualization_class_name(algorithm_name) 

65 

66 if algorithm_name != data_for_visualization.algorithm_name: 

67 raise ValueError(f"Algorithm name mismatch: {algorithm_name} != {data_for_visualization.algorithm_name}") 

68 

69 if class_path is None: 

70 supported_algs = list(VISUALIZATION_DATA_MAPPING.keys()) 

71 raise ValueError( 

72 f"Invalid algorithm name: {algorithm_name}. " 

73 f"Supported algorithms: {', '.join(supported_algs)}" 

74 ) 

75 

76 # Load the visualization data module and class 

77 module_name, class_name = class_path.rsplit('.', 1) 

78 try: 

79 module = importlib.import_module(module_name) 

80 visualization_class = getattr(module, class_name) 

81 except (ImportError, AttributeError) as e: 

82 raise ImportError( 

83 f"Failed to load visualization data class '{class_name}' " 

84 f"from module '{module_name}': {e}" 

85 ) 

86 

87 # Create and validate the instance 

88 instance = visualization_class(data_for_visualization=data_for_visualization, dpi=dpi, watermarking_step=watermarking_step) 

89 return instance