Coverage for detection / robin / robin_detection.py: 100.00%

43 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16import torch 

17from detection.base import BaseDetector 

18from scipy.stats import ncx2 

19from torch.nn import functional as F 

20 

21class ROBINDetector(BaseDetector): 

22 

23 def __init__(self, 

24 watermarking_mask: torch.Tensor, 

25 gt_patch: torch.Tensor, 

26 threshold: float, 

27 device: torch.device): 

28 super().__init__(threshold, device) 

29 self.watermarking_mask = watermarking_mask 

30 self.gt_patch = gt_patch 

31 

32 def eval_watermark(self, 

33 reversed_latents: torch.Tensor, 

34 reference_latents: torch.Tensor = None, 

35 detector_type: str = "l1_distance") -> float: 

36 reversed_latents_fft = torch.fft.fftshift(torch.fft.fft2(reversed_latents), dim=(-1, -2)) 

37 

38 # Resize mask and gt_patch if dimensions don't match 

39 if self.watermarking_mask.shape[-1] != reversed_latents.shape[-1]: 

40 target_size = reversed_latents.shape[-1] 

41 

42 # Resize mask (nearest neighbor for boolean mask) 

43 mask_float = self.watermarking_mask.float() 

44 mask_resized = F.interpolate(mask_float, size=(target_size, target_size), mode='nearest') 

45 current_mask = mask_resized.bool() 

46 

47 # Resize gt_patch (bilinear for continuous values) 

48 # gt_patch is complex, so we need to handle real and imag parts separately 

49 gt_real = self.gt_patch.real 

50 gt_imag = self.gt_patch.imag 

51 

52 gt_real_resized = F.interpolate(gt_real, size=(target_size, target_size), mode='bilinear', align_corners=False) 

53 gt_imag_resized = F.interpolate(gt_imag, size=(target_size, target_size), mode='bilinear', align_corners=False) 

54 

55 current_gt_patch = torch.complex(gt_real_resized, gt_imag_resized) 

56 else: 

57 current_mask = self.watermarking_mask 

58 current_gt_patch = self.gt_patch 

59 

60 if detector_type == 'l1_distance': 

61 target_patch = current_gt_patch 

62 l1_distance = torch.abs(reversed_latents_fft[current_mask] - target_patch[current_mask]).mean().item() 

63 return { 

64 'is_watermarked': bool(l1_distance < self.threshold), 

65 'l1_distance': l1_distance 

66 } 

67 elif detector_type == 'p_value': 

68 reversed_latents_fft_wm_area = reversed_latents_fft[current_mask].flatten() 

69 target_patch = current_gt_patch[current_mask].flatten() 

70 target_patch = torch.concatenate([target_patch.real, target_patch.imag]) 

71 reversed_latents_fft_wm_area = torch.concatenate([reversed_latents_fft_wm_area.real, reversed_latents_fft_wm_area.imag]) 

72 sigma_ = reversed_latents_fft_wm_area.std() 

73 lambda_ = (target_patch ** 2 / sigma_ ** 2).sum().item() 

74 x = (((reversed_latents_fft_wm_area - target_patch) / sigma_) ** 2).sum().item() 

75 p = ncx2.cdf(x=x, df=len(target_patch), nc=lambda_) 

76 return { 

77 'is_watermarked': p < self.threshold, 

78 'p_value': p 

79 } 

80 elif detector_type == 'cosine_similarity': 

81 reversed_latents_fft_wm_area = reversed_latents_fft[current_mask].flatten() 

82 target_patch = current_gt_patch[current_mask].flatten() 

83 cosine_similarity = F.cosine_similarity(reversed_latents_fft_wm_area.real, target_patch.real, dim=0) 

84 return { 

85 'is_watermarked': cosine_similarity > self.threshold, 

86 'cosine_similarity': cosine_similarity 

87 } 

88 else: 

89 raise ValueError(f"Tree Ring's watermark detector type {self.detector_type} not supported") 

90