Coverage for visualize / videomark / video_mark_visualizer.py: 90.70%

86 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16import torch 

17import matplotlib.pyplot as plt 

18from matplotlib.axes import Axes 

19from matplotlib.gridspec import GridSpecFromSubplotSpec 

20import numpy as np 

21from typing import Optional 

22from visualize.base import BaseVisualizer 

23from visualize.data_for_visualization import DataForVisualization 

24 

25 

26class VideoMarkVisualizer(BaseVisualizer): 

27 """VideoMark watermark visualization class. 

28  

29 This visualizer handles watermark visualization for VideoShield algorithm, 

30 which extends Gaussian Shading to the video domain by adding frame dimensions. 

31  

32 Key Members for VideoMarkVisualizer: 

33 - self.data.orig_watermarked_latents: [B, C, F, H, W] 

34 - self.data.reversed_latents: List[[B, C, F, H, W]] 

35 """ 

36 

37 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1, is_video: bool = True): 

38 super().__init__(data_for_visualization, dpi, watermarking_step, is_video) 

39 # Pre-detach all tensors while maintaining device compatibility 

40 if hasattr(self.data, 'watermarked_latents') and self.data.watermarked_latents is not None: 

41 self.data.watermarked_latents = self.data.watermarked_latents.detach() 

42 if hasattr(self.data, 'orig_latents') and self.data.orig_latents is not None: 

43 self.data.orig_latents = self.data.orig_latents.detach() 

44 if hasattr(self.data, 'inverted_latents') and self.data.inverted_latents is not None: 

45 self.data.inverted_latents = self.data.inverted_latents.detach() 

46 if hasattr(self.data, 'prc_codeword') and self.data.prc_codeword is not None: 

47 self.data.prc_codeword = self.data.prc_codeword.detach() 

48 if hasattr(self.data, 'generator_matrix') and self.data.generator_matrix is not None: 

49 self.data.generator_matrix = self.data.generator_matrix.detach() 

50 

51 def draw_watermarked_video_frames(self, 

52 num_frames: int = 4, 

53 title: str = "Watermarked Video Frames", 

54 ax: Optional[Axes] = None) -> Axes: 

55 """ 

56 Draw multiple frames from the watermarked video. 

57 

58 DEPRECATED: 

59 This method is deprecated and will be removed in a future version. 

60 Please use `draw_watermarked_image` instead. 

61 

62 This method displays a grid of video frames to show the temporal 

63 consistency of the watermarked video. 

64 

65 Args: 

66 num_frames: Number of frames to display (default: 4) 

67 title: The title of the plot 

68 ax: The axes to plot on 

69 

70 Returns: 

71 The plotted axes 

72 """ 

73 return self._draw_video_frames( 

74 title=title, 

75 num_frames=num_frames, 

76 ax=ax 

77 ) 

78 

79 def draw_generator_matrix(self, 

80 title: str = "Generator Matrix G", 

81 cmap: str = "Blues", 

82 use_color_bar: bool = True, 

83 max_display_size: int = 50, 

84 ax: Optional[Axes] = None, 

85 **kwargs) -> Axes: 

86 """ 

87 Draw the generator matrix visualization 

88  

89 Parameters: 

90 title (str): The title of the plot 

91 cmap (str): The colormap to use 

92 use_color_bar (bool): Whether to display the colorbar 

93 max_display_size (int): Maximum size to display (for large matrices) 

94 ax (Axes): The axes to plot on 

95  

96 Returns: 

97 Axes: The plotted axes 

98 """ 

99 if hasattr(self.data, 'generator_matrix') and self.data.generator_matrix is not None: 

100 gen_matrix = self.data.generator_matrix.cpu().numpy() 

101 

102 # Show a sample of the matrix if it's too large 

103 if gen_matrix.shape[0] > max_display_size or gen_matrix.shape[1] > max_display_size: 

104 sample_size = min(max_display_size, min(gen_matrix.shape)) 

105 matrix_sample = gen_matrix[:sample_size, :sample_size] 

106 title += f" (Sample {sample_size}x{sample_size})" 

107 else: 

108 matrix_sample = gen_matrix 

109 

110 im = ax.imshow(matrix_sample, cmap=cmap, aspect='auto', **kwargs) 

111 

112 if use_color_bar: 

113 plt.colorbar(im, ax=ax, shrink=0.8) 

114 else: 

115 ax.text(0.5, 0.5, 'Generator Matrix\nNot Available', 

116 ha='center', va='center', fontsize=12, transform=ax.transAxes) 

117 

118 ax.set_title(title, fontsize=10) 

119 ax.set_xlabel('Columns') 

120 ax.set_ylabel('Rows') 

121 return ax 

122 

123 def draw_codeword(self, 

124 title: str = "VideoMark Codeword", 

125 cmap: str = "viridis", 

126 use_color_bar: bool = True, 

127 ax: Optional[Axes] = None, 

128 **kwargs) -> Axes: 

