Coverage for detection / robin / robin_detection.py: 100.00%
43 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16import torch
17from detection.base import BaseDetector
18from scipy.stats import ncx2
19from torch.nn import functional as F
21class ROBINDetector(BaseDetector):
23 def __init__(self,
24 watermarking_mask: torch.Tensor,
25 gt_patch: torch.Tensor,
26 threshold: float,
27 device: torch.device):
28 super().__init__(threshold, device)
29 self.watermarking_mask = watermarking_mask
30 self.gt_patch = gt_patch
32 def eval_watermark(self,
33 reversed_latents: torch.Tensor,
34 reference_latents: torch.Tensor = None,
35 detector_type: str = "l1_distance") -> float:
36 reversed_latents_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents), dim=(-1, -2))
38 # Resize mask and gt_patch if dimensions don't match
39 if self.watermarking_mask.shape[-1] != reversed_latents.shape[-1]:
40 target_size = reversed_latents.shape[-1]
42 # Resize mask (nearest neighbor for boolean mask)
43 mask_float = self.watermarking_mask.float()
44 mask_resized = F.interpolate(mask_float, size=(target_size, target_size), mode='nearest')
45 current_mask = mask_resized.bool()
47 # Resize gt_patch (bilinear for continuous values)
48 # gt_patch is complex, so we need to handle real and imag parts separately
49 gt_real = self.gt_patch.real
50 gt_imag = self.gt_patch.imag
52 gt_real_resized = F.interpolate(gt_real, size=(target_size, target_size), mode='bilinear', align_corners=False)
53 gt_imag_resized = F.interpolate(gt_imag, size=(target_size, target_size), mode='bilinear', align_corners=False)
55 current_gt_patch = torch.complex(gt_real_resized, gt_imag_resized)
56 else:
57 current_mask = self.watermarking_mask
58 current_gt_patch = self.gt_patch
60 if detector_type == 'l1_distance':
61 target_patch = current_gt_patch
62 l1_distance = torch.abs(reversed_latents_fft[current_mask] - target_patch[current_mask]).mean().item()
63 return {
64 'is_watermarked': bool(l1_distance < self.threshold),
65 'l1_distance': l1_distance
66 }
67 elif detector_type == 'p_value':
68 reversed_latents_fft_wm_area = reversed_latents_fft[current_mask].flatten()
69 target_patch = current_gt_patch[current_mask].flatten()
70 target_patch = torch.concatenate([target_patch.real, target_patch.imag])
71 reversed_latents_fft_wm_area = torch.concatenate([reversed_latents_fft_wm_area.real, reversed_latents_fft_wm_area.imag])
72 sigma_ = reversed_latents_fft_wm_area.std()
73 lambda_ = (target_patch ** 2 / sigma_ ** 2).sum().item()
74 x = (((reversed_latents_fft_wm_area - target_patch) / sigma_) ** 2).sum().item()
75 p = ncx2.cdf(x=x, df=len(target_patch), nc=lambda_)
76 return {
77 'is_watermarked': p < self.threshold,
78 'p_value': p
79 }
80 elif detector_type == 'cosine_similarity':
81 reversed_latents_fft_wm_area = reversed_latents_fft[current_mask].flatten()
82 target_patch = current_gt_patch[current_mask].flatten()
83 cosine_similarity = F.cosine_similarity(reversed_latents_fft_wm_area.real, target_patch.real, dim=0)
84 return {
85 'is_watermarked': cosine_similarity > self.threshold,
86 'cosine_similarity': cosine_similarity
87 }
88 else:
89 raise ValueError(f"Tree Ring's watermark detector type {self.detector_type} not supported")