Coverage for visualize / videomark / video_mark_visualizer.py: 90.70%
86 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16import torch
17import matplotlib.pyplot as plt
18from matplotlib.axes import Axes
19from matplotlib.gridspec import GridSpecFromSubplotSpec
20import numpy as np
21from typing import Optional
22from visualize.base import BaseVisualizer
23from visualize.data_for_visualization import DataForVisualization
26class VideoMarkVisualizer(BaseVisualizer):
27 """VideoMark watermark visualization class.
29 This visualizer handles watermark visualization for VideoShield algorithm,
30 which extends Gaussian Shading to the video domain by adding frame dimensions.
32 Key Members for VideoMarkVisualizer:
33 - self.data.orig_watermarked_latents: [B, C, F, H, W]
34 - self.data.reversed_latents: List[[B, C, F, H, W]]
35 """
37 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1, is_video: bool = True):
38 super().__init__(data_for_visualization, dpi, watermarking_step, is_video)
39 # Pre-detach all tensors while maintaining device compatibility
40 if hasattr(self.data, 'watermarked_latents') and self.data.watermarked_latents is not None:
41 self.data.watermarked_latents = self.data.watermarked_latents.detach()
42 if hasattr(self.data, 'orig_latents') and self.data.orig_latents is not None:
43 self.data.orig_latents = self.data.orig_latents.detach()
44 if hasattr(self.data, 'inverted_latents') and self.data.inverted_latents is not None:
45 self.data.inverted_latents = self.data.inverted_latents.detach()
46 if hasattr(self.data, 'prc_codeword') and self.data.prc_codeword is not None:
47 self.data.prc_codeword = self.data.prc_codeword.detach()
48 if hasattr(self.data, 'generator_matrix') and self.data.generator_matrix is not None:
49 self.data.generator_matrix = self.data.generator_matrix.detach()
51 def draw_watermarked_video_frames(self,
52 num_frames: int = 4,
53 title: str = "Watermarked Video Frames",
54 ax: Optional[Axes] = None) -> Axes:
55 """
56 Draw multiple frames from the watermarked video.
58 DEPRECATED:
59 This method is deprecated and will be removed in a future version.
60 Please use `draw_watermarked_image` instead.
62 This method displays a grid of video frames to show the temporal
63 consistency of the watermarked video.
65 Args:
66 num_frames: Number of frames to display (default: 4)
67 title: The title of the plot
68 ax: The axes to plot on
70 Returns:
71 The plotted axes
72 """
73 return self._draw_video_frames(
74 title=title,
75 num_frames=num_frames,
76 ax=ax
77 )
79 def draw_generator_matrix(self,
80 title: str = "Generator Matrix G",
81 cmap: str = "Blues",
82 use_color_bar: bool = True,
83 max_display_size: int = 50,
84 ax: Optional[Axes] = None,
85 **kwargs) -> Axes:
86 """
87 Draw the generator matrix visualization
89 Parameters:
90 title (str): The title of the plot
91 cmap (str): The colormap to use
92 use_color_bar (bool): Whether to display the colorbar
93 max_display_size (int): Maximum size to display (for large matrices)
94 ax (Axes): The axes to plot on
96 Returns:
97 Axes: The plotted axes
98 """
99 if hasattr(self.data, 'generator_matrix') and self.data.generator_matrix is not None:
100 gen_matrix = self.data.generator_matrix.cpu().numpy()
102 # Show a sample of the matrix if it's too large
103 if gen_matrix.shape[0] > max_display_size or gen_matrix.shape[1] > max_display_size:
104 sample_size = min(max_display_size, min(gen_matrix.shape))
105 matrix_sample = gen_matrix[:sample_size, :sample_size]
106 title += f" (Sample {sample_size}x{sample_size})"
107 else:
108 matrix_sample = gen_matrix
110 im = ax.imshow(matrix_sample, cmap=cmap, aspect='auto', **kwargs)
112 if use_color_bar:
113 plt.colorbar(im, ax=ax, shrink=0.8)
114 else:
115 ax.text(0.5, 0.5, 'Generator Matrix\nNot Available',
116 ha='center', va='center', fontsize=12, transform=ax.transAxes)
118 ax.set_title(title, fontsize=10)
119 ax.set_xlabel('Columns')
120 ax.set_ylabel('Rows')
121 return ax
123 def draw_codeword(self,
124 title: str = "VideoMark Codeword",
125 cmap: str = "viridis",
126 use_color_bar: bool = True,
127 ax: Optional[Axes] = None,
128 **kwargs) -> Axes:
129 """
130 Draw the PRC codeword visualization
132 Parameters:
133 title (str): The title of the plot
134 cmap (str): The colormap to use
135 use_color_bar (bool): Whether to display the colorbar
136 ax (Axes): The axes to plot on
138 Returns:
139 Axes: The plotted axes
140 """
141 if hasattr(self.data, 'prc_codeword') and self.data.prc_codeword is not None:
142 codeword = self.data.prc_codeword[0].cpu().numpy()#Get the first-frame codeword for visualization
144 # If 1D, reshape for visualization
145 if len(codeword.shape) == 1:
146 # Create a reasonable 2D shape
147 length = len(codeword)
148 height = int(np.sqrt(length))
149 width = length // height
150 if height * width < length:
151 width += 1
152 # Pad if necessary
153 padded_codeword = np.zeros(height * width)
154 padded_codeword[:length] = codeword
155 codeword = padded_codeword.reshape(height, width)
157 im = ax.imshow(codeword, cmap=cmap, aspect='equal', **kwargs)
159 if use_color_bar:
160 plt.colorbar(im, ax=ax, shrink=0.8)
161 else:
162 ax.text(0.5, 0.5, 'PRC Codeword\nNot Available',
163 ha='center', va='center', fontsize=12, transform=ax.transAxes)
165 ax.set_title(title, fontsize=12)
166 return ax
168 def draw_recovered_codeword(self,
169 title: str = "Recovered Codeword (c̃)",
170 cmap: str = "viridis",
171 use_color_bar: bool = True,
172 vmin: float = -1.0,
173 vmax: float = 1.0,
174 ax: Optional[Axes] = None,
175 **kwargs) -> Axes:
177 if hasattr(self.data, 'recovered_prc') and self.data.recovered_prc is not None:
178 recovered_codeword = self.data.recovered_prc.cpu().numpy().flatten()
179 length = len(recovered_codeword)
181 side = int(length ** 0.5)
183 if side * side == length:
184 codeword_2d = recovered_codeword.reshape((side, side))
186 im = ax.imshow(codeword_2d, cmap=cmap, vmin=vmin, vmax=vmax,
187 aspect='equal', **kwargs)
189 if use_color_bar:
190 cbar = ax.figure.colorbar(im, ax=ax, shrink=0.8)
191 cbar.set_label('Codeword Value', fontsize=8)
192 else:
193 ax.text(0.5, 0.5,
194 f'Recovered Codeword\nLength = {length}\nCannot reshape to square',
195 ha='center', va='center', fontsize=12, transform=ax.transAxes)
197 else:
198 ax.text(0.5, 0.5, 'Recovered Codeword (c̃)\nNot Available',
199 ha='center', va='center', fontsize=12, transform=ax.transAxes)
201 ax.set_title(title, fontsize=10)
202 ax.axis('off')
203 return ax
206 def draw_difference_map(self,
207 title: str = "Difference Map",
208 cmap: str = "hot",
209 use_color_bar: bool = True,
210 channel: int = 0,
211 frame: int =0,
212 ax: Optional[Axes] = None,
213 **kwargs) -> Axes:
214 """
215 Draw difference map between watermarked and inverted latents
217 Parameters:
218 title (str): The title of the plot
219 cmap (str): The colormap to use
220 use_color_bar (bool): Whether to display the colorbar
221 channel (int): The channel to visualize
222 ax (Axes): The axes to plot on
224 Returns:
225 Axes: The plotted axes
226 """
227 if (hasattr(self.data, 'watermarked_latents') and self.data.watermarked_latents is not None and
228 hasattr(self.data, 'inverted_latents') and self.data.inverted_latents is not None):
230 wm_latents = self._get_latent_data(self.data.watermarked_latents, channel=channel, frame=frame).cpu().numpy()
231 inv_latents = self._get_latent_data(self.data.inverted_latents, channel=channel, frame=frame).cpu().numpy()
233 diff_map = np.abs(wm_latents - inv_latents)
234 im = ax.imshow(diff_map, cmap=cmap, aspect='equal', **kwargs)
236 if use_color_bar:
237 plt.colorbar(im, ax=ax, shrink=0.8)
238 else:
239 ax.text(0.5, 0.5, 'Difference Map\nNot Available',
240 ha='center', va='center', fontsize=12, transform=ax.transAxes)
242 ax.set_title(title, fontsize=12)
243 ax.set_xlabel('Width')
244 ax.set_ylabel('Height')
245 return ax