Coverage for visualize / ri / ri_visualizer.py: 100.00%

56 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16from typing import Optional 

17import torch 

18import matplotlib.pyplot as plt 

19from matplotlib.axes import Axes 

20import numpy as np 

21from visualize.base import BaseVisualizer 

22from visualize.data_for_visualization import DataForVisualization 

23 

24class RingIDVisualizer(BaseVisualizer): 

25 """RingID watermark visualization class""" 

26 

27 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1): 

28 super().__init__(data_for_visualization, dpi, watermarking_step) 

29 

30 def draw_ring_pattern_fft(self, 

31 title: str = "Ring Watermark Pattern (FFT)", 

32 cmap: str = "viridis", 

33 use_color_bar: bool = True, 

34 vmin: Optional[float] = None, 

35 vmax: Optional[float] = None, 

36 ax: Optional[Axes] = None, 

37 **kwargs) -> Axes: 

38 """ 

39 Draw the ring watermark pattern in the Fourier Domain.(background all zeros) 

40  

41 Parameters: 

42 title (str): The title of the plot. 

43 cmap (str): The colormap to use. 

44 use_color_bar (bool): Whether to display the colorbar. 

45 ax (Axes): The axes to plot on. 

46  

47 Returns: 

48 Axes: The plotted axes. 

49 """ 

50 watermarked_latents_fft = torch.from_numpy(self._fft_transform(self.data.latents)) 

51 background = torch.zeros_like(watermarked_latents_fft) 

52 watermark_channel = self.data.ring_watermark_channel 

53 pattern = self.data.pattern.cpu() 

54 mask = self.data.mask.cpu() 

55 for channel, channel_mask in zip(watermark_channel, mask): 

56 watermarked_latents_fft[:, channel] = background[:, channel] + pattern[:, 

57 channel] * channel_mask 

58 

59 im = ax.imshow(np.abs(watermarked_latents_fft[0, watermark_channel][0].cpu().numpy()), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

60 if title != "": 

61 ax.set_title(title) 

62 if use_color_bar: 

63 ax.figure.colorbar(im, ax=ax) 

64 ax.axis('off') 

65 

66 return ax 

67 

68 def draw_heter_pattern_fft(self, 

69 title: str = "Heter Watermark Pattern (FFT)", 

70 cmap: str = "viridis", 

71 use_color_bar: bool = True, 

72 vmin: Optional[float] = None, 

73 vmax: Optional[float] = None, 

74 ax: Optional[Axes] = None, 

75 **kwargs) -> Axes: 

76 """ 

77 Draw the heter watermark pattern in the Fourier Domain.(background all zeros) 

78  

79 Parameters: 

80 title (str): The title of the plot. 

81 cmap (str): The colormap to use. 

82 use_color_bar (bool): Whether to display the colorbar. 

83 ax (Axes): The axes to plot on. 

84  

85 Returns: 

86 Axes: The plotted axes. 

87 """ 

88 watermarked_latents_fft = torch.from_numpy(self._fft_transform(self.data.latents)) 

89 background = torch.zeros_like(watermarked_latents_fft) 

90 watermark_channel = self.data.heter_watermark_channel 

91 pattern = self.data.pattern.cpu() 

92 mask = self.data.mask.cpu() 

93 for channel, channel_mask in zip(watermark_channel, mask): 

94 watermarked_latents_fft[:, channel] = background[:, channel] + pattern[:, 

95 channel] * channel_mask 

96 

97 im = ax.imshow(np.abs(watermarked_latents_fft[0, self.data.heter_watermark_channel][0].cpu().numpy()), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

98 if title != "": 

99 ax.set_title(title) 

100 if use_color_bar: 

101 ax.figure.colorbar(im, ax=ax) 

102 ax.axis('off') 

103 

104 return ax 

105 

106 def draw_inverted_ring_pattern_fft(self, 

107 title: str = "Inverted Ring Watermark Pattern (FFT)", 

108 cmap: str = "viridis", 

109 use_color_bar: bool = True, 

110 vmin: Optional[float] = None, 

111 vmax: Optional[float] = None, 

112 ax: Optional[Axes] = None, 

113 **kwargs) -> Axes: 

114 """ 

115 Extract and visualize watermark pattern from reversed_latents[-1] using FFT. 

116  

117 Parameters: 

118 title (str): The title of the plot. 

119 cmap (str): The colormap to use. 

120 use_color_bar (bool): Whether to display the colorbar. 

121 ax (Axes): The axes to plot on. 

122  

123 Returns: 

124 Axes: The plotted axes. 

125 """ 

126 # Extract the last reversed latent 

127 reversed_latent = self.data.reversed_latents[-1] 

128 

129 # Apply FFT transform 

130 reversed_latent_fft = torch.from_numpy(self._fft_transform(reversed_latent)) 

131 

132 # Extract watermark pattern from ring watermark channel 

133 watermark_channel = self.data.ring_watermark_channel 

134 pattern = self.data.pattern.cpu() 

135 mask = self.data.mask.cpu() 

136 

137 # Create extracted pattern 

138 extracted_pattern = torch.zeros_like(reversed_latent_fft) 

139 for channel, channel_mask in zip(watermark_channel, mask): 

140 extracted_pattern[:, channel] = reversed_latent_fft[:, channel] * channel_mask 

141 

142 # Visualize the extracted pattern 

143 im = ax.imshow(np.abs(extracted_pattern[0, watermark_channel][0].cpu().numpy()), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

144 if title != "": 

145 ax.set_title(title) 

146 if use_color_bar: 

147 ax.figure.colorbar(im, ax=ax) 

148 ax.axis('off') 

149 

150 return ax