Coverage for detection / seal / seal_detection.py: 95.56%

45 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16import torch 

17from detection.base import BaseDetector 

18from transformers import Blip2Processor, Blip2ForConditionalGeneration 

19from sentence_transformers import SentenceTransformer 

20from PIL import Image 

21import math 

22 

23class SEALDetector(BaseDetector): 

24 

25 def __init__(self, 

26 k: int, 

27 b: int, 

28 theta_mid: int, 

29 cap_processor: Blip2Processor, 

30 cap_model: Blip2ForConditionalGeneration, 

31 sentence_transformer: SentenceTransformer, 

32 patch_distance_threshold: float, 

33 device: torch.device): 

34 super().__init__(patch_distance_threshold, device) 

35 self.k = k 

36 self.b = b 

37 self.theta_mid = theta_mid 

38 self.cap_processor = cap_processor 

39 self.cap_model = cap_model 

40 self.sentence_transformer = sentence_transformer 

41 self.patch_distance_threshold = patch_distance_threshold 

42 

43 # Calculate the match threshold 

44 # m^{\text{match}} = \left\lfloor n \rho(\theta^{\text{mid}}) \right\rfloor 

45 # \rho(\theta) = \left( 1 - \frac{\theta}{180^\circ} \right)^b 

46 # n = k 

47 self.match_threshold = math.floor(self.k * ((1 - self.theta_mid / 180) ** self.b)) 

48 

49 def _calculate_patch_l2(self, noise1: torch.Tensor, noise2: torch.Tensor, k: int) -> torch.Tensor: 

50 """ 

51 Calculate L2 distances patch by patch. Returns a list of L2 values for the first k patches. 

52 """ 

53 l2_list = [] 

54 patch_per_side_h = int(math.ceil(math.sqrt(k))) 

55 patch_per_side_w = int(math.ceil(k / patch_per_side_h)) 

56 

57 # Dynamically calculate patch size based on input tensor dimensions 

58 _, _, H, W = noise1.shape 

59 patch_height = H // patch_per_side_h 

60 patch_width = W // patch_per_side_w 

61 

62 patch_count = 0 

63 for i in range(patch_per_side_h): 

64 for j in range(patch_per_side_w): 

65 if patch_count >= k: 

66 break 

67 y_start = i * patch_height 

68 x_start = j * patch_width 

69 y_end = min(y_start + patch_height, H) 

70 x_end = min(x_start + patch_width, W) 

71 patch1 = noise1[:, :, y_start:y_end, x_start:x_end] 

72 patch2 = noise2[:, :, y_start:y_end, x_start:x_end] 

73 l2_val = torch.norm(patch1 - patch2).item() 

74 l2_list.append(l2_val) 

75 patch_count += 1 

76 return l2_list 

77 

78 def eval_watermark(self, 

79 reversed_latents: torch.Tensor, 

80 reference_latents: torch.Tensor, 

81 detector_type: str = "patch_accuracy") -> float: 

82 

83 if detector_type != "patch_accuracy": 

84 raise ValueError(f"Detector type {detector_type} is not supported for SEAL detector") 

85 

86 l2_patch_list = self._calculate_patch_l2(reversed_latents, reference_latents, self.k) 

87 

88 # Count the number of patches that are less than the threshold 

89 num_patches_below_threshold = sum(1 for l2 in l2_patch_list if l2 < self.patch_distance_threshold) 

90 

91 return { 

92 "is_watermarked": bool(num_patches_below_threshold >= self.match_threshold), 

93 "patch_accuracy": num_patches_below_threshold / self.k, 

94 } 

95 

96