Coverage for visualize / prc / prc_visualizer.py: 91.46%

82 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16from typing import Optional 

17import torch 

18import matplotlib.pyplot as plt 

19from matplotlib.axes import Axes 

20from matplotlib.gridspec import GridSpec 

21import numpy as np 

22from visualize.base import BaseVisualizer 

23from visualize.data_for_visualization import DataForVisualization 

24 

25class PRCVisualizer(BaseVisualizer): 

26 """PRC (Pseudorandom Codes) watermark visualizer""" 

27 

28 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1): 

29 """ 

30 Initialize PRC visualizer 

31  

32 Args: 

33 data_for_visualization: DataForVisualization object containing visualization data 

34 dpi: DPI for visualization (default: 300) 

35 watermarking_step: The step for inserting the watermark (default: -1) 

36 """ 

37 super().__init__(data_for_visualization, dpi, watermarking_step) 

38 

39 # Pre-detach all tensors while maintaining device compatibility 

40 if hasattr(self.data, 'watermarked_latents') and self.data.watermarked_latents is not None: 

41 self.data.watermarked_latents = self.data.watermarked_latents.detach() 

42 if hasattr(self.data, 'orig_latents') and self.data.orig_latents is not None: 

43 self.data.orig_latents = self.data.orig_latents.detach() 

44 if hasattr(self.data, 'inverted_latents') and self.data.inverted_latents is not None: 

45 self.data.inverted_latents = self.data.inverted_latents.detach() 

46 if hasattr(self.data, 'prc_codeword') and self.data.prc_codeword is not None: 

47 self.data.prc_codeword = self.data.prc_codeword.detach() 

48 if hasattr(self.data, 'generator_matrix') and self.data.generator_matrix is not None: 

49 self.data.generator_matrix = self.data.generator_matrix.detach() 

50 

51 def draw_generator_matrix(self, 

52 title: str = "Generator Matrix G", 

53 cmap: str = "Blues", 

54 use_color_bar: bool = True, 

55 max_display_size: int = 50, 

56 ax: Optional[Axes] = None, 

57 **kwargs) -> Axes: 

58 """ 

59 Draw the generator matrix visualization 

60  

61 Parameters: 

62 title (str): The title of the plot 

63 cmap (str): The colormap to use 

64 use_color_bar (bool): Whether to display the colorbar 

65 max_display_size (int): Maximum size to display (for large matrices) 

66 ax (Axes): The axes to plot on 

67  

68 Returns: 

69 Axes: The plotted axes 

70 """ 

71 if hasattr(self.data, 'generator_matrix') and self.data.generator_matrix is not None: 

72 gen_matrix = self.data.generator_matrix.cpu().numpy() 

73 

74 # Show a sample of the matrix if it's too large 

75 if gen_matrix.shape[0] > max_display_size or gen_matrix.shape[1] > max_display_size: 

76 sample_size = min(max_display_size, min(gen_matrix.shape)) 

77 matrix_sample = gen_matrix[:sample_size, :sample_size] 

78 title += f" (Sample {sample_size}x{sample_size})" 

79 else: 

80 matrix_sample = gen_matrix 

81 

82 im = ax.imshow(matrix_sample, cmap=cmap, aspect='auto', **kwargs) 

83 

84 if use_color_bar: 

85 plt.colorbar(im, ax=ax, shrink=0.8) 

86 else: 

87 ax.text(0.5, 0.5, 'Generator Matrix\nNot Available', 

88 ha='center', va='center', fontsize=12, transform=ax.transAxes) 

89 

90 ax.set_title(title, fontsize=10) 

91 ax.set_xlabel('Columns') 

92 ax.set_ylabel('Rows') 

93 return ax 

94 

95 def draw_codeword(self, 

96 title: str = "PRC Codeword", 

97 cmap: str = "viridis", 

98 use_color_bar: bool = True, 

99 ax: Optional[Axes] = None, 

100 **kwargs) -> Axes: 

101 """ 

102 Draw the PRC codeword visualization 

103  

104 Parameters: 

105 title (str): The title of the plot 

106 cmap (str): The colormap to use 

107 use_color_bar (bool): Whether to display the colorbar 

108 ax (Axes): The axes to plot on 

109  

110 Returns: 

111 Axes: The plotted axes 

112 """ 

113 if hasattr(self.data, 'prc_codeword') and self.data.prc_codeword is not None: 

114 codeword = self.data.prc_codeword.cpu().numpy() 

115 

116 # If 1D, reshape for visualization 

117 if len(codeword.shape) == 1: 

118 # Create a reasonable 2D shape 

119 length = len(codeword) 

120 height = int(np.sqrt(length)) 

