Coverage for detection / ri / ri_detection.py: 91.67%
36 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16import torch
17from detection.base import BaseDetector
18import numpy as np
20class RIDetector(BaseDetector):
22 def __init__(self,
23 watermarking_mask: torch.Tensor,
24 ring_watermark_channel: list,
25 heter_watermark_channel: list,
26 pattern_list: list,
27 threshold: float,
28 device: torch.device):
29 super().__init__(threshold, device)
30 self.watermarking_mask = watermarking_mask
31 self.ring_watermark_channel = ring_watermark_channel
32 self.heter_watermark_channel = heter_watermark_channel
33 self.watermark_channel = sorted(self.heter_watermark_channel + self.ring_watermark_channel)
34 self.pattern_list = pattern_list
36 def _get_distance(self, tensor1, tensor2, mask, p, mode, channel_min=False):
37 channel = self.watermark_channel
38 if tensor1.shape != tensor2.shape:
39 raise ValueError(f'Shape mismatch during eval: {tensor1.shape} vs {tensor2.shape}')
40 if mode not in ['complex', 'real', 'imag']:
41 raise NotImplemented(f'Eval mode not implemented: {mode}')
43 if not channel_min:
44 if p == 1:
45 # a faster implementation for l1 distance
46 if mode == 'complex':
47 return torch.mean(torch.abs(tensor1[0][channel] - tensor2[0][channel])[mask]).item()
48 if mode == 'real':
49 return torch.mean(torch.abs(tensor1[0][channel].real - tensor2[0][channel].real)[mask]).item()
50 if mode == 'imag':
51 return torch.mean(torch.abs(tensor1[0][channel].imag - tensor2[0][channel].imag)[mask]).item()
52 # else:
53 # if mode == 'complex':
54 # return torch.norm(torch.abs(tensor1[0][channel][mask] - tensor2[0][channel][mask]),
55 # p=p).item() / torch.sum(mask)
56 # if mode == 'real':
57 # return torch.norm(torch.abs(tensor1[0][channel][mask].real - tensor2[0][channel][mask].real),
58 # p=p).item() / torch.sum(mask)
59 # if mode == 'imag':
60 # return torch.norm(torch.abs(tensor1[0][channel][mask].imag - tensor2[0][channel][mask].imag),
61 # p=p).item() / torch.sum(mask)
62 # else:
63 # # argmin TODO: normalize
64 # if len(self.ring_watermark_channel) > 1 and len(self.heter_watermark_channel) > 0:
65 # ring_channel_idx_list = [idx for idx, c_id in enumerate(self.watermark_channel) if
66 # c_id in self.ring_watermark_channel]
67 # heter_channel_idx_list = [idx for idx, c_id in enumerate(self.watermark_channel) if
68 # c_id in self.heter_watermark_channel]
69 # if mode == 'complex':
70 # diff = torch.abs(tensor1[0][channel] - tensor2[0][channel]) # [c, h, w]
71 # elif mode == 'real':
72 # diff = torch.abs(tensor1[0][channel].real - tensor2[0][channel].real) # [c, h, w]
73 # elif mode == 'imag':
74 # diff = torch.abs(tensor1[0][channel].imag - tensor2[0][channel].imag) # [c, h, w]
75 # l1_list = []
76 # num_items = []
77 # for c_idx in range(len(mask)):
78 # mask_c = torch.zeros_like(mask)
79 # mask_c[c_idx] = mask[c_idx]
80 # l1_list.append(torch.mean(diff[mask_c]).item())
81 # num_items.append(torch.sum(mask_c).item())
82 # total = 0
83 # num = 0
84 # for ring_channel_idx in ring_channel_idx_list:
85 # total += l1_list[ring_channel_idx] * num_items[ring_channel_idx]
86 # num += num_items[ring_channel_idx]
87 # ring_channels_mean = total / num
88 # return min(ring_channels_mean, min([l1_list[idx] for idx in heter_channel_idx_list]))
89 # elif len(self.ring_watermark_channel) == 1 and len(self.heter_watermark_channel) > 0:
90 # if mode == 'complex':
91 # diff = torch.abs(tensor1[0][channel] - tensor2[0][channel]) # [c, h, w]
92 # elif mode == 'real':
93 # diff = torch.abs(tensor1[0][channel].real - tensor2[0][channel].real) # [c, h, w]
94 # elif mode == 'imag':
95 # diff = torch.abs(tensor1[0][channel].imag - tensor2[0][channel].imag) # [c, h, w]
96 # l1_list = []
97 # for c_idx in range(len(mask)):
98 # mask_c = torch.zeros_like(mask)
99 # mask_c[c_idx] = mask[c_idx]
100 # l1_list.append(torch.mean(diff[mask_c]).item())
101 # return min(l1_list)
102 # else:
103 # raise NotImplementedError
106 def eval_watermark(self,
107 reversed_latents: torch.Tensor,
108 reference_latents: torch.Tensor = None,
109 detector_type: str = "l1_distance", mode="complex") -> float:
111 if detector_type != 'l1_distance':
112 raise ValueError(f"Detector type {detector_type} not supported")
114 reversed_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents), dim=(-1, -2))
115 all_distances = []
117 for idx, pattern in enumerate(self.pattern_list):
118 dist = self._get_distance(
119 pattern,
120 reversed_fft,
121 self.watermarking_mask,
122 p=1,
123 mode=mode,
124 channel_min=False
125 )
126 all_distances.append(dist)
128 min_idx = int(np.argmin(all_distances))
129 min_dist = all_distances[min_idx]
131 return {
132 'is_watermarked': bool(min_dist < self.threshold),
133 'l1_distance': min_dist
134 }