Coverage for watermark / videomark / video_mark.py: 97.42%

233 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from ..base import BaseWatermark, BaseConfig 

2import torch 

3import numpy as np 

4from typing import Dict, Tuple, Any, Optional, List, Union 

5from PIL import Image 

6import galois 

7from scipy.sparse import csr_matrix 

8from scipy.special import binom 

9import logging 

10from functools import reduce 

11from visualize.data_for_visualization import DataForVisualization 

12from detection.videomark.videomark_detection import VideoMarkDetector 

13from utils.media_utils import * 

14from utils.utils import set_random_seed 

15from utils.pipeline_utils import is_video_pipeline, is_t2v_pipeline, is_i2v_pipeline 

16from utils.callbacks import DenoisingLatentsCollector 

17import random 

18 

19# Setup logging 

20logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 

21logger = logging.getLogger(__name__) 

22 

23# Constants 

24VAE_DOWNSAMPLE_FACTOR = 8 

25DEFAULT_CONFIDENCE_THRESHOLD = 0.6 

26 

27 

28class VideoMarkConfig(BaseConfig): 

29 """Config class for VideoMark algorithm.""" 

30 

31 def initialize_parameters(self) -> None: 

32 """Initialize algorithm-specific parameters.""" 

33 self.fpr = self.config_dict['fpr'] 

34 self.t = self.config_dict['prc_t'] 

35 self.var = self.config_dict['var'] 

36 self.threshold = self.config_dict['threshold'] 

37 self.sequence_length = self.config_dict['sequence_length'] # Length of the watermark sequence 

38 self.message_length = self.config_dict['message_length'] # Number of bits in each sequence 

39 self.message_sequence = np.random.randint(0, 2, size=(self.sequence_length, self.message_length)) # <= 512 bits for robustness 

40 self.shift = np.random.default_rng().integers(0, self.sequence_length - self.num_frames) 

41 self.message = self.message_sequence[self.shift : self.shift + self.num_frames] 

42 self.latents_height = self.image_size[0] // self.pipe.vae_scale_factor 

43 self.latents_width = self.image_size[1] // self.pipe.vae_scale_factor 

44 self.latents_channel = self.pipe.unet.config.in_channels 

45 self.n = self.latents_height * self.latents_width * self.latents_channel # Dimension of the latent space 

46 self.GF = galois.GF(2) 

47 

48 # Seeds for key generation 

49 self.gen_matrix_seed = self.config_dict['keygen']['gen_matrix_seed'] 

50 self.indice_seed = self.config_dict['keygen']['indice_seed'] 

51 self.one_time_pad_seed = self.config_dict['keygen']['one_time_pad_seed'] 

52 self.test_bits_seed = self.config_dict['keygen']['test_bits_seed'] 

53 self.permute_bits_seed = self.config_dict['keygen']['permute_bits_seed'] 

54 

55 # Seeds for encoding 

56 self.payload_seed = self.config_dict['encode']['payload_seed'] 

57 self.error_seed = self.config_dict['encode']['error_seed'] 

58 self.pseudogaussian_seed = self.config_dict['encode']['pseudogaussian_seed'] 

59 

60 @property 

61 def algorithm_name(self) -> str: 

62 """Return the algorithm name.""" 

63 return 'VideoMark' 

64 

65 def _get_message(length: int, window: int, seed=None) -> int: 

66 """Return a random start index for a subarray of size `window` in array of size `length`.""" 

67 rng = np.random.default_rng() 

68 return rng.integers(0, length - window) 

69 

70class VideoMarkUtils: 

71 """Utility class for VideoMark algorithm.""" 

72 

73 def __init__(self, config: VideoMarkConfig, *args, **kwargs) -> None: 

74 """Initialize PRC utility.""" 

75 self.config = config 

76 self.encoding_key, self.decoding_key = self._generate_encoding_key(self.config.message_length) 

77 

78 def _generate_encoding_key(self, message_length: int) -> Tuple[tuple, tuple]: 

79 """Generate encoding key for PRC algorithm.""" 

80 # Set basic scheme parameters 

81 num_test_bits = int(np.ceil(np.log2(1 / self.config.fpr))) 

82 secpar = int(np.log2(binom(self.config.n, self.config.t))) 

83 g = secpar 

84 k = message_length + g + num_test_bits 

85 r = self.config.n - k - secpar 

86 noise_rate = 1 - 2 ** (-secpar / g ** 2) 

87 

