Coverage for visualize / ri / ri_visualizer.py: 100.00%
56 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16from typing import Optional
17import torch
18import matplotlib.pyplot as plt
19from matplotlib.axes import Axes
20import numpy as np
21from visualize.base import BaseVisualizer
22from visualize.data_for_visualization import DataForVisualization
24class RingIDVisualizer(BaseVisualizer):
25 """RingID watermark visualization class"""
27 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1):
28 super().__init__(data_for_visualization, dpi, watermarking_step)
30 def draw_ring_pattern_fft(self,
31 title: str = "Ring Watermark Pattern (FFT)",
32 cmap: str = "viridis",
33 use_color_bar: bool = True,
34 vmin: Optional[float] = None,
35 vmax: Optional[float] = None,
36 ax: Optional[Axes] = None,
37 **kwargs) -> Axes:
38 """
39 Draw the ring watermark pattern in the Fourier Domain.(background all zeros)
41 Parameters:
42 title (str): The title of the plot.
43 cmap (str): The colormap to use.
44 use_color_bar (bool): Whether to display the colorbar.
45 ax (Axes): The axes to plot on.
47 Returns:
48 Axes: The plotted axes.
49 """
50 watermarked_latents_fft = torch.from_numpy(self._fft_transform(self.data.latents))
51 background = torch.zeros_like(watermarked_latents_fft)
52 watermark_channel = self.data.ring_watermark_channel
53 pattern = self.data.pattern.cpu()
54 mask = self.data.mask.cpu()
55 for channel, channel_mask in zip(watermark_channel, mask):
56 watermarked_latents_fft[:, channel] = background[:, channel] + pattern[:,
57 channel] * channel_mask
59 im = ax.imshow(np.abs(watermarked_latents_fft[0, watermark_channel][0].cpu().numpy()), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
60 if title != "":
61 ax.set_title(title)
62 if use_color_bar:
63 ax.figure.colorbar(im, ax=ax)
64 ax.axis('off')
66 return ax
68 def draw_heter_pattern_fft(self,
69 title: str = "Heter Watermark Pattern (FFT)",
70 cmap: str = "viridis",
71 use_color_bar: bool = True,
72 vmin: Optional[float] = None,
73 vmax: Optional[float] = None,
74 ax: Optional[Axes] = None,
75 **kwargs) -> Axes:
76 """
77 Draw the heter watermark pattern in the Fourier Domain.(background all zeros)
79 Parameters:
80 title (str): The title of the plot.
81 cmap (str): The colormap to use.
82 use_color_bar (bool): Whether to display the colorbar.
83 ax (Axes): The axes to plot on.
85 Returns:
86 Axes: The plotted axes.
87 """
88 watermarked_latents_fft = torch.from_numpy(self._fft_transform(self.data.latents))
89 background = torch.zeros_like(watermarked_latents_fft)
90 watermark_channel = self.data.heter_watermark_channel
91 pattern = self.data.pattern.cpu()
92 mask = self.data.mask.cpu()
93 for channel, channel_mask in zip(watermark_channel, mask):
94 watermarked_latents_fft[:, channel] = background[:, channel] + pattern[:,
95 channel] * channel_mask
97 im = ax.imshow(np.abs(watermarked_latents_fft[0, self.data.heter_watermark_channel][0].cpu().numpy()), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
98 if title != "":
99 ax.set_title(title)
100 if use_color_bar:
101 ax.figure.colorbar(im, ax=ax)
102 ax.axis('off')
104 return ax
106 def draw_inverted_ring_pattern_fft(self,
107 title: str = "Inverted Ring Watermark Pattern (FFT)",
108 cmap: str = "viridis",
109 use_color_bar: bool = True,
110 vmin: Optional[float] = None,
111 vmax: Optional[float] = None,
112 ax: Optional[Axes] = None,
113 **kwargs) -> Axes:
114 """
115 Extract and visualize watermark pattern from reversed_latents[-1] using FFT.
117 Parameters:
118 title (str): The title of the plot.
119 cmap (str): The colormap to use.
120 use_color_bar (bool): Whether to display the colorbar.
121 ax (Axes): The axes to plot on.
123 Returns:
124 Axes: The plotted axes.
125 """
126 # Extract the last reversed latent
127 reversed_latent = self.data.reversed_latents[-1]
129 # Apply FFT transform
130 reversed_latent_fft = torch.from_numpy(self._fft_transform(reversed_latent))
132 # Extract watermark pattern from ring watermark channel
133 watermark_channel = self.data.ring_watermark_channel
134 pattern = self.data.pattern.cpu()
135 mask = self.data.mask.cpu()
137 # Create extracted pattern
138 extracted_pattern = torch.zeros_like(reversed_latent_fft)
139 for channel, channel_mask in zip(watermark_channel, mask):
140 extracted_pattern[:, channel] = reversed_latent_fft[:, channel] * channel_mask
142 # Visualize the extracted pattern
143 im = ax.imshow(np.abs(extracted_pattern[0, watermark_channel][0].cpu().numpy()), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs)
144 if title != "":
145 ax.set_title(title)
146 if use_color_bar:
147 ax.figure.colorbar(im, ax=ax)
148 ax.axis('off')
150 return ax