Coverage for visualize / seal / seal_visualizer.py: 95.31%
64 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from typing import Optional
2import torch
3import matplotlib.pyplot as plt
4from matplotlib.axes import Axes
5from matplotlib.gridspec import GridSpecFromSubplotSpec
6import numpy as np
7import math
8from visualize.base import BaseVisualizer
9from visualize.data_for_visualization import DataForVisualization
11class SEALVisualizer(BaseVisualizer):
12 """SEAL watermark visualization class"""
14 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1):
15 super().__init__(data_for_visualization, dpi, watermarking_step)
17 def draw_embedding_distributions(self,
18 title: str = "Embedding Distributions",
19 ax: Optional[Axes] = None,
20 show_legend: bool = True,
21 show_label: bool = True,
22 show_axis: bool = True) -> Axes:
23 """
24 Draw histogram of embedding distributions comparison(original_embedding vs inspected_embedding).
26 Parameters:
27 title (str): The title of the plot.
28 ax (plt.Axes): The axes to plot on.
29 show_legend (bool): Whether to show the legend. Default: True.
30 show_label (bool): Whether to show axis labels. Default: True.
31 show_axis (bool): Whether to show axis ticks and labels. Default: True.
33 Returns:
34 Axes: The plotted axes.
35 """
36 original_embedding = self.data.original_embedding
37 inspected_embedding = self.data.inspected_embedding
39 # Convert to numpy arrays and flatten if necessary
40 original_data = original_embedding.cpu().numpy().flatten()
41 inspected_data = inspected_embedding.cpu().numpy().flatten()
43 # Create overlapping histograms with transparency
44 ax.hist(original_data, bins=50, alpha=0.7, color='blue',
45 label='Original Embedding', density=True, edgecolor='darkblue', linewidth=0.5)
46 ax.hist(inspected_data, bins=50, alpha=0.7, color='red',
47 label='Inspected Embedding', density=True, edgecolor='darkred', linewidth=0.5)
49 # Set labels and title
50 if title != "":
51 ax.set_title(title, fontsize=12)
52 if show_label:
53 ax.set_xlabel('Embedding Values')
54 ax.set_ylabel('Density')
55 if show_legend:
56 ax.legend()
57 if not show_axis:
58 ax.set_xticks([])
59 ax.set_yticks([])
60 ax.grid(True, alpha=0.3)
62 # Create a hidden colorbar for nice visualization
63 im = ax.scatter([], [], c=[], cmap='viridis')
64 cbar = ax.figure.colorbar(im, ax=ax, alpha=0.0)
65 cbar.ax.set_visible(False)
67 return ax
70 def draw_patch_diff(self,
71 title: str = "Patch Difference",
72 cmap: str = 'RdBu',
73 use_color_bar: bool = True,
74 vmin: Optional[float] = None,
75 vmax: Optional[float] = None,
76 show_number: bool = False,
77 ax: Optional[Axes] = None,
78 **kwargs) -> Axes:
79 """
80 Draw the difference between the reference_noise and reversed_latents in patch.
82 Parameters:
83 title (str): The title of the plot.
84 cmap (str): The colormap to use.
85 use_color_bar (bool): Whether to display the colorbar.
86 vmin (Optional[float]): Minimum value for colormap normalization.
87 vmax (Optional[float]): Maximum value for colormap normalization.
88 show_number (bool): Whether to display numerical values on each patch. Default: False.
89 ax (plt.Axes): The axes to plot on.
91 Returns:
92 plt.axes.Axes: The plotted axes.
93 """
94 reversed_latent = self.data.reversed_latents[self.watermarking_step] # shape: [1, 4, 64, 64]
95 reference_noise = self.data.reference_noise # shape: [1, 4, 64, 64]
96 k = self.data.k_value
98 patch_per_side_h = int(math.ceil(math.sqrt(k)))
99 patch_per_side_w = int(math.ceil(k / patch_per_side_h))
100 patch_height = 64 // patch_per_side_h
101 patch_width = 64 // patch_per_side_w
102 diff_map = torch.zeros((patch_per_side_h, patch_per_side_w))
104 patch_count = 0 # Initialize patch counter
105 for i in range(patch_per_side_h):
106 for j in range(patch_per_side_w):
107 if patch_count >= k:
108 break
109 y_start = i * patch_height
110 x_start = j * patch_width
111 y_end = min(y_start + patch_height, 64)
112 x_end = min(x_start + patch_width, 64)
113 patch1 = reversed_latent[:, :, y_start:y_end, x_start:x_end]
114 patch2 = reference_noise[:, :, y_start:y_end, x_start:x_end]
115 l2_val = torch.norm(patch1 - patch2).item()
116 diff_map[i, j] = l2_val
117 patch_count += 1 # Increment patch counter
119 im = ax.imshow(diff_map.cpu().numpy(), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
120 if title != "":
121 ax.set_title(title)
122 if use_color_bar:
123 ax.figure.colorbar(im, ax=ax)
124 # if show_number:
125 # # Calculate appropriate font size based on patch size
126 # patch_size = min(patch_per_side_h, patch_per_side_w)
127 # if patch_size >= 8:
128 # fontsize = 8
129 # format_str = '{:.2f}'
130 # elif patch_size >= 4:
131 # fontsize = 6
132 # format_str = '{:.1f}'
133 # else:
134 # fontsize = 4
135 # format_str = '{:.0f}'
136 # fontsize = 4
137 # format_str = '{:.0f}'
138 # for i in range(patch_per_side_h):
139 # for j in range(patch_per_side_w):
140 # if i * patch_per_side_w + j < k: # Only show numbers for valid patches
141 # value = diff_map[i, j].item()
142 # ax.text(j, i, format_str.format(value),
143 # ha='center', va='center', color='white',
144 # fontsize=fontsize, fontweight='bold')
145 ax.axis('off')
147 return ax