88 # Sample n by k generator matrix (all but the first n-r of these will be over-written) 

89 generator_matrix = self.config.GF.Random(shape=(self.config.n, k), seed=self.config.gen_matrix_seed) 

90 

91 # Sample scipy.sparse parity-check matrix together with the last n-r rows of the generator matrix 

92 row_indices = [] 

93 col_indices = [] 

94 data = [] 

95 for i, row in enumerate(range(r)): 

96 np.random.seed(self.config.indice_seed + i) 

97 chosen_indices = np.random.choice(self.config.n - r + row, self.config.t - 1, replace=False) 

98 chosen_indices = np.append(chosen_indices, self.config.n - r + row) 

99 row_indices.extend([row] * self.config.t) 

100 col_indices.extend(chosen_indices) 

101 data.extend([1] * self.config.t) 

102 generator_matrix[self.config.n - r + row] = generator_matrix[chosen_indices[:-1]].sum(axis=0) 

103 parity_check_matrix = csr_matrix((data, (row_indices, col_indices))) 

104 

105 # Compute scheme parameters 

106 max_bp_iter = int(np.log(self.config.n) / np.log(self.config.t)) 

107 

108 # Sample one-time pad and test bits 

109 one_time_pad = self.config.GF.Random(self.config.n, seed=self.config.one_time_pad_seed) 

110 test_bits = self.config.GF.Random(num_test_bits, seed=self.config.test_bits_seed) 

111 

112 # Permute bits 

113 np.random.seed(self.config.permute_bits_seed) 

114 permutation = np.random.permutation(self.config.n) 

115 generator_matrix = generator_matrix[permutation] 

116 one_time_pad = one_time_pad[permutation] 

117 parity_check_matrix = parity_check_matrix[:, permutation] 

118 

119 return (generator_matrix, one_time_pad, test_bits, g, noise_rate), (generator_matrix, parity_check_matrix, one_time_pad, self.config.fpr, noise_rate, test_bits, g, max_bp_iter, self.config.t) 

120 

121 def _encode_message(self, encoding_key: tuple, message: np.ndarray = None) -> np.ndarray: 

122 """Encode a message using PRC algorithm.""" 

123 generator_matrix, one_time_pad, test_bits, g, noise_rate = encoding_key 

124 n, k = generator_matrix.shape 

125 

126 if message is None: 

127 payload = np.concatenate((test_bits, self.config.GF.Random(k - len(test_bits), seed=self.config.payload_seed))) 

128 else: 

129 assert len(message) <= k-len(test_bits)-g, "Message is too long" 

130 payload = np.concatenate((test_bits, self.config.GF.Random(g, seed=self.config.payload_seed), self.config.GF(message), self.config.GF.Zeros(k-len(test_bits)-g-len(message)))) 

131 

132 np.random.seed(self.config.error_seed) 

133 error = self.config.GF(np.random.binomial(1, noise_rate, n)) 

134 

135 return 1 - 2 * torch.tensor(payload @ generator_matrix.T + one_time_pad + error, dtype=float) 

136 

137 

138 def _sample_prc_codeword(self, codeword: torch.Tensor, basis: torch.Tensor = None) -> torch.Tensor: 

139 """Sample a PRC codeword.""" 

140 codeword_np = codeword.numpy() 

141 np.random.seed(self.config.pseudogaussian_seed) 

142 pseudogaussian_np = codeword_np * np.abs(np.random.randn(*codeword_np.shape)) 

143 pseudogaussian = torch.from_numpy(pseudogaussian_np).to(dtype=torch.float32) 

144 if basis is None: 

145 return pseudogaussian 

146 return pseudogaussian @ basis.T 

147 

148 def inject_watermark(self) -> torch.Tensor: 

149 """Generate watermarked latents from PRC codeword.""" 

150 # Step 1: Encode message 

151 prc_codeword = torch.stack([self._encode_message(self.encoding_key, self.config.message[frame_index]) for frame_index in range(self.config.num_frames)]) 

152 # Step 2: Sample PRC codeword and get watermarked latents 

153 watermarked_latents = self._sample_prc_codeword(prc_codeword).reshape(self.config.num_frames, 1, self.config.latents_channel, self.config.latents_height, self.config.latents_width).to(self.config.device) 

154 

155 return watermarked_latents.permute(1, 2, 0, 3, 4) # (b, c, f, h, w) 

156 

157 

158 

159class VideoMarkWatermark(BaseWatermark): 

160 """Main class for VideoMark watermarking algorithm.""" 

