Coverage for watermark / gm / gnr.py: 93.98%
83 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from __future__ import annotations
3from pathlib import Path
4from typing import Optional
6import torch
7import torch.nn as nn
10class DoubleConv(nn.Module):
11 """(convolution => [BN] => ReLU) * 2."""
13 def __init__(self, in_channels: int, out_channels: int, mid_channels: Optional[int] = None) -> None:
14 super().__init__()
15 mid_channels = mid_channels or out_channels
16 self.double_conv = nn.Sequential(
17 nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
18 nn.BatchNorm2d(mid_channels),
19 nn.ReLU(inplace=True),
20 nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
21 nn.BatchNorm2d(out_channels),
22 nn.ReLU(inplace=True),
23 )
25 def forward(self, x: torch.Tensor) -> torch.Tensor: # noqa: D401 - inherited
26 return self.double_conv(x)
29class Down(nn.Module):
30 """Downscaling with maxpool then double conv."""
32 def __init__(self, in_channels: int, out_channels: int) -> None:
33 super().__init__()
34 self.maxpool_conv = nn.Sequential(
35 nn.MaxPool2d(2),
36 DoubleConv(in_channels, out_channels),
37 )
39 def forward(self, x: torch.Tensor) -> torch.Tensor: # noqa: D401 - inherited
40 return self.maxpool_conv(x)
43class Up(nn.Module):
44 """Upscaling then double conv."""
46 def __init__(self, in_channels: int, out_channels: int, bilinear: bool = True) -> None:
47 super().__init__()
48 if bilinear:
49 self.up = nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True)
50 self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
51 else:
52 self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
53 self.conv = DoubleConv(in_channels, out_channels)
55 def forward(self, x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor: # noqa: D401 - inherited
56 x1 = self.up(x1)
57 diff_y = x2.size(2) - x1.size(2)
58 diff_x = x2.size(3) - x1.size(3)
59 x1 = torch.nn.functional.pad(
60 x1,
61 [diff_x // 2, diff_x - diff_x // 2, diff_y // 2, diff_y - diff_y // 2],
62 )
63 x = torch.cat([x2, x1], dim=1)
64 return self.conv(x)
67class OutConv(nn.Module):
68 def __init__(self, in_channels: int, out_channels: int) -> None:
69 super().__init__()
70 self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
72 def forward(self, x: torch.Tensor) -> torch.Tensor: # noqa: D401 - inherited
73 return self.conv(x)
76class GNRUNet(nn.Module):
77 """UNet backbone used by the GaussMarker GNR restoration module."""
79 def __init__(self, in_channels: int, out_channels: int, nf: int = 128, bilinear: bool = False) -> None:
80 super().__init__()
81 self.inc = DoubleConv(in_channels, nf)
82 self.down1 = Down(nf, nf * 2)
83 self.down2 = Down(nf * 2, nf * 4)
84 self.down3 = Down(nf * 4, nf * 8)
85 factor = 2 if bilinear else 1
86 self.up2 = Up(nf * 8, nf * 4 // factor, bilinear)
87 self.up3 = Up(nf * 4, nf * 2 // factor, bilinear)
88 self.up4 = Up(nf * 2, nf, bilinear)
89 self.outc = OutConv(nf, out_channels)
91 def forward(self, x: torch.Tensor) -> torch.Tensor: # noqa: D401 - inherited
92 x1 = self.inc(x)
93 x2 = self.down1(x1)
94 x3 = self.down2(x2)
95 x4 = self.down3(x3)
96 x = self.up2(x4, x3)
97 x = self.up3(x, x2)
98 x = self.up4(x, x1)
99 return self.outc(x)
102class GNRRestorer:
103 """Wrapper for loading and running the GaussMarker GNR restoration network."""
105 def __init__(
106 self,
107 checkpoint_path: Path,
108 in_channels: int,
109 out_channels: int,
110 nf: int,
111 device: torch.device,
112 classifier_type: int,
113 base_message: Optional[torch.Tensor] = None,
114 ) -> None:
115 self.device = device
116 self.classifier_type = classifier_type
117 self.base_message = base_message.to(device) if base_message is not None else None
118 self.model = GNRUNet(in_channels, out_channels, nf=nf)
119 state = torch.load(checkpoint_path, map_location="cpu")
120 self.model.load_state_dict(state)
121 self.model.to(device)
122 self.model.eval()
124 def restore(self, reversed_m: torch.Tensor) -> torch.Tensor:
125 """Run the GNR model and return the restored watermark bits (probabilities)."""
126 with torch.no_grad():
127 inputs = reversed_m.to(self.device, dtype=torch.float32)
128 if self.classifier_type == 1:
129 if self.base_message is None:
130 raise ValueError("Base watermark message is required when classifier_type=1")
131 inputs = torch.cat([self.base_message, inputs], dim=1)
132 logits = self.model(inputs)
133 probs = torch.sigmoid(logits)
134 return probs
136 def restore_binary(self, reversed_m: torch.Tensor, threshold: float = 0.5) -> torch.Tensor:
137 """Convenience helper returning binarised restored watermark bits."""
138 probs = self.restore(reversed_m)
139 return (probs > threshold).float()