Coverage for visualize / gs / gs_visualizer.py: 94.85%
97 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1import torch
2import matplotlib.pyplot as plt
3from matplotlib.axes import Axes
4from matplotlib.gridspec import GridSpecFromSubplotSpec
5from typing import Optional
6import numpy as np
7from visualize.base import BaseVisualizer
8from visualize.data_for_visualization import DataForVisualization
9from Crypto.Cipher import ChaCha20
11class GaussianShadingVisualizer(BaseVisualizer):
12 """Gaussian Shading watermark visualization class"""
14 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1):
15 super().__init__(data_for_visualization, dpi, watermarking_step)
17 def _stream_key_decrypt(self, reversed_m):
18 """Decrypt the watermark using ChaCha20 cipher."""
19 cipher = ChaCha20.new(key=self.data.chacha_key, nonce=self.data.chacha_nonce)
20 sd_byte = cipher.decrypt(np.packbits(reversed_m).tobytes())
21 sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8))
22 sd_tensor = torch.from_numpy(sd_bit).reshape(1, 4, 64, 64).to(torch.uint8)
23 return sd_tensor.to(self.data.device)
25 def _diffusion_inverse(self, reversed_sd):
26 """Inverse the diffusion process to extract the watermark."""
27 ch_stride = 4 // self.data.channel_copy
28 hw_stride = 64 // self.data.hw_copy
29 ch_list = [ch_stride] * self.data.channel_copy
30 hw_list = [hw_stride] * self.data.hw_copy
31 split_dim1 = torch.cat(torch.split(reversed_sd, tuple(ch_list), dim=1), dim=0)
32 split_dim2 = torch.cat(torch.split(split_dim1, tuple(hw_list), dim=2), dim=0)
33 split_dim3 = torch.cat(torch.split(split_dim2, tuple(hw_list), dim=3), dim=0)
34 vote = torch.sum(split_dim3, dim=0).clone()
35 vote[vote <= self.data.vote_threshold] = 0
36 vote[vote > self.data.vote_threshold] = 1
37 return vote
39 def draw_watermark_bits(self,
40 channel: Optional[int] = None,
41 title: str = "Original Watermark Bits",
42 cmap: str = "binary",
43 ax: Optional[Axes] = None) -> Axes:
44 """
45 Draw the original watermark bits.(sd in GS class). draw ch // channel_copy images in one ax.
47 Parameters:
48 channel (Optional[int]): The channel to visualize. If None, all channels are shown.
49 title (str): The title of the plot.
50 cmap (str): The colormap to use.
51 ax (Axes): The axes to plot on.
53 Returns:
54 Axes: The plotted axes.
55 """
56 # Step 1: reshape self.data.watermark to [1, 4 // self.data.channel_copy, 64 // self.data.hw_copy, 64 // self.data.hw_copy]
57 watermark = self.data.watermark.reshape(1, 4 // self.data.channel_copy, 64 // self.data.hw_copy, 64 // self.data.hw_copy)
59 if channel is not None:
60 # Single channel visualization
61 if channel >= 4 // self.data.channel_copy:
62 raise ValueError(f"Channel {channel} is out of range. Max channel: {4 // self.data.channel_copy - 1}")
64 watermark_data = watermark[0, channel].cpu().numpy()
65 im = ax.imshow(watermark_data, cmap=cmap, vmin=0, vmax=1, interpolation='nearest')
66 if title != "":
67 ax.set_title(f"{title} - Channel {channel}", fontsize=10)
68 ax.axis('off')
70 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0)
71 cbar.ax.set_visible(False)
72 else:
73 # Multi-channel visualization
74 # Step 2: draw watermark for (4 // self.data.channel_copy) images in this ax
75 num_channels = 4 // self.data.channel_copy
77 # Calculate grid layout
78 rows = int(np.ceil(np.sqrt(num_channels)))
79 cols = int(np.ceil(num_channels / rows))
81 # Clear the axis and set title
82 ax.clear()
83 if title != "":
84 ax.set_title(title, pad=20)
85 ax.axis('off')
87 # Use gridspec for better control
88 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(),
89 wspace=0.3, hspace=0.4)
91 # Create subplots
92 for i in range(num_channels):
93 row_idx = i // cols
94 col_idx = i % cols
96 # Create subplot using gridspec
97 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
99 # Draw the watermark channel
100 watermark_data = watermark[0, i].cpu().numpy()
101 sub_ax.imshow(watermark_data, cmap=cmap, vmin=0, vmax=1, interpolation='nearest')
102 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3)
103 sub_ax.axis('off')
105 return ax
107 def draw_reconstructed_watermark_bits(self,
108 channel: Optional[int] = None,
109 title: str = "Reconstructed Watermark Bits",
110 cmap: str = "binary",
111 ax: Optional[Axes] = None) -> Axes:
112 """
113 Draw the reconstructed watermark bits.(reversed_latents in GS class). draw ch // channel_copy images in one ax.
115 Parameters:
116 channel (Optional[int]): The channel to visualize. If None, all channels are shown.
117 title (str): The title of the plot.
118 cmap (str): The colormap to use.
119 ax (Axes): The axes to plot on.
121 Returns:
122 Axes: The plotted axes.
123 """
124 # Step 1: reconstruct the watermark bits
125 reversed_latent = self.data.reversed_latents[self.watermarking_step]
127 reversed_m = (reversed_latent > 0).int()
128 if self.data.chacha:
129 reversed_sd = self._stream_key_decrypt(reversed_m.flatten().cpu().numpy())
130 else:
131 reversed_sd = (reversed_m + self.data.key) % 2
133 reversed_watermark = self._diffusion_inverse(reversed_sd)
134 bit_acc = (reversed_watermark == self.data.watermark).float().mean().item()
135 reconstructed_watermark = reversed_watermark.reshape(1, 4 // self.data.channel_copy, 64 // self.data.hw_copy, 64 // self.data.hw_copy)
137 if channel is not None:
138 # Single channel visualization
139 if channel >= 4 // self.data.channel_copy:
140 raise ValueError(f"Channel {channel} is out of range. Max channel: {4 // self.data.channel_copy - 1}")
142 reconstructed_watermark_data = reconstructed_watermark[0, channel].cpu().numpy()
143 im = ax.imshow(reconstructed_watermark_data, cmap=cmap, vmin=0, vmax=1, interpolation='nearest')
144 if title != "":
145 ax.set_title(f"{title} - Channel {channel} (Bit Acc: {bit_acc:.3f})", fontsize=10)
146 else:
147 ax.set_title(f"Channel {channel} (Bit Acc: {bit_acc:.3f})", fontsize=10)
148 ax.axis('off')
150 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0)
151 cbar.ax.set_visible(False)
152 else:
153 # Multi-channel visualization
154 # Step 2: draw reconstructed_watermark for (4 // self.data.channel_copy) images in this ax(add Bit_acc to the title)
155 num_channels = 4 // self.data.channel_copy
157 # Calculate grid layout
158 rows = int(np.ceil(np.sqrt(num_channels)))
159 cols = int(np.ceil(num_channels / rows))
161 # Clear the axis and set title with bit accuracy
162 ax.clear()
163 if title != "":
164 ax.set_title(f'{title} (Bit Acc: {bit_acc:.3f})', pad=20)
165 else:
166 ax.set_title(f'(Bit Acc: {bit_acc:.3f})', pad=20)
167 ax.axis('off')
169 # Use gridspec for better control
170 gs = GridSpecFromSubplotSpec(rows, cols, subplot_spec=ax.get_subplotspec(),
171 wspace=0.3, hspace=0.4)
173 # Create subplots
174 for i in range(num_channels):
175 row_idx = i // cols
176 col_idx = i % cols
178 # Create subplot using gridspec
179 sub_ax = ax.figure.add_subplot(gs[row_idx, col_idx])
181 # Draw the reconstructed watermark channel
182 reconstructed_watermark_data = reconstructed_watermark[0, i].cpu().numpy()
183 sub_ax.imshow(reconstructed_watermark_data, cmap=cmap, vmin=0, vmax=1, interpolation='nearest')
184 sub_ax.set_title(f'Channel {i}', fontsize=8, pad=3)
185 sub_ax.axis('off')
187 return ax