Coverage for detection / tr / tr_detection.py: 96.00%

25 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16import torch 

17from detection.base import BaseDetector 

18from scipy.stats import ncx2 

19 

20class TRDetector(BaseDetector): 

21 

22 def __init__(self, 

23 watermarking_mask: torch.Tensor, 

24 gt_patch: torch.Tensor, 

25 threshold: float, 

26 device: torch.device): 

27 super().__init__(threshold, device) 

28 self.watermarking_mask = watermarking_mask 

29 self.gt_patch = gt_patch 

30 

31 def eval_watermark(self, 

32 reversed_latents: torch.Tensor, 

33 reference_latents: torch.Tensor = None, 

34 detector_type: str = "l1_distance") -> float: 

35 if detector_type == 'l1_distance': 

36 reversed_latents_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents), dim=(-1, -2)) #[self.watermarking_mask].flatten() 

37 target_patch = self.gt_patch #[self.watermarking_mask].flatten() 

38 l1_distance = torch.abs(reversed_latents_fft[self.watermarking_mask] - target_patch[self.watermarking_mask]).mean().item() 

39 return { 

40 'is_watermarked': l1_distance < self.threshold, 

41 'l1_distance': l1_distance 

42 } 

43 elif detector_type == 'p_value': 

44 reversed_latents_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents), dim=(-1, -2))[self.watermarking_mask].flatten() 

45 target_patch = self.gt_patch[self.watermarking_mask].flatten() 

46 target_patch = torch.concatenate([target_patch.real, target_patch.imag]) 

47 reversed_latents_fft = torch.concatenate([reversed_latents_fft.real, reversed_latents_fft.imag]) 

48 sigma_ = reversed_latents_fft.std() 

49 lambda_ = (target_patch ** 2 / sigma_ ** 2).sum().item() 

50 x = (((reversed_latents_fft - target_patch) / sigma_) ** 2).sum().item() 

51 p = ncx2.cdf(x=x, df=len(target_patch), nc=lambda_) 

52 return { 

53 'is_watermarked': bool(p < self.threshold), 

54 'p_value': p 

55 } 

56 else: 

57 raise ValueError(f"Tree Ring's watermark detector type {self.detector_type} not supported") 

58