161 

162 def __init__(self, watermark_config: VideoMarkConfig, *args, **kwargs) -> None: 

163 """Initialize the VideoShield watermarking algorithm. 

164  

165 Args: 

166 watermark_config: Configuration instance of the VideoMark algorithm 

167 """ 

168 self.config = watermark_config 

169 self.utils = VideoMarkUtils(self.config) 

170 

171 # Initialize detector with encryption keys from utils 

172 self.detector = VideoMarkDetector( 

173 message_sequence=self.config.message_sequence, 

174 watermark=self.config.message, 

175 num_frames=self.config.num_frames, 

176 var=self.config.var, 

177 decoding_key=self.utils.decoding_key, 

178 GF=self.config.GF, 

179 threshold=self.config.threshold, 

180 device=self.config.device 

181 ) 

182 

183 def _generate_watermarked_video(self, prompt: str, num_frames: Optional[int] = None, *args, **kwargs) -> List[Image.Image]: 

184 """Generate watermarked video using VideoMark algorithm. 

185  

186 Args: 

187 prompt: The input prompt for video generation 

188 num_frames: Number of frames to generate (uses config value if None) 

189  

190 Returns: 

191 List of generated watermarked video frames 

192 """ 

193 if not is_video_pipeline(self.config.pipe): 

194 raise ValueError(f"This pipeline ({self.config.pipe.__class__.__name__}) does not support video generation.") 

195 

196 # Set random seed for reproducibility 

197 set_random_seed(self.config.gen_seed) 

198 

199 # Use config frames if not specified 

200 frames_to_generate = num_frames if num_frames is not None else self.config.num_frames 

201 

202 # Set num_frames in config for watermark generation 

203 original_num_frames = getattr(self.config, 'num_frames', None) 

204 self.config.num_frames = frames_to_generate 

205 

206 try: 

207 # Generate watermarked latents 

208 watermarked_latents = self.utils.inject_watermark().to(self.config.pipe.unet.dtype) 

209 

210 # Save watermarked latents for visualization 

211 self.set_orig_watermarked_latents(watermarked_latents) 

212 

213 # Construct video generation parameters 

214 generation_params = { 

215 "num_inference_steps": self.config.num_inference_steps, 

216 "guidance_scale": self.config.guidance_scale, 

217 "height": self.config.image_size[0], 

218 "width": self.config.image_size[1], 

219 "num_frames": frames_to_generate, 

220 "latents": watermarked_latents 

221 } 

222 

223 # Add parameters from config.gen_kwargs 

224 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs: 

225 for key, value in self.config.gen_kwargs.items(): 

226 if key not in generation_params: 

227 generation_params[key] = value 

228 

229 # Use kwargs to override default parameters 

230 for key, value in kwargs.items(): 

231 if key != "num_frames": # Prevent overriding processed parameters 

232 generation_params[key] = value 

233 

234 # Handle I2V pipelines that need dimension permutation (like SVD) 

235 final_latents = watermarked_latents 

236 if is_i2v_pipeline(self.config.pipe): 

237 logger.info("I2V pipeline detected, permuting latent dimensions.") 

238 final_latents = final_latents.permute(0, 2, 1, 3, 4) # (b,c,f,h,w) -> (b,f,c,h,w) 

239 

240 generation_params["latents"] = final_latents 

241 

242 # Generate video 

243 output = self.config.pipe( 

244 prompt, 

245 **generation_params 

246 ) 

247 

248 # Extract frames from output 

249 if hasattr(output, 'frames'): 

250 frames = output.frames[0] 

251 elif hasattr(output, 'videos'): 

252 frames = output.videos[0] 

253 else: 

254 frames = output[0] if isinstance(output, tuple) else output 

255 

256 # Convert frames to PIL Images 

257 frame_list = [] 

258 for frame in frames: 

259 if not isinstance(frame, Image.Image): 

260 if isinstance(frame, np.ndarray): 

261 if frame.dtype == np.uint8: 

262 frame_pil = Image.fromarray(frame) 

263 else: 

264 frame_scaled = (frame * 255).astype(np.uint8) 

265 frame_pil = Image.fromarray(frame_scaled) 

266 elif isinstance(frame, torch.Tensor): 

267 if frame.dim() == 3 and frame.shape[-1] in [1, 3]: 

268 if frame.max() <= 1.0: 

269 frame = (frame * 255).byte() 

270 frame_np = frame.cpu().numpy() 

271 frame_pil = Image.fromarray(frame_np) 

