Coverage for detection / ri / ri_detection.py: 91.67%

36 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16import torch 

17from detection.base import BaseDetector 

18import numpy as np 

19 

20class RIDetector(BaseDetector): 

21 

22 def __init__(self, 

23 watermarking_mask: torch.Tensor, 

24 ring_watermark_channel: list, 

25 heter_watermark_channel: list, 

26 pattern_list: list, 

27 threshold: float, 

28 device: torch.device): 

29 super().__init__(threshold, device) 

30 self.watermarking_mask = watermarking_mask 

31 self.ring_watermark_channel = ring_watermark_channel 

32 self.heter_watermark_channel = heter_watermark_channel 

33 self.watermark_channel = sorted(self.heter_watermark_channel + self.ring_watermark_channel) 

34 self.pattern_list = pattern_list 

35 

36 def _get_distance(self, tensor1, tensor2, mask, p, mode, channel_min=False): 

37 channel = self.watermark_channel 

38 if tensor1.shape != tensor2.shape: 

39 raise ValueError(f'Shape mismatch during eval: {tensor1.shape} vs {tensor2.shape}') 

40 if mode not in ['complex', 'real', 'imag']: 

41 raise NotImplemented(f'Eval mode not implemented: {mode}') 

42 

43 if not channel_min: 

44 if p == 1: 

45 # a faster implementation for l1 distance 

46 if mode == 'complex': 

47 return torch.mean(torch.abs(tensor1[0][channel] - tensor2[0][channel])[mask]).item() 

48 if mode == 'real': 

49 return torch.mean(torch.abs(tensor1[0][channel].real - tensor2[0][channel].real)[mask]).item() 

50 if mode == 'imag': 

51 return torch.mean(torch.abs(tensor1[0][channel].imag - tensor2[0][channel].imag)[mask]).item() 

52 # else: 

53 # if mode == 'complex': 

54 # return torch.norm(torch.abs(tensor1[0][channel][mask] - tensor2[0][channel][mask]), 

55 # p=p).item() / torch.sum(mask) 

56 # if mode == 'real': 

57 # return torch.norm(torch.abs(tensor1[0][channel][mask].real - tensor2[0][channel][mask].real), 

58 # p=p).item() / torch.sum(mask) 

59 # if mode == 'imag': 

60 # return torch.norm(torch.abs(tensor1[0][channel][mask].imag - tensor2[0][channel][mask].imag), 

61 # p=p).item() / torch.sum(mask) 

62 # else: 

63 # # argmin TODO: normalize 

64 # if len(self.ring_watermark_channel) > 1 and len(self.heter_watermark_channel) > 0: 

65 # ring_channel_idx_list = [idx for idx, c_id in enumerate(self.watermark_channel) if 

66 # c_id in self.ring_watermark_channel] 

67 # heter_channel_idx_list = [idx for idx, c_id in enumerate(self.watermark_channel) if 

68 # c_id in self.heter_watermark_channel] 

69 # if mode == 'complex': 

70 # diff = torch.abs(tensor1[0][channel] - tensor2[0][channel]) # [c, h, w] 

71 # elif mode == 'real': 

72 # diff = torch.abs(tensor1[0][channel].real - tensor2[0][channel].real) # [c, h, w] 

73 # elif mode == 'imag': 

74 # diff = torch.abs(tensor1[0][channel].imag - tensor2[0][channel].imag) # [c, h, w] 

75 # l1_list = [] 

76 # num_items = [] 

77 # for c_idx in range(len(mask)): 

78 # mask_c = torch.zeros_like(mask) 

79 # mask_c[c_idx] = mask[c_idx] 

80 # l1_list.append(torch.mean(diff[mask_c]).item()) 

81 # num_items.append(torch.sum(mask_c).item()) 

82 # total = 0 

83 # num = 0 

84 # for ring_channel_idx in ring_channel_idx_list: 

85 # total += l1_list[ring_channel_idx] * num_items[ring_channel_idx] 

86 # num += num_items[ring_channel_idx] 

87 # ring_channels_mean = total / num 

88 # return min(ring_channels_mean, min([l1_list[idx] for idx in heter_channel_idx_list])) 

89 # elif len(self.ring_watermark_channel) == 1 and len(self.heter_watermark_channel) > 0: 

90 # if mode == 'complex': 

91 # diff = torch.abs(tensor1[0][channel] - tensor2[0][channel]) # [c, h, w] 

92 # elif mode == 'real': 

93 # diff = torch.abs(tensor1[0][channel].real - tensor2[0][channel].real) # [c, h, w] 

94 # elif mode == 'imag': 

95 # diff = torch.abs(tensor1[0][channel].imag - tensor2[0][channel].imag) # [c, h, w] 

96 # l1_list = [] 

97 # for c_idx in range(len(mask)): 

98 # mask_c = torch.zeros_like(mask) 

99 # mask_c[c_idx] = mask[c_idx] 

100 # l1_list.append(torch.mean(diff[mask_c]).item()) 

101 # return min(l1_list) 

102 # else: 

103 # raise NotImplementedError 

104 

105 

106 def eval_watermark(self, 

107 reversed_latents: torch.Tensor, 

108 reference_latents: torch.Tensor = None, 

109 detector_type: str = "l1_distance", mode="complex") -> float: 

110 

111 if detector_type != 'l1_distance': 

112 raise ValueError(f"Detector type {detector_type} not supported") 

113 

114 reversed_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents), dim=(-1, -2)) 

115 all_distances = [] 

116 

117 for idx, pattern in enumerate(self.pattern_list): 

118 dist = self._get_distance( 

119 pattern, 

120 reversed_fft, 

121 self.watermarking_mask, 

122 p=1, 

123 mode=mode, 

124 channel_min=False 

125 ) 

126 all_distances.append(dist) 

127 

128 min_idx = int(np.argmin(all_distances)) 

129 min_dist = all_distances[min_idx] 

130 

131 return { 

132 'is_watermarked': bool(min_dist < self.threshold), 

133 'l1_distance': min_dist 

134 } 

135 

136