121 width = length // height 

122 if height * width < length: 

123 width += 1 

124 # Pad if necessary 

125 padded_codeword = np.zeros(height * width) 

126 padded_codeword[:length] = codeword 

127 codeword = padded_codeword.reshape(height, width) 

128 

129 im = ax.imshow(codeword, cmap=cmap, aspect='equal', **kwargs) 

130 

131 if use_color_bar: 

132 plt.colorbar(im, ax=ax, shrink=0.8) 

133 else: 

134 ax.text(0.5, 0.5, 'PRC Codeword\nNot Available', 

135 ha='center', va='center', fontsize=12, transform=ax.transAxes) 

136 

137 ax.set_title(title, fontsize=12) 

138 return ax 

139 

140 def draw_recovered_codeword(self, 

141 title: str = "Recovered Codeword (c̃)", 

142 cmap: str = "viridis", 

143 use_color_bar: bool = True, 

144 vmin: float = -1.0, 

145 vmax: float = 1.0, 

146 ax: Optional[Axes] = None, 

147 **kwargs) -> Axes: 

148 """ 

149 Draw the recovered codeword (c̃) from PRC detection 

150  

151 This visualizes the recovered codeword from prc_detection.recovered_prc 

152  

153 Parameters: 

154 title (str): The title of the plot 

155 cmap (str): The colormap to use 

156 use_color_bar (bool): Whether to display the colorbar 

157 vmin (float): Minimum value for colormap (-1.0) 

158 vmax (float): Maximum value for colormap (1.0) 

159 ax (Axes): The axes to plot on 

160  

161 Returns: 

162 Axes: The plotted axes 

163 """ 

164 if hasattr(self.data, 'recovered_prc') and self.data.recovered_prc is not None: 

165 recovered_codeword = self.data.recovered_prc.cpu().numpy().flatten() 

166 

167 # Ensure it's the expected length 

168 if len(recovered_codeword) == 16384: 

169 # Reshape to 2D for visualization (128x128 = 16384) 

170 codeword_2d = recovered_codeword.reshape((128, 128)) 

171 

172 im = ax.imshow(codeword_2d, cmap=cmap, vmin=vmin, vmax=vmax, aspect='equal', **kwargs) 

173 

174 if use_color_bar: 

175 cbar = ax.figure.colorbar(im, ax=ax, shrink=0.8) 

176 cbar.set_label('Codeword Value', fontsize=8) 

177 else: 

178 ax.text(0.5, 0.5, f'Recovered Codeword\nUnexpected Length: {len(recovered_codeword)}\n(Expected: 16384)', 

179 ha='center', va='center', fontsize=12, transform=ax.transAxes) 

180 else: 

181 ax.text(0.5, 0.5, 'Recovered Codeword (c̃)\nNot Available', 

182 ha='center', va='center', fontsize=12, transform=ax.transAxes) 

183 

184 ax.set_title(title, fontsize=10) 

185 ax.axis('off') 

186 return ax 

187 

188 def draw_difference_map(self, 

189 title: str = "Difference Map", 

190 cmap: str = "hot", 

191 use_color_bar: bool = True, 

192 channel: int = 0, 

193 ax: Optional[Axes] = None, 

194 **kwargs) -> Axes: 

195 """ 

196 Draw difference map between watermarked and inverted latents 

197  

198 Parameters: 

199 title (str): The title of the plot 

200 cmap (str): The colormap to use 

201 use_color_bar (bool): Whether to display the colorbar 

202 channel (int): The channel to visualize 

203 ax (Axes): The axes to plot on 

204  

205 Returns: 

206 Axes: The plotted axes 

207 """ 

208 if (hasattr(self.data, 'watermarked_latents') and self.data.watermarked_latents is not None and 

209 hasattr(self.data, 'inverted_latents') and self.data.inverted_latents is not None): 

210 

211 wm_latents = self._get_latent_data(self.data.watermarked_latents, channel=channel).cpu().numpy() 

212 inv_latents = self._get_latent_data(self.data.inverted_latents, channel=channel).cpu().numpy() 

213 

214 diff_map = np.abs(wm_latents - inv_latents) 

215 im = ax.imshow(diff_map, cmap=cmap, aspect='equal', **kwargs) 

216 

217 if use_color_bar: 

218 plt.colorbar(im, ax=ax, shrink=0.8) 

219 else: 

220 ax.text(0.5, 0.5, 'Difference Map\nNot Available', 

221 ha='center', va='center', fontsize=12, transform=ax.transAxes) 

222 

223 ax.set_title(title, fontsize=12) 

224 ax.set_xlabel('Width') 

225 ax.set_ylabel('Height') 

226 return ax