Coverage for watermark / gm / gnr.py: 93.98%

83 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from __future__ import annotations 

2 

3from pathlib import Path 

4from typing import Optional 

5 

6import torch 

7import torch.nn as nn 

8 

9 

10class DoubleConv(nn.Module): 

11 """(convolution => [BN] => ReLU) * 2.""" 

12 

13 def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None) -> None: 

14 super().__init__() 

15 mid_channels = mid_channels or out_channels 

16 self.double_conv = nn.Sequential( 

17 nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), 

18 nn.BatchNorm2d(mid_channels), 

19 nn.ReLU(inplace=True), 

20 nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), 

21 nn.BatchNorm2d(out_channels), 

22 nn.ReLU(inplace=True), 

23 ) 

24 

25 def forward(self, x: torch.Tensor) -> torch.Tensor: # noqa: D401 - inherited 

26 return self.double_conv(x) 

27 

28 

29class Down(nn.Module): 

30 """Downscaling with maxpool then double conv.""" 

31 

32 def __init__(self, in_channels: int, out_channels: int) -> None: 

33 super().__init__() 

34 self.maxpool_conv = nn.Sequential( 

35 nn.MaxPool2d(2), 

36 DoubleConv(in_channels, out_channels), 

37 ) 

38 

39 def forward(self, x: torch.Tensor) -> torch.Tensor: # noqa: D401 - inherited 

40 return self.maxpool_conv(x) 

41 

42 

43class Up(nn.Module): 

44 """Upscaling then double conv.""" 

45 

46 def __init__(self, in_channels: int, out_channels: int, bilinear: bool = True) -> None: 

47 super().__init__() 

48 if bilinear: 

49 self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True) 

50 self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) 

51 else: 

52 self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) 

53 self.conv = DoubleConv(in_channels, out_channels) 

54 

55 def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: # noqa: D401 - inherited 

56 x1 = self.up(x1) 

57 diff_y = x2.size(2) - x1.size(2) 

58 diff_x = x2.size(3) - x1.size(3) 

59 x1 = torch.nn.functional.pad( 

60 x1, 

61 [diff_x // 2, diff_x - diff_x // 2, diff_y // 2, diff_y - diff_y // 2], 

62 ) 

63 x = torch.cat([x2, x1], dim=1) 

64 return self.conv(x) 

65 

66 

67class OutConv(nn.Module): 

68 def __init__(self, in_channels: int, out_channels: int) -> None: 

69 super().__init__() 

70 self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) 

71 

72 def forward(self, x: torch.Tensor) -> torch.Tensor: # noqa: D401 - inherited 

73 return self.conv(x) 

74 

75 

76class GNRUNet(nn.Module): 

77 """UNet backbone used by the GaussMarker GNR restoration module.""" 

78 

79 def __init__(self, in_channels: int, out_channels: int, nf: int = 128, bilinear: bool = False) -> None: 

80 super().__init__() 

81 self.inc = DoubleConv(in_channels, nf) 

82 self.down1 = Down(nf, nf * 2) 

83 self.down2 = Down(nf * 2, nf * 4) 

84 self.down3 = Down(nf * 4, nf * 8) 

85 factor = 2 if bilinear else 1 

86 self.up2 = Up(nf * 8, nf * 4 // factor, bilinear) 

87 self.up3 = Up(nf * 4, nf * 2 // factor, bilinear) 

88 self.up4 = Up(nf * 2, nf, bilinear) 

89 self.outc = OutConv(nf, out_channels) 

90 

91 def forward(self, x: torch.Tensor) -> torch.Tensor: # noqa: D401 - inherited 

92 x1 = self.inc(x) 

93 x2 = self.down1(x1) 

94 x3 = self.down2(x2) 

95 x4 = self.down3(x3) 

96 x = self.up2(x4, x3) 

97 x = self.up3(x, x2) 

98 x = self.up4(x, x1) 

99 return self.outc(x) 

100 

101 

102class GNRRestorer: 

103 """Wrapper for loading and running the GaussMarker GNR restoration network.""" 

104 

105 def __init__( 

106 self, 

107 checkpoint_path: Path, 

108 in_channels: int, 

109 out_channels: int, 

110 nf: int, 

111 device: torch.device, 

112 classifier_type: int, 

113 base_message: Optional[torch.Tensor] = None, 

114 ) -> None: 

115 self.device = device 

116 self.classifier_type = classifier_type 

117 self.base_message = base_message.to(device) if base_message is not None else None 

118 self.model = GNRUNet(in_channels, out_channels, nf=nf) 

119 state = torch.load(checkpoint_path, map_location="cpu") 

120 self.model.load_state_dict(state) 

121 self.model.to(device) 

122 self.model.eval() 

123 

124 def restore(self, reversed_m: torch.Tensor) -> torch.Tensor: 

125 """Run the GNR model and return the restored watermark bits (probabilities).""" 

126 with torch.no_grad(): 

127 inputs = reversed_m.to(self.device, dtype=torch.float32) 

128 if self.classifier_type == 1: 

129 if self.base_message is None: 

130 raise ValueError("Base watermark message is required when classifier_type=1") 

131 inputs = torch.cat([self.base_message, inputs], dim=1) 

132 logits = self.model(inputs) 

133 probs = torch.sigmoid(logits) 

134 return probs 

135 

136 def restore_binary(self, reversed_m: torch.Tensor, threshold: float = 0.5) -> torch.Tensor: 

137 """Convenience helper returning binarised restored watermark bits.""" 

138 probs = self.restore(reversed_m) 

139 return (probs > threshold).float()