Coverage for markdiffusion / detection / tr / tr_detection.py: 96.15%

26 statements  

« prev     ^ index     » next       coverage.py v7.14.0, created at 2026-05-14 19:25 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16import torch 

17from markdiffusion.detection.base import BaseDetector 

18from scipy.stats import ncx2 

19 

20class TRDetector(BaseDetector): 

21 

22 def __init__(self, 

23 watermarking_mask: torch.Tensor, 

24 gt_patch: torch.Tensor, 

25 threshold: float, 

26 device: torch.device, 

27 threshold_p_value: float = 0.01): 

28 super().__init__(threshold, device) 

29 self.watermarking_mask = watermarking_mask 

30 self.gt_patch = gt_patch 

31 self.threshold_p_value = threshold_p_value 

32 

33 def eval_watermark(self, 

34 reversed_latents: torch.Tensor, 

35 reference_latents: torch.Tensor = None, 

36 detector_type: str = "l1_distance") -> float: 

37 if detector_type == 'l1_distance': 

38 reversed_latents_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents), dim=(-1, -2)) #[self.watermarking_mask].flatten() 

39 target_patch = self.gt_patch #[self.watermarking_mask].flatten() 

40 l1_distance = torch.abs(reversed_latents_fft[self.watermarking_mask] - target_patch[self.watermarking_mask]).mean().item() 

41 return { 

42 'is_watermarked': l1_distance < self.threshold, 

43 'l1_distance': l1_distance 

44 } 

45 elif detector_type == 'p_value': 

46 reversed_latents_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents), dim=(-1, -2))[self.watermarking_mask].flatten() 

47 target_patch = self.gt_patch[self.watermarking_mask].flatten() 

48 target_patch = torch.concatenate([target_patch.real, target_patch.imag]) 

49 reversed_latents_fft = torch.concatenate([reversed_latents_fft.real, reversed_latents_fft.imag]) 

50 sigma_ = reversed_latents_fft.std() 

51 lambda_ = (target_patch ** 2 / sigma_ ** 2).sum().item() 

52 x = (((reversed_latents_fft - target_patch) / sigma_) ** 2).sum().item() 

53 p = ncx2.cdf(x=x, df=len(target_patch), nc=lambda_) 

54 return { 

55 'is_watermarked': bool(p < self.threshold_p_value), 

56 'p_value': p 

57 } 

58 else: 

59 raise ValueError(f"Tree Ring's watermark detector type {self.detector_type} not supported") 

60