272 else: 

273 raise ValueError(f"Unexpected tensor shape for frame: {frame.shape}") 

274 else: 

275 raise TypeError(f"Unexpected type for frame: {type(frame)}") 

276 else: 

277 frame_pil = frame 

278 

279 frame_list.append(frame_pil) 

280 

281 return frame_list 

282 

283 finally: 

284 # Restore original num_frames 

285 if original_num_frames is not None: 

286 self.config.num_frames = original_num_frames 

287 elif hasattr(self.config, 'num_frames'): 

288 delattr(self.config, 'num_frames') 

289 

290 def _get_video_latents(self, vae, video_frames, sample=True, rng_generator=None, permute=True): 

291 encoding_dist = vae.encode(video_frames).latent_dist 

292 if sample: 

293 encoding = encoding_dist.sample(generator=rng_generator) 

294 else: 

295 encoding = encoding_dist.mode() 

296 latents = (encoding * 0.18215).unsqueeze(0) 

297 if permute: 

298 latents = latents.permute(0, 2, 1, 3, 4) 

299 return latents 

300 

301 def _detect_watermark_in_video(self, 

302 video_frames: Union[torch.Tensor, List[Image.Image]], 

303 prompt: str = "", 

304 detector_type: str = 'bit_acc', 

305 *args, **kwargs) -> Dict[str, float]: 

306 """Detect VideoMark watermark in video. 

307  

308 Args: 

309 video_frames: Input video frames as tensor or list of PIL images 

310 prompt: Text prompt used for generation 

311  

312 Returns: 

313 Dictionary containing detection results 

314 """ 

315 # Use config values as defaults if not explicitly provided 

316 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

317 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps) 

318 

319 # Convert frames to tensor if needed 

320 if isinstance(video_frames, list): 

321 from torchvision import transforms 

322 frames_tensor = torch.stack([transforms.ToTensor()(frame) for frame in video_frames]) 

323 video_frames = 2.0 * frames_tensor - 1.0 # Normalize to [-1, 1] 

324 

325 video_frames = video_frames.to(self.config.device).to(self.config.pipe.vae.dtype) 

326 

327 # Get video latents 

328 with torch.no_grad(): 

329 # TODO: Add support for I2V pipeline 

330 video_latents = self._get_video_latents(self.config.pipe.vae, video_frames, sample=False) 

331 

332 # Perform DDIM inversion 

333 inversion_kwargs = {k: v for k, v in kwargs.items() 

334 if k not in ['guidance_scale', 'num_inference_steps']} 

335 

336 from diffusers import DDIMInverseScheduler 

337 original_scheduler = self.config.pipe.scheduler 

338 inverse_scheduler = DDIMInverseScheduler.from_config(original_scheduler.config) 

339 self.config.pipe.scheduler = inverse_scheduler 

340 

341 video_latents = video_latents.to(self.config.pipe.unet.dtype) 

342 

343 final_reversed_latents = self.config.pipe( 

344 prompt=prompt, 

345 latents=video_latents, 

346 num_inference_steps=num_steps_to_use, 

347 guidance_scale=guidance_scale_to_use, 

348 output_type='latent', 

349 **inversion_kwargs 

350 ).frames # [B, F, H, W, C](T2V) 

351 self.config.pipe.scheduler = original_scheduler 

352 

353 # Use detector for evaluation 

354 return self.detector.eval_watermark(final_reversed_latents, detector_type=detector_type) 

355 

356 

357 def get_data_for_visualize(self, 

358 video_frames: List[Image.Image], 

359 prompt: str = "", 

360 guidance_scale: float = 1, 

361 *args, **kwargs) -> DataForVisualization: 

362 """Get VideoMark visualization data. 

363  

364 This method generates the necessary data for visualizing VideoMark watermarks, 

365 including original watermarked latents and reversed latents from inversion. 

366  

367 Args: 

368 image: The image to visualize watermarks for (can be None for generation only) 

369 prompt: The text prompt used for generation 

370 guidance_scale: Guidance scale for generation and inversion 

371  

372 Returns: 

373 DataForVisualization object containing visualization data 

374 """ 

375 # Prepare PRC-specific data 

376 message_bits = torch.tensor(self.config.message, dtype=torch.float32) 

377 

378 # Get generator matrix 

379 generator_matrix = torch.tensor(np.array(self.utils.encoding_key[0], dtype=float), dtype=torch.float32) 

380 

381 # Get parity check matrix 

