Coverage for visualize / prc / prc_visualizer.py: 91.46%
82 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16from typing import Optional
17import torch
18import matplotlib.pyplot as plt
19from matplotlib.axes import Axes
20from matplotlib.gridspec import GridSpec
21import numpy as np
22from visualize.base import BaseVisualizer
23from visualize.data_for_visualization import DataForVisualization
25class PRCVisualizer(BaseVisualizer):
26 """PRC (Pseudorandom Codes) watermark visualizer"""
28 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1):
29 """
30 Initialize PRC visualizer
32 Args:
33 data_for_visualization: DataForVisualization object containing visualization data
34 dpi: DPI for visualization (default: 300)
35 watermarking_step: The step for inserting the watermark (default: -1)
36 """
37 super().__init__(data_for_visualization, dpi, watermarking_step)
39 # Pre-detach all tensors while maintaining device compatibility
40 if hasattr(self.data, 'watermarked_latents') and self.data.watermarked_latents is not None:
41 self.data.watermarked_latents = self.data.watermarked_latents.detach()
42 if hasattr(self.data, 'orig_latents') and self.data.orig_latents is not None:
43 self.data.orig_latents = self.data.orig_latents.detach()
44 if hasattr(self.data, 'inverted_latents') and self.data.inverted_latents is not None:
45 self.data.inverted_latents = self.data.inverted_latents.detach()
46 if hasattr(self.data, 'prc_codeword') and self.data.prc_codeword is not None:
47 self.data.prc_codeword = self.data.prc_codeword.detach()
48 if hasattr(self.data, 'generator_matrix') and self.data.generator_matrix is not None:
49 self.data.generator_matrix = self.data.generator_matrix.detach()
51 def draw_generator_matrix(self,
52 title: str = "Generator Matrix G",
53 cmap: str = "Blues",
54 use_color_bar: bool = True,
55 max_display_size: int = 50,
56 ax: Optional[Axes] = None,
57 **kwargs) -> Axes:
58 """
59 Draw the generator matrix visualization
61 Parameters:
62 title (str): The title of the plot
63 cmap (str): The colormap to use
64 use_color_bar (bool): Whether to display the colorbar
65 max_display_size (int): Maximum size to display (for large matrices)
66 ax (Axes): The axes to plot on
68 Returns:
69 Axes: The plotted axes
70 """
71 if hasattr(self.data, 'generator_matrix') and self.data.generator_matrix is not None:
72 gen_matrix = self.data.generator_matrix.cpu().numpy()
74 # Show a sample of the matrix if it's too large
75 if gen_matrix.shape[0] > max_display_size or gen_matrix.shape[1] > max_display_size:
76 sample_size = min(max_display_size, min(gen_matrix.shape))
77 matrix_sample = gen_matrix[:sample_size, :sample_size]
78 title += f" (Sample {sample_size}x{sample_size})"
79 else:
80 matrix_sample = gen_matrix
82 im = ax.imshow(matrix_sample, cmap=cmap, aspect='auto', **kwargs)
84 if use_color_bar:
85 plt.colorbar(im, ax=ax, shrink=0.8)
86 else:
87 ax.text(0.5, 0.5, 'Generator Matrix\nNot Available',
88 ha='center', va='center', fontsize=12, transform=ax.transAxes)
90 ax.set_title(title, fontsize=10)
91 ax.set_xlabel('Columns')
92 ax.set_ylabel('Rows')
93 return ax
95 def draw_codeword(self,
96 title: str = "PRC Codeword",
97 cmap: str = "viridis",
98 use_color_bar: bool = True,
99 ax: Optional[Axes] = None,
100 **kwargs) -> Axes:
101 """
102 Draw the PRC codeword visualization
104 Parameters:
105 title (str): The title of the plot
106 cmap (str): The colormap to use
107 use_color_bar (bool): Whether to display the colorbar
108 ax (Axes): The axes to plot on
110 Returns:
111 Axes: The plotted axes
112 """
113 if hasattr(self.data, 'prc_codeword') and self.data.prc_codeword is not None:
114 codeword = self.data.prc_codeword.cpu().numpy()
116 # If 1D, reshape for visualization
117 if len(codeword.shape) == 1:
118 # Create a reasonable 2D shape
119 length = len(codeword)
120 height = int(np.sqrt(length))
121 width = length // height
122 if height * width < length:
123 width += 1
124 # Pad if necessary
125 padded_codeword = np.zeros(height * width)
126 padded_codeword[:length] = codeword
127 codeword = padded_codeword.reshape(height, width)
129 im = ax.imshow(codeword, cmap=cmap, aspect='equal', **kwargs)
131 if use_color_bar:
132 plt.colorbar(im, ax=ax, shrink=0.8)
133 else:
134 ax.text(0.5, 0.5, 'PRC Codeword\nNot Available',
135 ha='center', va='center', fontsize=12, transform=ax.transAxes)
137 ax.set_title(title, fontsize=12)
138 return ax
140 def draw_recovered_codeword(self,
141 title: str = "Recovered Codeword (c̃)",
142 cmap: str = "viridis",
143 use_color_bar: bool = True,
144 vmin: float = -1.0,
145 vmax: float = 1.0,
146 ax: Optional[Axes] = None,
147 **kwargs) -> Axes:
148 """
149 Draw the recovered codeword (c̃) from PRC detection
151 This visualizes the recovered codeword from prc_detection.recovered_prc
153 Parameters:
154 title (str): The title of the plot
155 cmap (str): The colormap to use
156 use_color_bar (bool): Whether to display the colorbar
157 vmin (float): Minimum value for colormap (-1.0)
158 vmax (float): Maximum value for colormap (1.0)
159 ax (Axes): The axes to plot on
161 Returns:
162 Axes: The plotted axes
163 """
164 if hasattr(self.data, 'recovered_prc') and self.data.recovered_prc is not None:
165 recovered_codeword = self.data.recovered_prc.cpu().numpy().flatten()
167 # Ensure it's the expected length
168 if len(recovered_codeword) == 16384:
169 # Reshape to 2D for visualization (128x128 = 16384)
170 codeword_2d = recovered_codeword.reshape((128, 128))
172 im = ax.imshow(codeword_2d, cmap=cmap, vmin=vmin, vmax=vmax, aspect='equal', **kwargs)
174 if use_color_bar:
175 cbar = ax.figure.colorbar(im, ax=ax, shrink=0.8)
176 cbar.set_label('Codeword Value', fontsize=8)
177 else:
178 ax.text(0.5, 0.5, f'Recovered Codeword\nUnexpected Length: {len(recovered_codeword)}\n(Expected: 16384)',
179 ha='center', va='center', fontsize=12, transform=ax.transAxes)
180 else:
181 ax.text(0.5, 0.5, 'Recovered Codeword (c̃)\nNot Available',
182 ha='center', va='center', fontsize=12, transform=ax.transAxes)
184 ax.set_title(title, fontsize=10)
185 ax.axis('off')
186 return ax
188 def draw_difference_map(self,
189 title: str = "Difference Map",
190 cmap: str = "hot",
191 use_color_bar: bool = True,
192 channel: int = 0,
193 ax: Optional[Axes] = None,
194 **kwargs) -> Axes:
195 """
196 Draw difference map between watermarked and inverted latents
198 Parameters:
199 title (str): The title of the plot
200 cmap (str): The colormap to use
201 use_color_bar (bool): Whether to display the colorbar
202 channel (int): The channel to visualize
203 ax (Axes): The axes to plot on
205 Returns:
206 Axes: The plotted axes
207 """
208 if (hasattr(self.data, 'watermarked_latents') and self.data.watermarked_latents is not None and
209 hasattr(self.data, 'inverted_latents') and self.data.inverted_latents is not None):
211 wm_latents = self._get_latent_data(self.data.watermarked_latents, channel=channel).cpu().numpy()
212 inv_latents = self._get_latent_data(self.data.inverted_latents, channel=channel).cpu().numpy()
214 diff_map = np.abs(wm_latents - inv_latents)
215 im = ax.imshow(diff_map, cmap=cmap, aspect='equal', **kwargs)
217 if use_color_bar:
218 plt.colorbar(im, ax=ax, shrink=0.8)
219 else:
220 ax.text(0.5, 0.5, 'Difference Map\nNot Available',
221 ha='center', va='center', fontsize=12, transform=ax.transAxes)
223 ax.set_title(title, fontsize=12)
224 ax.set_xlabel('Width')
225 ax.set_ylabel('Height')
226 return ax