Coverage for visualize / tr / tr_visualizer.py: 97.37%

38 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1import torch 

2import matplotlib.pyplot as plt 

3from matplotlib.axes import Axes 

4from typing import Optional 

5import numpy as np 

6from visualize.base import BaseVisualizer 

7from visualize.data_for_visualization import DataForVisualization 

8 

9class TreeRingVisualizer(BaseVisualizer): 

10 """Tree-Ring watermark visualization class""" 

11 

12 def __init__(self, data_for_visualization: DataForVisualization, dpi: int = 300, watermarking_step: int = -1): 

13 super().__init__(data_for_visualization, dpi, watermarking_step) 

14 

15 def draw_pattern_fft(self, 

16 title: str = "Tree-Ring FFT with Watermark Area", 

17 cmap: str = "viridis", 

18 use_color_bar: bool = True, 

19 vmin: Optional[float] = None, 

20 vmax: Optional[float] = None, 

21 ax: Optional[Axes] = None, 

22 **kwargs) -> Axes: 

23 """ 

24 Draw FFT visualization with original watermark pattern, with all 0 background. 

25  

26 Parameters: 

27 title (str): The title of the plot. 

28 cmap (str): The colormap to use. 

29 use_color_bar (bool): Whether to display the colorbar. 

30 ax (Axes): The axes to plot on. 

31  

32 Returns: 

33 Axes: The plotted axes. 

34 """ 

35 orig_latent = self.data.orig_watermarked_latents[0, self.data.w_channel].cpu() 

36 watermarking_mask = self.data.watermarking_mask[0, self.data.w_channel].cpu() 

37 

38 fft_data = torch.from_numpy(self._fft_transform(orig_latent)) 

39 fft_vis = torch.zeros_like(fft_data) 

40 fft_vis[watermarking_mask] = fft_data[watermarking_mask] 

41 

42 im = ax.imshow(np.abs(fft_vis.cpu().numpy()), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

43 if title != "": 

44 ax.set_title(title) 

45 if use_color_bar: 

46 ax.figure.colorbar(im, ax=ax) 

47 ax.axis('off') 

48 

49 return ax 

50 

51 def draw_inverted_pattern_fft(self, 

52 step: Optional[int] = None, 

53 title: str = "Tree-Ring FFT with Inverted Watermark Area", 

54 cmap: str = "viridis", 

55 use_color_bar: bool = True, 

56 vmin: Optional[float] = None, 

57 vmax: Optional[float] = None, 

58 ax: Optional[Axes] = None, 

59 **kwargs) -> Axes: 

60 """ 

61 Draw FFT visualization with inverted pattern, with all 0 background. 

62  

63 Parameters: 

64 step (Optional[int]): The timestep of the inverted latents. If None, the last timestep is used. 

65 title (str): The title of the plot. 

66 cmap (str): The colormap to use. 

67 use_color_bar (bool): Whether to display the colorbar. 

68 ax (Axes): The axes to plot on. 

69  

70 Returns: 

71 Axes: The plotted axes. 

72 """ 

73 if step is None: 

74 inverted_latent = self.data.reversed_latents[self.watermarking_step][0, self.data.w_channel] 

75 else: 

76 inverted_latent = self.data.reversed_latents[step][0, self.data.w_channel] 

77 

78 watermarking_mask = self.data.watermarking_mask[0, self.data.w_channel].cpu() 

79 

80 fft_data = torch.from_numpy(self._fft_transform(inverted_latent)) 

81 fft_vis = torch.zeros_like(fft_data).to(fft_data.device) 

82 fft_vis[watermarking_mask] = fft_data[watermarking_mask] 

83 

84 im = ax.imshow(np.abs(fft_vis.cpu().numpy()), cmap=cmap, vmin=vmin, vmax=vmax, **kwargs) 

85 if title != "": 

86 ax.set_title(title) 

87 if use_color_bar: 

88 ax.figure.colorbar(im, ax=ax) 

89 ax.axis('off') 

90 

91 return ax 

92