Coverage for watermark / videomark / video_mark.py: 97.42%
233 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from ..base import BaseWatermark, BaseConfig
2import torch
3import numpy as np
4from typing import Dict, Tuple, Any, Optional, List, Union
5from PIL import Image
6import galois
7from scipy.sparse import csr_matrix
8from scipy.special import binom
9import logging
10from functools import reduce
11from visualize.data_for_visualization import DataForVisualization
12from detection.videomark.videomark_detection import VideoMarkDetector
13from utils.media_utils import *
14from utils.utils import set_random_seed
15from utils.pipeline_utils import is_video_pipeline, is_t2v_pipeline, is_i2v_pipeline
16from utils.callbacks import DenoisingLatentsCollector
17import random
19# Setup logging
20logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
21logger = logging.getLogger(__name__)
23# Constants
24VAE_DOWNSAMPLE_FACTOR = 8
25DEFAULT_CONFIDENCE_THRESHOLD = 0.6
28class VideoMarkConfig(BaseConfig):
29 """Config class for VideoMark algorithm."""
31 def initialize_parameters(self) -> None:
32 """Initialize algorithm-specific parameters."""
33 self.fpr = self.config_dict['fpr']
34 self.t = self.config_dict['prc_t']
35 self.var = self.config_dict['var']
36 self.threshold = self.config_dict['threshold']
37 self.sequence_length = self.config_dict['sequence_length'] # Length of the watermark sequence
38 self.message_length = self.config_dict['message_length'] # Number of bits in each sequence
39 self.message_sequence = np.random.randint(0, 2, size=(self.sequence_length, self.message_length)) # <= 512 bits for robustness
40 self.shift = np.random.default_rng().integers(0, self.sequence_length - self.num_frames)
41 self.message = self.message_sequence[self.shift : self.shift + self.num_frames]
42 self.latents_height = self.image_size[0] // self.pipe.vae_scale_factor
43 self.latents_width = self.image_size[1] // self.pipe.vae_scale_factor
44 self.latents_channel = self.pipe.unet.config.in_channels
45 self.n = self.latents_height * self.latents_width * self.latents_channel # Dimension of the latent space
46 self.GF = galois.GF(2)
48 # Seeds for key generation
49 self.gen_matrix_seed = self.config_dict['keygen']['gen_matrix_seed']
50 self.indice_seed = self.config_dict['keygen']['indice_seed']
51 self.one_time_pad_seed = self.config_dict['keygen']['one_time_pad_seed']
52 self.test_bits_seed = self.config_dict['keygen']['test_bits_seed']
53 self.permute_bits_seed = self.config_dict['keygen']['permute_bits_seed']
55 # Seeds for encoding
56 self.payload_seed = self.config_dict['encode']['payload_seed']
57 self.error_seed = self.config_dict['encode']['error_seed']
58 self.pseudogaussian_seed = self.config_dict['encode']['pseudogaussian_seed']
60 @property
61 def algorithm_name(self) -> str:
62 """Return the algorithm name."""
63 return 'VideoMark'
65 def _get_message(length: int, window: int, seed=None) -> int:
66 """Return a random start index for a subarray of size `window` in array of size `length`."""
67 rng = np.random.default_rng()
68 return rng.integers(0, length - window)
70class VideoMarkUtils:
71 """Utility class for VideoMark algorithm."""
73 def __init__(self, config: VideoMarkConfig, *args, **kwargs) -> None:
74 """Initialize PRC utility."""
75 self.config = config
76 self.encoding_key, self.decoding_key = self._generate_encoding_key(self.config.message_length)
78 def _generate_encoding_key(self, message_length: int) -> Tuple[tuple, tuple]:
79 """Generate encoding key for PRC algorithm."""
80 # Set basic scheme parameters
81 num_test_bits = int(np.ceil(np.log2(1 / self.config.fpr)))
82 secpar = int(np.log2(binom(self.config.n, self.config.t)))
83 g = secpar
84 k = message_length + g + num_test_bits
85 r = self.config.n - k - secpar
86 noise_rate = 1 - 2 ** (-secpar / g ** 2)
88 # Sample n by k generator matrix (all but the first n-r of these will be over-written)
89 generator_matrix = self.config.GF.Random(shape=(self.config.n, k), seed=self.config.gen_matrix_seed)
91 # Sample scipy.sparse parity-check matrix together with the last n-r rows of the generator matrix
92 row_indices = []
93 col_indices = []
94 data = []
95 for i, row in enumerate(range(r)):
96 np.random.seed(self.config.indice_seed + i)
97 chosen_indices = np.random.choice(self.config.n - r + row, self.config.t - 1, replace=False)
98 chosen_indices = np.append(chosen_indices, self.config.n - r + row)
99 row_indices.extend([row] * self.config.t)
100 col_indices.extend(chosen_indices)
101 data.extend([1] * self.config.t)
102 generator_matrix[self.config.n - r + row] = generator_matrix[chosen_indices[:-1]].sum(axis=0)
103 parity_check_matrix = csr_matrix((data, (row_indices, col_indices)))
105 # Compute scheme parameters
106 max_bp_iter = int(np.log(self.config.n) / np.log(self.config.t))
108 # Sample one-time pad and test bits
109 one_time_pad = self.config.GF.Random(self.config.n, seed=self.config.one_time_pad_seed)
110 test_bits = self.config.GF.Random(num_test_bits, seed=self.config.test_bits_seed)
112 # Permute bits
113 np.random.seed(self.config.permute_bits_seed)
114 permutation = np.random.permutation(self.config.n)
115 generator_matrix = generator_matrix[permutation]
116 one_time_pad = one_time_pad[permutation]
117 parity_check_matrix = parity_check_matrix[:, permutation]
119 return (generator_matrix, one_time_pad, test_bits, g, noise_rate), (generator_matrix, parity_check_matrix, one_time_pad, self.config.fpr, noise_rate, test_bits, g, max_bp_iter, self.config.t)
121 def _encode_message(self, encoding_key: tuple, message: np.ndarray = None) -> np.ndarray:
122 """Encode a message using PRC algorithm."""
123 generator_matrix, one_time_pad, test_bits, g, noise_rate = encoding_key
124 n, k = generator_matrix.shape
126 if message is None:
127 payload = np.concatenate((test_bits, self.config.GF.Random(k - len(test_bits), seed=self.config.payload_seed)))
128 else:
129 assert len(message) <= k-len(test_bits)-g, "Message is too long"
130 payload = np.concatenate((test_bits, self.config.GF.Random(g, seed=self.config.payload_seed), self.config.GF(message), self.config.GF.Zeros(k-len(test_bits)-g-len(message))))
132 np.random.seed(self.config.error_seed)
133 error = self.config.GF(np.random.binomial(1, noise_rate, n))
135 return 1 - 2 * torch.tensor(payload @ generator_matrix.T + one_time_pad + error, dtype=float)
138 def _sample_prc_codeword(self, codeword: torch.Tensor, basis: torch.Tensor = None) -> torch.Tensor:
139 """Sample a PRC codeword."""
140 codeword_np = codeword.numpy()
141 np.random.seed(self.config.pseudogaussian_seed)
142 pseudogaussian_np = codeword_np * np.abs(np.random.randn(*codeword_np.shape))
143 pseudogaussian = torch.from_numpy(pseudogaussian_np).to(dtype=torch.float32)
144 if basis is None:
145 return pseudogaussian
146 return pseudogaussian @ basis.T
148 def inject_watermark(self) -> torch.Tensor:
149 """Generate watermarked latents from PRC codeword."""
150 # Step 1: Encode message
151 prc_codeword = torch.stack([self._encode_message(self.encoding_key, self.config.message[frame_index]) for frame_index in range(self.config.num_frames)])
152 # Step 2: Sample PRC codeword and get watermarked latents
153 watermarked_latents = self._sample_prc_codeword(prc_codeword).reshape(self.config.num_frames, 1, self.config.latents_channel, self.config.latents_height, self.config.latents_width).to(self.config.device)
155 return watermarked_latents.permute(1, 2, 0, 3, 4) # (b, c, f, h, w)
159class VideoMarkWatermark(BaseWatermark):
160 """Main class for VideoMark watermarking algorithm."""
162 def __init__(self, watermark_config: VideoMarkConfig, *args, **kwargs) -> None:
163 """Initialize the VideoShield watermarking algorithm.
165 Args:
166 watermark_config: Configuration instance of the VideoMark algorithm
167 """
168 self.config = watermark_config
169 self.utils = VideoMarkUtils(self.config)
171 # Initialize detector with encryption keys from utils
172 self.detector = VideoMarkDetector(
173 message_sequence=self.config.message_sequence,
174 watermark=self.config.message,
175 num_frames=self.config.num_frames,
176 var=self.config.var,
177 decoding_key=self.utils.decoding_key,
178 GF=self.config.GF,
179 threshold=self.config.threshold,
180 device=self.config.device
181 )
183 def _generate_watermarked_video(self, prompt: str, num_frames: Optional[int] = None, *args, **kwargs) -> List[Image.Image]:
184 """Generate watermarked video using VideoMark algorithm.
186 Args:
187 prompt: The input prompt for video generation
188 num_frames: Number of frames to generate (uses config value if None)
190 Returns:
191 List of generated watermarked video frames
192 """
193 if not is_video_pipeline(self.config.pipe):
194 raise ValueError(f"This pipeline ({self.config.pipe.__class__.__name__}) does not support video generation.")
196 # Set random seed for reproducibility
197 set_random_seed(self.config.gen_seed)
199 # Use config frames if not specified
200 frames_to_generate = num_frames if num_frames is not None else self.config.num_frames
202 # Set num_frames in config for watermark generation
203 original_num_frames = getattr(self.config, 'num_frames', None)
204 self.config.num_frames = frames_to_generate
206 try:
207 # Generate watermarked latents
208 watermarked_latents = self.utils.inject_watermark().to(self.config.pipe.unet.dtype)
210 # Save watermarked latents for visualization
211 self.set_orig_watermarked_latents(watermarked_latents)
213 # Construct video generation parameters
214 generation_params = {
215 "num_inference_steps": self.config.num_inference_steps,
216 "guidance_scale": self.config.guidance_scale,
217 "height": self.config.image_size[0],
218 "width": self.config.image_size[1],
219 "num_frames": frames_to_generate,
220 "latents": watermarked_latents
221 }
223 # Add parameters from config.gen_kwargs
224 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
225 for key, value in self.config.gen_kwargs.items():
226 if key not in generation_params:
227 generation_params[key] = value
229 # Use kwargs to override default parameters
230 for key, value in kwargs.items():
231 if key != "num_frames": # Prevent overriding processed parameters
232 generation_params[key] = value
234 # Handle I2V pipelines that need dimension permutation (like SVD)
235 final_latents = watermarked_latents
236 if is_i2v_pipeline(self.config.pipe):
237 logger.info("I2V pipeline detected, permuting latent dimensions.")
238 final_latents = final_latents.permute(0, 2, 1, 3, 4) # (b,c,f,h,w) -> (b,f,c,h,w)
240 generation_params["latents"] = final_latents
242 # Generate video
243 output = self.config.pipe(
244 prompt,
245 **generation_params
246 )
248 # Extract frames from output
249 if hasattr(output, 'frames'):
250 frames = output.frames[0]
251 elif hasattr(output, 'videos'):
252 frames = output.videos[0]
253 else:
254 frames = output[0] if isinstance(output, tuple) else output
256 # Convert frames to PIL Images
257 frame_list = []
258 for frame in frames:
259 if not isinstance(frame, Image.Image):
260 if isinstance(frame, np.ndarray):
261 if frame.dtype == np.uint8:
262 frame_pil = Image.fromarray(frame)
263 else:
264 frame_scaled = (frame * 255).astype(np.uint8)
265 frame_pil = Image.fromarray(frame_scaled)
266 elif isinstance(frame, torch.Tensor):
267 if frame.dim() == 3 and frame.shape[-1] in [1, 3]:
268 if frame.max() <= 1.0:
269 frame = (frame * 255).byte()
270 frame_np = frame.cpu().numpy()
271 frame_pil = Image.fromarray(frame_np)
272 else:
273 raise ValueError(f"Unexpected tensor shape for frame: {frame.shape}")
274 else:
275 raise TypeError(f"Unexpected type for frame: {type(frame)}")
276 else:
277 frame_pil = frame
279 frame_list.append(frame_pil)
281 return frame_list
283 finally:
284 # Restore original num_frames
285 if original_num_frames is not None:
286 self.config.num_frames = original_num_frames
287 elif hasattr(self.config, 'num_frames'):
288 delattr(self.config, 'num_frames')
290 def _get_video_latents(self, vae, video_frames, sample=True, rng_generator=None, permute=True):
291 encoding_dist = vae.encode(video_frames).latent_dist
292 if sample:
293 encoding = encoding_dist.sample(generator=rng_generator)
294 else:
295 encoding = encoding_dist.mode()
296 latents = (encoding * 0.18215).unsqueeze(0)
297 if permute:
298 latents = latents.permute(0, 2, 1, 3, 4)
299 return latents
301 def _detect_watermark_in_video(self,
302 video_frames: Union[torch.Tensor, List[Image.Image]],
303 prompt: str = "",
304 detector_type: str = 'bit_acc',
305 *args, **kwargs) -> Dict[str, float]:
306 """Detect VideoMark watermark in video.
308 Args:
309 video_frames: Input video frames as tensor or list of PIL images
310 prompt: Text prompt used for generation
312 Returns:
313 Dictionary containing detection results
314 """
315 # Use config values as defaults if not explicitly provided
316 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
317 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps)
319 # Convert frames to tensor if needed
320 if isinstance(video_frames, list):
321 from torchvision import transforms
322 frames_tensor = torch.stack([transforms.ToTensor()(frame) for frame in video_frames])
323 video_frames = 2.0 * frames_tensor - 1.0 # Normalize to [-1, 1]
325 video_frames = video_frames.to(self.config.device).to(self.config.pipe.vae.dtype)
327 # Get video latents
328 with torch.no_grad():
329 # TODO: Add support for I2V pipeline
330 video_latents = self._get_video_latents(self.config.pipe.vae, video_frames, sample=False)
332 # Perform DDIM inversion
333 inversion_kwargs = {k: v for k, v in kwargs.items()
334 if k not in ['guidance_scale', 'num_inference_steps']}
336 from diffusers import DDIMInverseScheduler
337 original_scheduler = self.config.pipe.scheduler
338 inverse_scheduler = DDIMInverseScheduler.from_config(original_scheduler.config)
339 self.config.pipe.scheduler = inverse_scheduler
341 video_latents = video_latents.to(self.config.pipe.unet.dtype)
343 final_reversed_latents = self.config.pipe(
344 prompt=prompt,
345 latents=video_latents,
346 num_inference_steps=num_steps_to_use,
347 guidance_scale=guidance_scale_to_use,
348 output_type='latent',
349 **inversion_kwargs
350 ).frames # [B, F, H, W, C](T2V)
351 self.config.pipe.scheduler = original_scheduler
353 # Use detector for evaluation
354 return self.detector.eval_watermark(final_reversed_latents, detector_type=detector_type)
357 def get_data_for_visualize(self,
358 video_frames: List[Image.Image],
359 prompt: str = "",
360 guidance_scale: float = 1,
361 *args, **kwargs) -> DataForVisualization:
362 """Get VideoMark visualization data.
364 This method generates the necessary data for visualizing VideoMark watermarks,
365 including original watermarked latents and reversed latents from inversion.
367 Args:
368 image: The image to visualize watermarks for (can be None for generation only)
369 prompt: The text prompt used for generation
370 guidance_scale: Guidance scale for generation and inversion
372 Returns:
373 DataForVisualization object containing visualization data
374 """
375 # Prepare PRC-specific data
376 message_bits = torch.tensor(self.config.message, dtype=torch.float32)
378 # Get generator matrix
379 generator_matrix = torch.tensor(np.array(self.utils.encoding_key[0], dtype=float), dtype=torch.float32)
381 # Get parity check matrix
382 parity_check_matrix = self.utils.decoding_key[1]
384 # 1. Generate watermarked latents and collect intermediate data
385 set_random_seed(self.config.gen_seed)
387 # Step 1: Encode message
388 prc_codeword = torch.stack([self.utils._encode_message(self.utils.encoding_key, self.config.message[frame_index]) for frame_index in range(self.config.num_frames)])
390 # Step 2: Sample PRC codeword
391 pseudogaussian_noise = self.utils._sample_prc_codeword(prc_codeword)
393 # Step 3: Generate watermarked latents
394 watermarked_latents = pseudogaussian_noise.reshape(self.config.num_frames, 1, self.config.latents_channel, self.config.latents_height, self.config.latents_width).to(self.config.device)
395 watermarked_latents = watermarked_latents.permute(1, 2, 0, 3, 4)
397 # Use config values as defaults if not explicitly provided
398 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
399 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps)
401 # Convert frames to tensor if needed
402 if isinstance(video_frames, list):
403 from torchvision import transforms
404 frames_tensor = torch.stack([transforms.ToTensor()(frame) for frame in video_frames])
405 video_frames = 2.0 * frames_tensor - 1.0 # Normalize to [-1, 1]
407 video_frames = video_frames.to(self.config.device).to(self.config.pipe.vae.dtype)
409 # Get video latents
410 with torch.no_grad():
411 # TODO: Add support for I2V pipeline
412 video_latents = self._get_video_latents(self.config.pipe.vae, video_frames, sample=False)
414 # Perform DDIM inversion
415 inversion_kwargs = {k: v for k, v in kwargs.items()
416 if k not in ['guidance_scale', 'num_inference_steps']}
418 from diffusers import DDIMInverseScheduler
419 original_scheduler = self.config.pipe.scheduler
420 inverse_scheduler = DDIMInverseScheduler.from_config(original_scheduler.config)
421 self.config.pipe.scheduler = inverse_scheduler
422 collector = DenoisingLatentsCollector(save_every_n_steps=1, to_cpu=True)
424 video_latents = video_latents.to(self.config.pipe.unet.dtype)
426 final_reversed_latents = self.config.pipe(
427 prompt=prompt,
428 latents=video_latents,
429 num_inference_steps=num_steps_to_use,
430 guidance_scale=guidance_scale_to_use,
431 output_type='latent',
432 callback=collector,
433 callback_steps=1,
434 **inversion_kwargs
435 ).frames # [B, F, H, W, C](T2V)
436 self.config.pipe.scheduler = original_scheduler
438 reversed_latents = collector.latents_list # List[Tensor]
440 inverted_latents = final_reversed_latents
441 recovered_prc = None
442 try:
443 if inverted_latents is not None:
444 # Use the detector to recover the PRC codeword
445 detection_result = self.detector.eval_watermark(inverted_latents)
446 # The detector should have recovered_prc attribute or return it
447 if hasattr(self.detector, 'recovered_prc') and self.detector.recovered_prc is not None:
448 recovered_prc = self.detector.recovered_prc
449 elif 'recovered_prc' in detection_result:
450 recovered_prc = detection_result['recovered_prc']
451 else:
452 print("Warning: Detector did not provide recovered_prc")
453 except Exception as e:
454 print(f"Warning: Could not recover PRC codeword for visualization: {e}")
455 recovered_prc = None
457 return DataForVisualization(
458 config=self.config,
459 utils=self.utils,
460 orig_watermarked_latents=watermarked_latents,
461 watermarked_latents=watermarked_latents,
462 reversed_latents=reversed_latents,
463 inverted_latents=inverted_latents,
464 video_frames=video_frames,
465 # PRC-specific data
466 message_bits= message_bits,
467 prc_codeword=torch.tensor(prc_codeword, dtype=torch.float32),
468 pseudogaussian_noise=torch.tensor(pseudogaussian_noise, dtype=torch.float32),
469 generator_matrix=generator_matrix,
470 parity_check_matrix=parity_check_matrix,
471 threshold=self.config.threshold,
472 recovered_prc=torch.tensor(recovered_prc, dtype=torch.float32) if recovered_prc is not None else None
473 )