Coverage for detection / seal / seal_detection.py: 95.56%
45 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16import torch
17from detection.base import BaseDetector
18from transformers import Blip2Processor, Blip2ForConditionalGeneration
19from sentence_transformers import SentenceTransformer
20from PIL import Image
21import math
23class SEALDetector(BaseDetector):
25 def __init__(self,
26 k: int,
27 b: int,
28 theta_mid: int,
29 cap_processor: Blip2Processor,
30 cap_model: Blip2ForConditionalGeneration,
31 sentence_transformer: SentenceTransformer,
32 patch_distance_threshold: float,
33 device: torch.device):
34 super().__init__(patch_distance_threshold, device)
35 self.k = k
36 self.b = b
37 self.theta_mid = theta_mid
38 self.cap_processor = cap_processor
39 self.cap_model = cap_model
40 self.sentence_transformer = sentence_transformer
41 self.patch_distance_threshold = patch_distance_threshold
43 # Calculate the match threshold
44 # m^{\text{match}} = \left\lfloor n \rho(\theta^{\text{mid}}) \right\rfloor
45 # \rho(\theta) = \left( 1 - \frac{\theta}{180^\circ} \right)^b
46 # n = k
47 self.match_threshold = math.floor(self.k * ((1 - self.theta_mid / 180) ** self.b))
49 def _calculate_patch_l2(self, noise1: torch.Tensor, noise2: torch.Tensor, k: int) -> torch.Tensor:
50 """
51 Calculate L2 distances patch by patch. Returns a list of L2 values for the first k patches.
52 """
53 l2_list = []
54 patch_per_side_h = int(math.ceil(math.sqrt(k)))
55 patch_per_side_w = int(math.ceil(k / patch_per_side_h))
57 # Dynamically calculate patch size based on input tensor dimensions
58 _, _, H, W = noise1.shape
59 patch_height = H // patch_per_side_h
60 patch_width = W // patch_per_side_w
62 patch_count = 0
63 for i in range(patch_per_side_h):
64 for j in range(patch_per_side_w):
65 if patch_count >= k:
66 break
67 y_start = i * patch_height
68 x_start = j * patch_width
69 y_end = min(y_start + patch_height, H)
70 x_end = min(x_start + patch_width, W)
71 patch1 = noise1[:, :, y_start:y_end, x_start:x_end]
72 patch2 = noise2[:, :, y_start:y_end, x_start:x_end]
73 l2_val = torch.norm(patch1 - patch2).item()
74 l2_list.append(l2_val)
75 patch_count += 1
76 return l2_list
78 def eval_watermark(self,
79 reversed_latents: torch.Tensor,
80 reference_latents: torch.Tensor,
81 detector_type: str = "patch_accuracy") -> float:
83 if detector_type != "patch_accuracy":
84 raise ValueError(f"Detector type {detector_type} is not supported for SEAL detector")
86 l2_patch_list = self._calculate_patch_l2(reversed_latents, reference_latents, self.k)
88 # Count the number of patches that are less than the threshold
89 num_patches_below_threshold = sum(1 for l2 in l2_patch_list if l2 < self.patch_distance_threshold)
91 return {
92 "is_watermarked": bool(num_patches_below_threshold >= self.match_threshold),
93 "patch_accuracy": num_patches_below_threshold / self.k,
94 }