382 parity_check_matrix = self.utils.decoding_key[1] 

383 

384 # 1. Generate watermarked latents and collect intermediate data 

385 set_random_seed(self.config.gen_seed) 

386 

387 # Step 1: Encode message 

388 prc_codeword = torch.stack([self.utils._encode_message(self.utils.encoding_key, self.config.message[frame_index]) for frame_index in range(self.config.num_frames)]) 

389 

390 # Step 2: Sample PRC codeword 

391 pseudogaussian_noise = self.utils._sample_prc_codeword(prc_codeword) 

392 

393 # Step 3: Generate watermarked latents 

394 watermarked_latents = pseudogaussian_noise.reshape(self.config.num_frames, 1, self.config.latents_channel, self.config.latents_height, self.config.latents_width).to(self.config.device) 

395 watermarked_latents = watermarked_latents.permute(1, 2, 0, 3, 4) 

396 

397 # Use config values as defaults if not explicitly provided 

398 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

399 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps) 

400 

401 # Convert frames to tensor if needed 

402 if isinstance(video_frames, list): 

403 from torchvision import transforms 

404 frames_tensor = torch.stack([transforms.ToTensor()(frame) for frame in video_frames]) 

405 video_frames = 2.0 * frames_tensor - 1.0 # Normalize to [-1, 1] 

406 

407 video_frames = video_frames.to(self.config.device).to(self.config.pipe.vae.dtype) 

408 

409 # Get video latents 

410 with torch.no_grad(): 

411 # TODO: Add support for I2V pipeline 

412 video_latents = self._get_video_latents(self.config.pipe.vae, video_frames, sample=False) 

413 

414 # Perform DDIM inversion 

415 inversion_kwargs = {k: v for k, v in kwargs.items() 

416 if k not in ['guidance_scale', 'num_inference_steps']} 

417 

418 from diffusers import DDIMInverseScheduler 

419 original_scheduler = self.config.pipe.scheduler 

420 inverse_scheduler = DDIMInverseScheduler.from_config(original_scheduler.config) 

421 self.config.pipe.scheduler = inverse_scheduler 

422 collector = DenoisingLatentsCollector(save_every_n_steps=1, to_cpu=True) 

423 

424 video_latents = video_latents.to(self.config.pipe.unet.dtype) 

425 

426 final_reversed_latents = self.config.pipe( 

427 prompt=prompt, 

428 latents=video_latents, 

429 num_inference_steps=num_steps_to_use, 

430 guidance_scale=guidance_scale_to_use, 

431 output_type='latent', 

432 callback=collector, 

433 callback_steps=1, 

434 **inversion_kwargs 

435 ).frames # [B, F, H, W, C](T2V) 

436 self.config.pipe.scheduler = original_scheduler 

437 

438 reversed_latents = collector.latents_list # List[Tensor] 

439 

440 inverted_latents = final_reversed_latents 

441 recovered_prc = None 

442 try: 

443 if inverted_latents is not None: 

444 # Use the detector to recover the PRC codeword 

445 detection_result = self.detector.eval_watermark(inverted_latents) 

446 # The detector should have recovered_prc attribute or return it 

447 if hasattr(self.detector, 'recovered_prc') and self.detector.recovered_prc is not None: 

448 recovered_prc = self.detector.recovered_prc 

449 elif 'recovered_prc' in detection_result: 

450 recovered_prc = detection_result['recovered_prc'] 

451 else: 

452 print("Warning: Detector did not provide recovered_prc") 

453 except Exception as e: 

454 print(f"Warning: Could not recover PRC codeword for visualization: {e}") 

455 recovered_prc = None 

456 

457 return DataForVisualization( 

458 config=self.config, 

459 utils=self.utils, 

460 orig_watermarked_latents=watermarked_latents, 

461 watermarked_latents=watermarked_latents, 

462 reversed_latents=reversed_latents, 

463 inverted_latents=inverted_latents, 

464 video_frames=video_frames, 

465 # PRC-specific data 

466 message_bits= message_bits, 

467 prc_codeword=torch.tensor(prc_codeword, dtype=torch.float32), 

468 pseudogaussian_noise=torch.tensor(pseudogaussian_noise, dtype=torch.float32), 

469 generator_matrix=generator_matrix, 

470 parity_check_matrix=parity_check_matrix, 

471 threshold=self.config.threshold, 

472 recovered_prc=torch.tensor(recovered_prc, dtype=torch.float32) if recovered_prc is not None else None 

473 )