129 """ 

130 Draw the PRC codeword visualization 

131  

132 Parameters: 

133 title (str): The title of the plot 

134 cmap (str): The colormap to use 

135 use_color_bar (bool): Whether to display the colorbar 

136 ax (Axes): The axes to plot on 

137  

138 Returns: 

139 Axes: The plotted axes 

140 """ 

141 if hasattr(self.data, 'prc_codeword') and self.data.prc_codeword is not None: 

142 codeword = self.data.prc_codeword[0].cpu().numpy()#Get the first-frame codeword for visualization 

143 

144 # If 1D, reshape for visualization 

145 if len(codeword.shape) == 1: 

146 # Create a reasonable 2D shape 

147 length = len(codeword) 

148 height = int(np.sqrt(length)) 

149 width = length // height 

150 if height * width < length: 

151 width += 1 

152 # Pad if necessary 

153 padded_codeword = np.zeros(height * width) 

154 padded_codeword[:length] = codeword 

155 codeword = padded_codeword.reshape(height, width) 

156 

157 im = ax.imshow(codeword, cmap=cmap, aspect='equal', **kwargs) 

158 

159 if use_color_bar: 

160 plt.colorbar(im, ax=ax, shrink=0.8) 

161 else: 

162 ax.text(0.5, 0.5, 'PRC Codeword\nNot Available', 

163 ha='center', va='center', fontsize=12, transform=ax.transAxes) 

164 

165 ax.set_title(title, fontsize=12) 

166 return ax 

167 

168 def draw_recovered_codeword(self, 

169 title: str = "Recovered Codeword (c̃)", 

170 cmap: str = "viridis", 

171 use_color_bar: bool = True, 

172 vmin: float = -1.0, 

173 vmax: float = 1.0, 

174 ax: Optional[Axes] = None, 

175 **kwargs) -> Axes: 

176 

177 if hasattr(self.data, 'recovered_prc') and self.data.recovered_prc is not None: 

178 recovered_codeword = self.data.recovered_prc.cpu().numpy().flatten() 

179 length = len(recovered_codeword) 

180 

181 side = int(length ** 0.5) 

182 

183 if side * side == length: 

184 codeword_2d = recovered_codeword.reshape((side, side)) 

185 

186 im = ax.imshow(codeword_2d, cmap=cmap, vmin=vmin, vmax=vmax, 

187 aspect='equal', **kwargs) 

188 

189 if use_color_bar: 

190 cbar = ax.figure.colorbar(im, ax=ax, shrink=0.8) 

191 cbar.set_label('Codeword Value', fontsize=8) 

192 else: 

193 ax.text(0.5, 0.5, 

194 f'Recovered Codeword\nLength = {length}\nCannot reshape to square', 

195 ha='center', va='center', fontsize=12, transform=ax.transAxes) 

196 

197 else: 

198 ax.text(0.5, 0.5, 'Recovered Codeword (c̃)\nNot Available', 

199 ha='center', va='center', fontsize=12, transform=ax.transAxes) 

200 

201 ax.set_title(title, fontsize=10) 

202 ax.axis('off') 

203 return ax 

204 

205 

206 def draw_difference_map(self, 

207 title: str = "Difference Map", 

208 cmap: str = "hot", 

209 use_color_bar: bool = True, 

210 channel: int = 0, 

211 frame: int =0, 

212 ax: Optional[Axes] = None, 

213 **kwargs) -> Axes: 

214 """ 

215 Draw difference map between watermarked and inverted latents 

216  

217 Parameters: 

218 title (str): The title of the plot 

219 cmap (str): The colormap to use 

220 use_color_bar (bool): Whether to display the colorbar 

221 channel (int): The channel to visualize 

222 ax (Axes): The axes to plot on 

223  

224 Returns: 

225 Axes: The plotted axes 

226 """ 

227 if (hasattr(self.data, 'watermarked_latents') and self.data.watermarked_latents is not None and 

228 hasattr(self.data, 'inverted_latents') and self.data.inverted_latents is not None): 

229 

230 wm_latents = self._get_latent_data(self.data.watermarked_latents, channel=channel, frame=frame).cpu().numpy() 

231 inv_latents = self._get_latent_data(self.data.inverted_latents, channel=channel, frame=frame).cpu().numpy() 

232 

233 diff_map = np.abs(wm_latents - inv_latents) 

234 im = ax.imshow(diff_map, cmap=cmap, aspect='equal', **kwargs) 

235 

236 if use_color_bar: 

237 plt.colorbar(im, ax=ax, shrink=0.8) 

238 else: 

239 ax.text(0.5, 0.5, 'Difference Map\nNot Available', 

240 ha='center', va='center', fontsize=12, transform=ax.transAxes) 

241 

242 ax.set_title(title, fontsize=12) 

243 ax.set_xlabel('Width') 

244 ax.set_ylabel('Height') 

245 return ax