Coverage for detection / gs / gs_detection.py: 98.08%
52 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16import torch
17import numpy as np
18from Crypto.Random import get_random_bytes
19from Crypto.Cipher import ChaCha20
20from scipy.stats import truncnorm, norm
21from functools import reduce
22from detection.base import BaseDetector
23from typing import Union
25class GSDetector(BaseDetector):
27 def __init__(self,
28 watermarking_mask: torch.Tensor,
29 chacha: bool,
30 wm_key: Union[int, tuple[int, int]],
31 channel_copy: int,
32 hw_copy: int,
33 vote_threshold: int,
34 threshold: float,
35 device: torch.device):
36 super().__init__(threshold, device)
37 self.chacha = chacha
38 self.channel_copy = channel_copy
39 self.hw_copy = hw_copy
40 self.watermark = watermarking_mask
41 self.vote_threshold = vote_threshold
42 if self.chacha:
43 self.chacha_key, self.chacha_nonce = wm_key
44 else:
45 self.key = wm_key
47 def _stream_key_decrypt(self, reversed_m):
48 """Decrypt the watermark using ChaCha20 cipher."""
49 cipher = ChaCha20.new(key=self.chacha_key, nonce=self.chacha_nonce)
50 sd_byte = cipher.decrypt(np.packbits(reversed_m).tobytes())
51 sd_bit = np.unpackbits(np.frombuffer(sd_byte, dtype=np.uint8))
53 # Calculate dimensions dynamically
54 total_elements = sd_bit.size
55 spatial_size = int(np.sqrt(total_elements / 4))
57 sd_tensor = torch.from_numpy(sd_bit).reshape(1, 4, spatial_size, spatial_size).to(torch.uint8)
58 return sd_tensor.to(self.device)
60 def _diffusion_inverse(self, reversed_sd):
61 """Inverse the diffusion process to extract the watermark."""
62 _, _, H, W = reversed_sd.shape
64 ch_stride = 4 // self.channel_copy
65 hw_stride_h = H // self.hw_copy
66 hw_stride_w = W // self.hw_copy
68 ch_list = [ch_stride] * self.channel_copy
69 hw_list_h = [hw_stride_h] * self.hw_copy
70 hw_list_w = [hw_stride_w] * self.hw_copy
72 split_dim1 = torch.cat(torch.split(reversed_sd, tuple(ch_list), dim=1), dim=0)
73 split_dim2 = torch.cat(torch.split(split_dim1, tuple(hw_list_h), dim=2), dim=0)
74 split_dim3 = torch.cat(torch.split(split_dim2, tuple(hw_list_w), dim=3), dim=0)
75 vote = torch.sum(split_dim3, dim=0).clone()
76 vote[vote <= self.vote_threshold] = 0
77 vote[vote > self.vote_threshold] = 1
78 return vote
80 def eval_watermark(self, reversed_latents: torch.Tensor, detector_type: str = "bit_acc") -> float:
81 """Evaluate watermark in reversed latents."""
82 if detector_type != 'bit_acc':
83 raise ValueError(f'Detector type {detector_type} is not supported for Gaussian Shading.')
84 reversed_m = (reversed_latents > 0).int()
85 if self.chacha:
86 reversed_sd = self._stream_key_decrypt(reversed_m.flatten().cpu().numpy())
87 else:
88 reversed_sd = (reversed_m + self.key) % 2
89 reversed_watermark = self._diffusion_inverse(reversed_sd)
90 correct = (reversed_watermark == self.watermark).float().mean().item()
92 return {
93 'is_watermarked': bool(correct > self.threshold),
94 'bit_acc': correct
95 }