Coverage for watermark / videoshield / video_shield.py: 94.12%
204 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from ..base import BaseWatermark, BaseConfig
2import torch
3import numpy as np
4from typing import Dict, Tuple, Any, Optional, List, Union
5from PIL import Image
6from Crypto.Cipher import ChaCha20
7from Crypto.Random import get_random_bytes
8import logging
9from scipy.stats import norm, truncnorm
10from functools import reduce
11from visualize.data_for_visualization import DataForVisualization
12from detection.videoshield.videoshield_detection import VideoShieldDetector
13from utils.media_utils import *
14from utils.utils import set_random_seed
15from utils.pipeline_utils import is_video_pipeline, is_t2v_pipeline, is_i2v_pipeline
16from utils.callbacks import DenoisingLatentsCollector
17import random
19# Setup logging
20logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
21logger = logging.getLogger(__name__)
23# Constants
24VAE_DOWNSAMPLE_FACTOR = 8
25DEFAULT_CONFIDENCE_THRESHOLD = 0.6
28class VideoShieldConfig(BaseConfig):
29 """Config class for VideoShield algorithm."""
31 def initialize_parameters(self) -> None:
32 """Initialize VideoShield configuration."""
33 # Repetition factors for video watermarking
34 self.k_f: int = self.config_dict['k_f'] # Frame repetition factor
35 self.k_c: int = self.config_dict['k_c'] # Channel repetition factor
36 self.k_h: int = self.config_dict['k_h'] # Height repetition factor
37 self.k_w: int = self.config_dict['k_w'] # Width repetition factor
39 # Temporal threshold for localization
40 self.t_temp: float = self.config_dict['t_temp']
42 # HSTR (Hierarchical Spatio-Temporal Refinement) parameters
43 self.hstr_levels: int = self.config_dict['hstr_levels']
44 self.t_wm: List[float] = self.config_dict['t_wm']
45 self.t_orig: List[float] = self.config_dict['t_orig']
47 # Watermark generation parameters
48 self.wm_key: int = self.config_dict.get('wm_key', 42)
49 self.chacha_key_seed: int = self.config_dict.get('chacha_key_seed', 123456)
50 self.chacha_nonce_seed: int = self.config_dict.get('chacha_nonce_seed', 789012)
52 # Detection threshold
53 self.threshold: float = self.config_dict.get('threshold', 0.6)
55 # Calculate latent dimensions
56 self.latents_height = self.image_size[0] // VAE_DOWNSAMPLE_FACTOR
57 self.latents_width = self.image_size[1] // VAE_DOWNSAMPLE_FACTOR
59 # Adjust repetition factors if they exceed dimensions to avoid empty tensors
60 if hasattr(self, 'num_frames') and self.num_frames > 0:
61 if self.k_f > self.num_frames:
62 logger.warning(f"k_f ({self.k_f}) is larger than num_frames ({self.num_frames}). Adjusting k_f to {self.num_frames}.")
63 self.k_f = self.num_frames
65 if self.k_h > self.latents_height:
66 logger.warning(f"k_h ({self.k_h}) is larger than latents_height ({self.latents_height}). Adjusting k_h to {self.latents_height}.")
67 self.k_h = self.latents_height
69 if self.k_w > self.latents_width:
70 logger.warning(f"k_w ({self.k_w}) is larger than latents_width ({self.latents_width}). Adjusting k_w to {self.latents_width}.")
71 self.k_w = self.latents_width
73 # Generate watermark pattern
74 generator = torch.Generator(device=self.device)
75 generator.manual_seed(self.wm_key)
77 # For video, we need frame dimension
78 if hasattr(self, 'num_frames') and self.num_frames > 0:
79 self.watermark = torch.randint(
80 0, 2,
81 [1, 4 // self.k_c, self.num_frames // self.k_f,
82 self.latents_height // self.k_h, self.latents_width // self.k_w],
83 generator=generator, device=self.device
84 )
85 else:
86 # Fallback for image-only case
87 self.watermark = torch.randint(
88 0, 2,
89 [1, 4 // self.k_c, self.latents_height // self.k_h, self.latents_width // self.k_w],
90 generator=generator, device=self.device
91 )
93 @property
94 def algorithm_name(self) -> str:
95 """Return the algorithm name."""
96 return 'VideoShield'
99class VideoShieldUtils:
100 """Utility class for VideoShield algorithm."""
102 def __init__(self, config: VideoShieldConfig, *args, **kwargs) -> None:
103 """Initialize the VideoShield watermarking utility."""
104 self.config = config
106 # Generate deterministic cryptographic keys using seeds
107 self.chacha_key = self._get_bytes_with_seed(self.config.chacha_key_seed, 32)
108 self.chacha_nonce = self._get_bytes_with_seed(self.config.chacha_nonce_seed, 12)
110 # Calculate latent space dimensions
111 if hasattr(self.config, 'num_frames') and self.config.num_frames > 0:
112 # Video case: include frame dimension
113 self.latentlength = 4 * self.config.num_frames * self.config.latents_height * self.config.latents_width
114 else:
115 # Image case: no frame dimension
116 self.latentlength = 4 * self.config.latents_height * self.config.latents_width
118 # Calculate watermark length based on repetition factors
119 self.marklength = self.latentlength // (self.config.k_f * self.config.k_c * self.config.k_h * self.config.k_w)
121 # Voting threshold for watermark extraction
122 if self.config.k_f == 1 and self.config.k_c == 1 and self.config.k_h == 1 and self.config.k_w == 1:
123 self.vote_threshold = 1
124 else:
125 self.vote_threshold = (self.config.k_f * self.config.k_c * self.config.k_h * self.config.k_w) // 2
127 def _get_bytes_with_seed(self, seed: int, n: int) -> bytes:
128 """Generate deterministic bytes using a seed."""
129 random.seed(seed)
130 return bytes(random.getrandbits(8) for _ in range(n))
132 def _stream_key_encrypt(self, sd: np.ndarray) -> np.ndarray:
133 """Encrypt the watermark using ChaCha20 cipher."""
134 try:
135 cipher = ChaCha20.new(key=self.chacha_key, nonce=self.chacha_nonce)
136 m_byte = cipher.encrypt(np.packbits(sd).tobytes())
137 m_bit = np.unpackbits(np.frombuffer(m_byte, dtype=np.uint8))
138 return m_bit[:len(sd)] # Ensure same length as input
139 except Exception as e:
140 logger.error(f"Encryption failed: {e}")
141 raise RuntimeError("Encryption failed") from e
143 def _truncated_sampling(self, message: np.ndarray) -> torch.Tensor:
144 """Truncated Gaussian sampling for watermarking.
146 Args:
147 message: Binary message as a numpy array of 0s and 1s
149 Returns:
150 Watermarked latents tensor
151 """
152 z = np.zeros(self.latentlength)
153 denominator = 2.0
154 ppf = [norm.ppf(j / denominator) for j in range(int(denominator) + 1)]
156 for i in range(self.latentlength):
157 dec_mes = reduce(lambda a, b: 2 * a + b, message[i : i + 1])
158 dec_mes = int(dec_mes)
159 z[i] = truncnorm.rvs(ppf[dec_mes], ppf[dec_mes + 1])
161 # Reshape based on whether this is video or image
162 if hasattr(self.config, 'num_frames') and self.config.num_frames > 0:
163 # Video: (batch, channels, frames, height, width)
164 z = torch.from_numpy(z).reshape(
165 1, 4, self.config.num_frames, self.config.latents_height, self.config.latents_width
166 ).float()
167 else:
168 # Image: (batch, channels, height, width)
169 z = torch.from_numpy(z).reshape(
170 1, 4, self.config.latents_height, self.config.latents_width
171 ).float()
173 return z.to(self.config.device)
175 def create_watermark_and_return_w(self) -> torch.Tensor:
176 """Create watermark pattern and return watermarked initial latents."""
177 # Create repeated watermark pattern
178 if hasattr(self.config, 'num_frames') and self.config.num_frames > 0:
179 # Video case: repeat along all dimensions including frames
180 sd = self.config.watermark.repeat(1, self.config.k_c, self.config.k_f, self.config.k_h, self.config.k_w)
181 else:
182 # Image case: repeat along spatial dimensions only
183 sd = self.config.watermark.repeat(1, self.config.k_c, self.config.k_h, self.config.k_w)
185 # Encrypt the repeated watermark
186 m = self._stream_key_encrypt(sd.flatten().cpu().numpy())
188 # Generate watermarked latents using truncated sampling
189 w = self._truncated_sampling(m)
191 return w
193class VideoShieldWatermark(BaseWatermark):
194 """Main class for VideoShield watermarking algorithm."""
196 def __init__(self, watermark_config: VideoShieldConfig, *args, **kwargs) -> None:
197 """Initialize the VideoShield watermarking algorithm.
199 Args:
200 watermark_config: Configuration instance of the VideoShield algorithm
201 """
202 self.config = watermark_config
203 self.utils = VideoShieldUtils(self.config)
205 # Initialize detector with encryption keys from utils
206 self.detector = VideoShieldDetector(
207 watermark=self.config.watermark,
208 threshold=self.config.threshold,
209 device=self.config.device,
210 chacha_key=self.utils.chacha_key,
211 chacha_nonce=self.utils.chacha_nonce,
212 height=self.config.image_size[0],
213 width=self.config.image_size[1],
214 num_frames=self.config.num_frames,
215 k_f=self.config.k_f,
216 k_c=self.config.k_c,
217 k_h=self.config.k_h,
218 k_w=self.config.k_w
219 )
221 def _generate_watermarked_video(self, prompt: str, num_frames: Optional[int] = None, *args, **kwargs) -> List[Image.Image]:
222 """Generate watermarked video using VideoShield algorithm.
224 Args:
225 prompt: The input prompt for video generation
226 num_frames: Number of frames to generate (uses config value if None)
228 Returns:
229 List of generated watermarked video frames
230 """
231 if not is_video_pipeline(self.config.pipe):
232 raise ValueError(f"This pipeline ({self.config.pipe.__class__.__name__}) does not support video generation.")
234 # Set random seed for reproducibility
235 set_random_seed(self.config.gen_seed)
237 # Use config frames if not specified
238 frames_to_generate = num_frames if num_frames is not None else self.config.num_frames
240 # Set num_frames in config for watermark generation
241 original_num_frames = getattr(self.config, 'num_frames', None)
242 self.config.num_frames = frames_to_generate
244 try:
245 # Generate watermarked latents
246 watermarked_latents = self.utils.create_watermark_and_return_w().to(self.config.pipe.unet.dtype)
248 # Save watermarked latents for visualization
249 self.set_orig_watermarked_latents(watermarked_latents)
251 # Construct video generation parameters
252 generation_params = {
253 "num_inference_steps": self.config.num_inference_steps,
254 "guidance_scale": self.config.guidance_scale,
255 "height": self.config.image_size[0],
256 "width": self.config.image_size[1],
257 "num_frames": frames_to_generate,
258 "latents": watermarked_latents
259 }
261 # Add parameters from config.gen_kwargs
262 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs:
263 for key, value in self.config.gen_kwargs.items():
264 if key not in generation_params:
265 generation_params[key] = value
267 # Use kwargs to override default parameters
268 for key, value in kwargs.items():
269 if key != "num_frames": # Prevent overriding processed parameters
270 generation_params[key] = value
272 # Handle I2V pipelines that need dimension permutation (like SVD)
273 final_latents = watermarked_latents
274 if is_i2v_pipeline(self.config.pipe):
275 logger.info("I2V pipeline detected, permuting latent dimensions.")
276 final_latents = final_latents.permute(0, 2, 1, 3, 4) # (b,c,f,h,w) -> (b,f,c,h,w)
278 generation_params["latents"] = final_latents
280 # Generate video
281 output = self.config.pipe(
282 prompt,
283 **generation_params
284 )
286 # Extract frames from output
287 if hasattr(output, 'frames'):
288 frames = output.frames[0]
289 elif hasattr(output, 'videos'):
290 frames = output.videos[0]
291 else:
292 frames = output[0] if isinstance(output, tuple) else output
294 # Convert frames to PIL Images
295 frame_list = []
296 for frame in frames:
297 if not isinstance(frame, Image.Image):
298 if isinstance(frame, np.ndarray):
299 if frame.dtype == np.uint8:
300 frame_pil = Image.fromarray(frame)
301 else:
302 frame_scaled = (frame * 255).astype(np.uint8)
303 frame_pil = Image.fromarray(frame_scaled)
304 elif isinstance(frame, torch.Tensor):
305 if frame.dim() == 3 and frame.shape[-1] in [1, 3]:
306 if frame.max() <= 1.0:
307 frame = (frame * 255).byte()
308 frame_np = frame.cpu().numpy()
309 frame_pil = Image.fromarray(frame_np)
310 else:
311 raise ValueError(f"Unexpected tensor shape for frame: {frame.shape}")
312 else:
313 raise TypeError(f"Unexpected type for frame: {type(frame)}")
314 else:
315 frame_pil = frame
317 frame_list.append(frame_pil)
319 return frame_list
321 finally:
322 # Restore original num_frames
323 if original_num_frames is not None:
324 self.config.num_frames = original_num_frames
325 elif hasattr(self.config, 'num_frames'):
326 delattr(self.config, 'num_frames')
328 # def _detect_watermark_in_image(self, image: Image.Image, prompt: str = "",
329 # *args, **kwargs) -> Dict[str, float]:
330 # """Detect VideoShield watermark in image.
332 # Args:
333 # image: Input PIL image
334 # prompt: Text prompt used for generation
336 # Returns:
337 # Dictionary containing detection results
338 # """
339 # # Use config values as defaults if not explicitly provided
340 # guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
341 # num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps)
343 # # Get text embeddings
344 # do_classifier_free_guidance = (guidance_scale_to_use > 1.0)
345 # prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt(
346 # prompt=prompt,
347 # device=self.config.device,
348 # do_classifier_free_guidance=do_classifier_free_guidance,
349 # num_images_per_prompt=1,
350 # )
352 # if do_classifier_free_guidance:
353 # text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds])
354 # else:
355 # text_embeddings = prompt_embeds
357 # # Preprocess image
358 # image_tensor = transform_to_model_format(
359 # image, target_size=self.config.image_size[0]
360 # ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device)
362 # # Get image latents
363 # image_latents = get_media_latents(
364 # pipe=self.config.pipe,
365 # media=image_tensor,
366 # sample=False,
367 # decoder_inv=kwargs.get("decoder_inv", False)
368 # )
370 # # Perform DDIM inversion
371 # inversion_kwargs = {k: v for k, v in kwargs.items()
372 # if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']}
374 # reversed_latents = self.config.inversion.forward_diffusion(
375 # latents=image_latents,
376 # text_embeddings=text_embeddings,
377 # guidance_scale=guidance_scale_to_use,
378 # num_inference_steps=num_steps_to_use,
379 # **inversion_kwargs
380 # )[-1]
382 # # Use detector or utils for evaluation
383 # if 'detector_type' in kwargs:
384 # return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type'])
385 # else:
386 # return self.utils.eval_watermark(reversed_latents)
388 def _get_video_latents(self, vae, video_frames, sample=True, rng_generator=None, permute=True):
389 encoding_dist = vae.encode(video_frames).latent_dist
390 if sample:
391 encoding = encoding_dist.sample(generator=rng_generator)
392 else:
393 encoding = encoding_dist.mode()
394 latents = (encoding * 0.18215).unsqueeze(0)
395 if permute:
396 latents = latents.permute(0, 2, 1, 3, 4)
397 return latents
399 def _detect_watermark_in_video(self,
400 video_frames: Union[torch.Tensor, List[Image.Image]],
401 prompt: str = "",
402 detector_type: str = 'bit_acc',
403 *args, **kwargs) -> Dict[str, float]:
404 """Detect VideoShield watermark in video.
406 Args:
407 video_frames: Input video frames as tensor or list of PIL images
408 prompt: Text prompt used for generation
410 Returns:
411 Dictionary containing detection results
412 """
413 # Use config values as defaults if not explicitly provided
414 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
415 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps)
417 # Convert frames to tensor if needed
418 if isinstance(video_frames, list):
419 from torchvision import transforms
420 frames_tensor = torch.stack([transforms.ToTensor()(frame) for frame in video_frames])
421 video_frames = 2.0 * frames_tensor - 1.0 # Normalize to [-1, 1]
423 video_frames = video_frames.to(self.config.device).to(self.config.pipe.vae.dtype)
425 # Get video latents
426 with torch.no_grad():
427 # TODO: Add support for I2V pipeline
428 video_latents = self._get_video_latents(self.config.pipe.vae, video_frames, sample=False)
430 # Perform DDIM inversion
431 inversion_kwargs = {k: v for k, v in kwargs.items()
432 if k not in ['guidance_scale', 'num_inference_steps']}
434 from diffusers import DDIMInverseScheduler
435 original_scheduler = self.config.pipe.scheduler
436 inverse_scheduler = DDIMInverseScheduler.from_config(original_scheduler.config)
437 self.config.pipe.scheduler = inverse_scheduler
439 video_latents = video_latents.to(self.config.pipe.unet.dtype)
441 final_reversed_latents = self.config.pipe(
442 prompt=prompt,
443 latents=video_latents,
444 num_inference_steps=num_steps_to_use,
445 guidance_scale=guidance_scale_to_use,
446 output_type='latent',
447 **inversion_kwargs
448 ).frames # [B, F, H, W, C](T2V)
449 self.config.pipe.scheduler = original_scheduler
451 # Use detector for evaluation
452 return self.detector.eval_watermark(final_reversed_latents, detector_type=detector_type)
455 def get_data_for_visualize(self,
456 video_frames: List[Image.Image],
457 prompt: str = "",
458 guidance_scale: float = 1,
459 *args, **kwargs) -> DataForVisualization:
460 """Get VideoShield visualization data.
462 This method generates the necessary data for visualizing VideoShield watermarks,
463 including original watermarked latents and reversed latents from inversion.
465 Args:
466 image: The image to visualize watermarks for (can be None for generation only)
467 prompt: The text prompt used for generation
468 guidance_scale: Guidance scale for generation and inversion
470 Returns:
471 DataForVisualization object containing visualization data
472 """
473 # Use config values as defaults if not explicitly provided
474 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale)
475 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps)
477 # Convert frames to tensor if needed
478 if isinstance(video_frames, list):
479 from torchvision import transforms
480 frames_tensor = torch.stack([transforms.ToTensor()(frame) for frame in video_frames])
481 video_frames = 2.0 * frames_tensor - 1.0 # Normalize to [-1, 1]
483 video_frames = video_frames.to(self.config.device).to(self.config.pipe.vae.dtype)
485 # Get video latents
486 with torch.no_grad():
487 # TODO: Add support for I2V pipeline
488 video_latents = self._get_video_latents(self.config.pipe.vae, video_frames, sample=False)
490 # Perform DDIM inversion
491 inversion_kwargs = {k: v for k, v in kwargs.items()
492 if k not in ['guidance_scale', 'num_inference_steps']}
494 from diffusers import DDIMInverseScheduler
495 original_scheduler = self.config.pipe.scheduler
496 inverse_scheduler = DDIMInverseScheduler.from_config(original_scheduler.config)
497 self.config.pipe.scheduler = inverse_scheduler
498 collector = DenoisingLatentsCollector(save_every_n_steps=1, to_cpu=True)
500 video_latents = video_latents.to(self.config.pipe.unet.dtype)
502 final_reversed_latents = self.config.pipe(
503 prompt=prompt,
504 latents=video_latents,
505 num_inference_steps=num_steps_to_use,
506 guidance_scale=guidance_scale_to_use,
507 output_type='latent',
508 callback=collector,
509 callback_steps=1,
510 **inversion_kwargs
511 ).frames # [B, F, H, W, C](T2V)
512 self.config.pipe.scheduler = original_scheduler
514 reversed_latents = collector.latents_list # List[Tensor]
516 return DataForVisualization(
517 config=self.config,
518 utils=self.utils,
519 orig_watermarked_latents=self.get_orig_watermarked_latents(),
520 reversed_latents=reversed_latents,
521 video_frames=video_frames,
522 )