Coverage for visualize / seal / seal_visualizer.py: 95.31%

64 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from typing import Optional 

2import torch 

3import matplotlib.pyplot as plt 

4from matplotlib.axes import Axes 

5from matplotlib.gridspec import GridSpecFromSubplotSpec 

6import numpy as np 

7import math 

8from visualize.base import BaseVisualizer 

9from visualize.data_for_visualization import DataForVisualization 

10 

11class SEALVisualizer(BaseVisualizer): 

12 """SEAL watermark visualization class""" 

13 

14 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1): 

15 super().__init__(data_for_visualization, dpi, watermarking_step) 

16 

17 def draw_embedding_distributions(self, 

18 title: str = "Embedding Distributions", 

19 ax: Optional[Axes] = None, 

20 show_legend: bool = True, 

21 show_label: bool = True, 

22 show_axis: bool = True) -> Axes: 

23 """ 

24 Draw histogram of embedding distributions comparison(original_embedding vs inspected_embedding). 

25  

26 Parameters: 

27 title (str): The title of the plot. 

28 ax (plt.Axes): The axes to plot on. 

29 show_legend (bool): Whether to show the legend. Default: True. 

30 show_label (bool): Whether to show axis labels. Default: True. 

31 show_axis (bool): Whether to show axis ticks and labels. Default: True. 

32  

33 Returns: 

34 Axes: The plotted axes. 

35 """ 

36 original_embedding = self.data.original_embedding 

37 inspected_embedding = self.data.inspected_embedding 

38 

39 # Convert to numpy arrays and flatten if necessary 

40 original_data = original_embedding.cpu().numpy().flatten() 

41 inspected_data = inspected_embedding.cpu().numpy().flatten() 

42 

43 # Create overlapping histograms with transparency 

44 ax.hist(original_data, bins=50, alpha=0.7, color='blue', 

45 label='Original Embedding', density=True, edgecolor='darkblue', linewidth=0.5) 

46 ax.hist(inspected_data, bins=50, alpha=0.7, color='red', 

47 label='Inspected Embedding', density=True, edgecolor='darkred', linewidth=0.5) 

48 

49 # Set labels and title 

50 if title != "": 

51 ax.set_title(title, fontsize=12) 

52 if show_label: 

53 ax.set_xlabel('Embedding Values') 

54 ax.set_ylabel('Density') 

55 if show_legend: 

56 ax.legend() 

57 if not show_axis: 

58 ax.set_xticks([]) 

59 ax.set_yticks([]) 

60 ax.grid(True, alpha=0.3) 

61 

62 # Create a hidden colorbar for nice visualization 

63 im = ax.scatter([], [], c=[], cmap='viridis') 

64 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0) 

65 cbar.ax.set_visible(False) 

66 

67 return ax 

68 

69 

70 def draw_patch_diff(self, 

71 title: str = "Patch Difference", 

72 cmap: str = 'RdBu', 

73 use_color_bar: bool = True, 

74 vmin: Optional[float] = None, 

75 vmax: Optional[float] = None, 

76 show_number: bool = False, 

77 ax: Optional[Axes] = None, 

78 **kwargs) -> Axes: 

79 """ 

80 Draw the difference between the reference_noise and reversed_latents in patch. 

81  

82 Parameters: 

83 title (str): The title of the plot. 

84 cmap (str): The colormap to use. 

85 use_color_bar (bool): Whether to display the colorbar. 

86 vmin (Optional[float]): Minimum value for colormap normalization. 

87 vmax (Optional[float]): Maximum value for colormap normalization. 

88 show_number (bool): Whether to display numerical values on each patch. Default: False. 

89 ax (plt.Axes): The axes to plot on. 

90  

91 Returns: 

92 plt.axes.Axes: The plotted axes. 

93 """ 

94 reversed_latent = self.data.reversed_latents[self.watermarking_step] # shape: [1, 4, 64, 64] 

95 reference_noise = self.data.reference_noise # shape: [1, 4, 64, 64] 

96 k = self.data.k_value 

97 

98 patch_per_side_h = int(math.ceil(math.sqrt(k))) 

99 patch_per_side_w = int(math.ceil(k / patch_per_side_h)) 

100 patch_height = 64 // patch_per_side_h 

101 patch_width = 64 // patch_per_side_w 

102 diff_map = torch.zeros((patch_per_side_h, patch_per_side_w)) 

103 

104 patch_count = 0 # Initialize patch counter 

105 for i in range(patch_per_side_h): 

106 for j in range(patch_per_side_w): 

107 if patch_count >= k: 

108 break 

109 y_start = i * patch_height 

110 x_start = j * patch_width 

111 y_end = min(y_start + patch_height, 64) 

112 x_end = min(x_start + patch_width, 64) 

113 patch1 = reversed_latent[:, :, y_start:y_end, x_start:x_end] 

114 patch2 = reference_noise[:, :, y_start:y_end, x_start:x_end] 

115 l2_val = torch.norm(patch1 - patch2).item() 

116 diff_map[i, j] = l2_val 

117 patch_count += 1 # Increment patch counter 

118 

119 im = ax.imshow(diff_map.cpu().numpy(), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

120 if title != "": 

121 ax.set_title(title) 

122 if use_color_bar: 

123 ax.figure.colorbar(im, ax=ax) 

124 # if show_number: 

125 # # Calculate appropriate font size based on patch size 

126 # patch_size = min(patch_per_side_h, patch_per_side_w) 

127 # if patch_size >= 8: 

128 # fontsize = 8 

129 # format_str = '{:.2f}' 

130 # elif patch_size >= 4: 

131 # fontsize = 6 

132 # format_str = '{:.1f}' 

133 # else: 

134 # fontsize = 4 

135 # format_str = '{:.0f}' 

136 # fontsize = 4 

137 # format_str = '{:.0f}' 

138 # for i in range(patch_per_side_h): 

139 # for j in range(patch_per_side_w): 

140 # if i * patch_per_side_w + j < k: # Only show numbers for valid patches 

141 # value = diff_map[i, j].item() 

142 # ax.text(j, i, format_str.format(value),  

143 # ha='center', va='center', color='white',  

144 # fontsize=fontsize, fontweight='bold') 

145 ax.axis('off') 

146 

147 return ax