Coverage for watermark / videoshield / video_shield.py: 94.12%

204 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from ..base import BaseWatermark, BaseConfig 

2import torch 

3import numpy as np 

4from typing import Dict, Tuple, Any, Optional, List, Union 

5from PIL import Image 

6from Crypto.Cipher import ChaCha20 

7from Crypto.Random import get_random_bytes 

8import logging 

9from scipy.stats import norm, truncnorm 

10from functools import reduce 

11from visualize.data_for_visualization import DataForVisualization 

12from detection.videoshield.videoshield_detection import VideoShieldDetector 

13from utils.media_utils import * 

14from utils.utils import set_random_seed 

15from utils.pipeline_utils import is_video_pipeline, is_t2v_pipeline, is_i2v_pipeline 

16from utils.callbacks import DenoisingLatentsCollector 

17import random 

18 

19# Setup logging 

20logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') 

21logger = logging.getLogger(__name__) 

22 

23# Constants 

24VAE_DOWNSAMPLE_FACTOR = 8 

25DEFAULT_CONFIDENCE_THRESHOLD = 0.6 

26 

27 

28class VideoShieldConfig(BaseConfig): 

29 """Config class for VideoShield algorithm.""" 

30 

31 def initialize_parameters(self) -> None: 

32 """Initialize VideoShield configuration.""" 

33 # Repetition factors for video watermarking 

34 self.k_f: int = self.config_dict['k_f'] # Frame repetition factor 

35 self.k_c: int = self.config_dict['k_c'] # Channel repetition factor  

36 self.k_h: int = self.config_dict['k_h'] # Height repetition factor 

37 self.k_w: int = self.config_dict['k_w'] # Width repetition factor 

38 

39 # Temporal threshold for localization 

40 self.t_temp: float = self.config_dict['t_temp'] 

41 

42 # HSTR (Hierarchical Spatio-Temporal Refinement) parameters 

43 self.hstr_levels: int = self.config_dict['hstr_levels'] 

44 self.t_wm: List[float] = self.config_dict['t_wm'] 

45 self.t_orig: List[float] = self.config_dict['t_orig'] 

46 

47 # Watermark generation parameters 

48 self.wm_key: int = self.config_dict.get('wm_key', 42) 

49 self.chacha_key_seed: int = self.config_dict.get('chacha_key_seed', 123456) 

50 self.chacha_nonce_seed: int = self.config_dict.get('chacha_nonce_seed', 789012) 

51 

52 # Detection threshold 

53 self.threshold: float = self.config_dict.get('threshold', 0.6) 

54 

55 # Calculate latent dimensions 

56 self.latents_height = self.image_size[0] // VAE_DOWNSAMPLE_FACTOR 

57 self.latents_width = self.image_size[1] // VAE_DOWNSAMPLE_FACTOR 

58 

59 # Adjust repetition factors if they exceed dimensions to avoid empty tensors 

60 if hasattr(self, 'num_frames') and self.num_frames > 0: 

61 if self.k_f > self.num_frames: 

62 logger.warning(f"k_f ({self.k_f}) is larger than num_frames ({self.num_frames}). Adjusting k_f to {self.num_frames}.") 

63 self.k_f = self.num_frames 

64 

65 if self.k_h > self.latents_height: 

66 logger.warning(f"k_h ({self.k_h}) is larger than latents_height ({self.latents_height}). Adjusting k_h to {self.latents_height}.") 

67 self.k_h = self.latents_height 

68 

69 if self.k_w > self.latents_width: 

70 logger.warning(f"k_w ({self.k_w}) is larger than latents_width ({self.latents_width}). Adjusting k_w to {self.latents_width}.") 

71 self.k_w = self.latents_width 

72 

73 # Generate watermark pattern 

74 generator = torch.Generator(device=self.device) 

75 generator.manual_seed(self.wm_key) 

76 

77 # For video, we need frame dimension 

78 if hasattr(self, 'num_frames') and self.num_frames > 0: 

79 self.watermark = torch.randint( 

80 0, 2, 

81 [1, 4 // self.k_c, self.num_frames // self.k_f, 

82 self.latents_height // self.k_h, self.latents_width // self.k_w], 

83 generator=generator, device=self.device 

84 ) 

85 else: 

86 # Fallback for image-only case 

87 self.watermark = torch.randint( 

88 0, 2, 

89 [1, 4 // self.k_c, self.latents_height // self.k_h, self.latents_width // self.k_w], 

90 generator=generator, device=self.device 

91 ) 

92 

93 @property 

94 def algorithm_name(self) -> str: 

95 """Return the algorithm name.""" 

96 return 'VideoShield' 

97 

98 

99class VideoShieldUtils: 

100 """Utility class for VideoShield algorithm.""" 

101 

102 def __init__(self, config: VideoShieldConfig, *args, **kwargs) -> None: 

103 """Initialize the VideoShield watermarking utility.""" 

104 self.config = config 

105 

106 # Generate deterministic cryptographic keys using seeds 

107 self.chacha_key = self._get_bytes_with_seed(self.config.chacha_key_seed, 32) 

108 self.chacha_nonce = self._get_bytes_with_seed(self.config.chacha_nonce_seed, 12) 

109 

110 # Calculate latent space dimensions 

111 if hasattr(self.config, 'num_frames') and self.config.num_frames > 0: 

112 # Video case: include frame dimension 

113 self.latentlength = 4 * self.config.num_frames * self.config.latents_height * self.config.latents_width 

114 else: 

115 # Image case: no frame dimension 

116 self.latentlength = 4 * self.config.latents_height * self.config.latents_width 

117 

118 # Calculate watermark length based on repetition factors 

119 self.marklength = self.latentlength // (self.config.k_f * self.config.k_c * self.config.k_h * self.config.k_w) 

120 

121 # Voting threshold for watermark extraction 

122 if self.config.k_f == 1 and self.config.k_c == 1 and self.config.k_h == 1 and self.config.k_w == 1: 

123 self.vote_threshold = 1 

124 else: 

125 self.vote_threshold = (self.config.k_f * self.config.k_c * self.config.k_h * self.config.k_w) // 2 

126 

127 def _get_bytes_with_seed(self, seed: int, n: int) -> bytes: 

128 """Generate deterministic bytes using a seed.""" 

129 random.seed(seed) 

130 return bytes(random.getrandbits(8) for _ in range(n)) 

131 

132 def _stream_key_encrypt(self, sd: np.ndarray) -> np.ndarray: 

133 """Encrypt the watermark using ChaCha20 cipher.""" 

134 try: 

135 cipher = ChaCha20.new(key=self.chacha_key, nonce=self.chacha_nonce) 

136 m_byte = cipher.encrypt(np.packbits(sd).tobytes()) 

137 m_bit = np.unpackbits(np.frombuffer(m_byte, dtype=np.uint8)) 

138 return m_bit[:len(sd)] # Ensure same length as input 

139 except Exception as e: 

140 logger.error(f"Encryption failed: {e}") 

141 raise RuntimeError("Encryption failed") from e 

142 

143 def _truncated_sampling(self, message: np.ndarray) -> torch.Tensor: 

144 """Truncated Gaussian sampling for watermarking. 

145  

146 Args: 

147 message: Binary message as a numpy array of 0s and 1s 

148  

149 Returns: 

150 Watermarked latents tensor 

151 """ 

152 z = np.zeros(self.latentlength) 

153 denominator = 2.0 

154 ppf = [norm.ppf(j / denominator) for j in range(int(denominator) + 1)] 

155 

156 for i in range(self.latentlength): 

157 dec_mes = reduce(lambda a, b: 2 * a + b, message[i : i + 1]) 

158 dec_mes = int(dec_mes) 

159 z[i] = truncnorm.rvs(ppf[dec_mes], ppf[dec_mes + 1]) 

160 

161 # Reshape based on whether this is video or image 

162 if hasattr(self.config, 'num_frames') and self.config.num_frames > 0: 

163 # Video: (batch, channels, frames, height, width) 

164 z = torch.from_numpy(z).reshape( 

165 1, 4, self.config.num_frames, self.config.latents_height, self.config.latents_width 

166 ).float() 

167 else: 

168 # Image: (batch, channels, height, width) 

169 z = torch.from_numpy(z).reshape( 

170 1, 4, self.config.latents_height, self.config.latents_width 

171 ).float() 

172 

173 return z.to(self.config.device) 

174 

175 def create_watermark_and_return_w(self) -> torch.Tensor: 

176 """Create watermark pattern and return watermarked initial latents.""" 

177 # Create repeated watermark pattern 

178 if hasattr(self.config, 'num_frames') and self.config.num_frames > 0: 

179 # Video case: repeat along all dimensions including frames 

180 sd = self.config.watermark.repeat(1, self.config.k_c, self.config.k_f, self.config.k_h, self.config.k_w) 

181 else: 

182 # Image case: repeat along spatial dimensions only 

183 sd = self.config.watermark.repeat(1, self.config.k_c, self.config.k_h, self.config.k_w) 

184 

185 # Encrypt the repeated watermark 

186 m = self._stream_key_encrypt(sd.flatten().cpu().numpy()) 

187 

188 # Generate watermarked latents using truncated sampling 

189 w = self._truncated_sampling(m) 

190 

191 return w 

192 

193class VideoShieldWatermark(BaseWatermark): 

194 """Main class for VideoShield watermarking algorithm.""" 

195 

196 def __init__(self, watermark_config: VideoShieldConfig, *args, **kwargs) -> None: 

197 """Initialize the VideoShield watermarking algorithm. 

198  

199 Args: 

200 watermark_config: Configuration instance of the VideoShield algorithm 

201 """ 

202 self.config = watermark_config 

203 self.utils = VideoShieldUtils(self.config) 

204 

205 # Initialize detector with encryption keys from utils 

206 self.detector = VideoShieldDetector( 

207 watermark=self.config.watermark, 

208 threshold=self.config.threshold, 

209 device=self.config.device, 

210 chacha_key=self.utils.chacha_key, 

211 chacha_nonce=self.utils.chacha_nonce, 

212 height=self.config.image_size[0], 

213 width=self.config.image_size[1], 

214 num_frames=self.config.num_frames, 

215 k_f=self.config.k_f, 

216 k_c=self.config.k_c, 

217 k_h=self.config.k_h, 

218 k_w=self.config.k_w 

219 ) 

220 

221 def _generate_watermarked_video(self, prompt: str, num_frames: Optional[int] = None, *args, **kwargs) -> List[Image.Image]: 

222 """Generate watermarked video using VideoShield algorithm. 

223  

224 Args: 

225 prompt: The input prompt for video generation 

226 num_frames: Number of frames to generate (uses config value if None) 

227  

228 Returns: 

229 List of generated watermarked video frames 

230 """ 

231 if not is_video_pipeline(self.config.pipe): 

232 raise ValueError(f"This pipeline ({self.config.pipe.__class__.__name__}) does not support video generation.") 

233 

234 # Set random seed for reproducibility 

235 set_random_seed(self.config.gen_seed) 

236 

237 # Use config frames if not specified 

238 frames_to_generate = num_frames if num_frames is not None else self.config.num_frames 

239 

240 # Set num_frames in config for watermark generation 

241 original_num_frames = getattr(self.config, 'num_frames', None) 

242 self.config.num_frames = frames_to_generate 

243 

244 try: 

245 # Generate watermarked latents 

246 watermarked_latents = self.utils.create_watermark_and_return_w().to(self.config.pipe.unet.dtype) 

247 

248 # Save watermarked latents for visualization 

249 self.set_orig_watermarked_latents(watermarked_latents) 

250 

251 # Construct video generation parameters 

252 generation_params = { 

253 "num_inference_steps": self.config.num_inference_steps, 

254 "guidance_scale": self.config.guidance_scale, 

255 "height": self.config.image_size[0], 

256 "width": self.config.image_size[1], 

257 "num_frames": frames_to_generate, 

258 "latents": watermarked_latents 

259 } 

260 

261 # Add parameters from config.gen_kwargs 

262 if hasattr(self.config, "gen_kwargs") and self.config.gen_kwargs: 

263 for key, value in self.config.gen_kwargs.items(): 

264 if key not in generation_params: 

265 generation_params[key] = value 

266 

267 # Use kwargs to override default parameters 

268 for key, value in kwargs.items(): 

269 if key != "num_frames": # Prevent overriding processed parameters 

270 generation_params[key] = value 

271 

272 # Handle I2V pipelines that need dimension permutation (like SVD) 

273 final_latents = watermarked_latents 

274 if is_i2v_pipeline(self.config.pipe): 

275 logger.info("I2V pipeline detected, permuting latent dimensions.") 

276 final_latents = final_latents.permute(0, 2, 1, 3, 4) # (b,c,f,h,w) -> (b,f,c,h,w) 

277 

278 generation_params["latents"] = final_latents 

279 

280 # Generate video 

281 output = self.config.pipe( 

282 prompt, 

283 **generation_params 

284 ) 

285 

286 # Extract frames from output 

287 if hasattr(output, 'frames'): 

288 frames = output.frames[0] 

289 elif hasattr(output, 'videos'): 

290 frames = output.videos[0] 

291 else: 

292 frames = output[0] if isinstance(output, tuple) else output 

293 

294 # Convert frames to PIL Images 

295 frame_list = [] 

296 for frame in frames: 

297 if not isinstance(frame, Image.Image): 

298 if isinstance(frame, np.ndarray): 

299 if frame.dtype == np.uint8: 

300 frame_pil = Image.fromarray(frame) 

301 else: 

302 frame_scaled = (frame * 255).astype(np.uint8) 

303 frame_pil = Image.fromarray(frame_scaled) 

304 elif isinstance(frame, torch.Tensor): 

305 if frame.dim() == 3 and frame.shape[-1] in [1, 3]: 

306 if frame.max() <= 1.0: 

307 frame = (frame * 255).byte() 

308 frame_np = frame.cpu().numpy() 

309 frame_pil = Image.fromarray(frame_np) 

310 else: 

311 raise ValueError(f"Unexpected tensor shape for frame: {frame.shape}") 

312 else: 

313 raise TypeError(f"Unexpected type for frame: {type(frame)}") 

314 else: 

315 frame_pil = frame 

316 

317 frame_list.append(frame_pil) 

318 

319 return frame_list 

320 

321 finally: 

322 # Restore original num_frames 

323 if original_num_frames is not None: 

324 self.config.num_frames = original_num_frames 

325 elif hasattr(self.config, 'num_frames'): 

326 delattr(self.config, 'num_frames') 

327 

328 # def _detect_watermark_in_image(self, image: Image.Image, prompt: str = "",  

329 # *args, **kwargs) -> Dict[str, float]: 

330 # """Detect VideoShield watermark in image. 

331 

332 # Args: 

333 # image: Input PIL image 

334 # prompt: Text prompt used for generation 

335 

336 # Returns: 

337 # Dictionary containing detection results 

338 # """ 

339 # # Use config values as defaults if not explicitly provided 

340 # guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

341 # num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps) 

342 

343 # # Get text embeddings 

344 # do_classifier_free_guidance = (guidance_scale_to_use > 1.0) 

345 # prompt_embeds, negative_prompt_embeds = self.config.pipe.encode_prompt( 

346 # prompt=prompt,  

347 # device=self.config.device,  

348 # do_classifier_free_guidance=do_classifier_free_guidance, 

349 # num_images_per_prompt=1, 

350 # ) 

351 

352 # if do_classifier_free_guidance: 

353 # text_embeddings = torch.cat([negative_prompt_embeds, prompt_embeds]) 

354 # else: 

355 # text_embeddings = prompt_embeds 

356 

357 # # Preprocess image 

358 # image_tensor = transform_to_model_format( 

359 # image, target_size=self.config.image_size[0] 

360 # ).unsqueeze(0).to(text_embeddings.dtype).to(self.config.device) 

361 

362 # # Get image latents 

363 # image_latents = get_media_latents( 

364 # pipe=self.config.pipe,  

365 # media=image_tensor,  

366 # sample=False,  

367 # decoder_inv=kwargs.get("decoder_inv", False) 

368 # ) 

369 

370 # # Perform DDIM inversion 

371 # inversion_kwargs = {k: v for k, v in kwargs.items()  

372 # if k not in ['decoder_inv', 'guidance_scale', 'num_inference_steps']} 

373 

374 # reversed_latents = self.config.inversion.forward_diffusion( 

375 # latents=image_latents, 

376 # text_embeddings=text_embeddings, 

377 # guidance_scale=guidance_scale_to_use, 

378 # num_inference_steps=num_steps_to_use, 

379 # **inversion_kwargs 

380 # )[-1] 

381 

382 # # Use detector or utils for evaluation 

383 # if 'detector_type' in kwargs: 

384 # return self.detector.eval_watermark(reversed_latents, detector_type=kwargs['detector_type']) 

385 # else: 

386 # return self.utils.eval_watermark(reversed_latents) 

387 

388 def _get_video_latents(self, vae, video_frames, sample=True, rng_generator=None, permute=True): 

389 encoding_dist = vae.encode(video_frames).latent_dist 

390 if sample: 

391 encoding = encoding_dist.sample(generator=rng_generator) 

392 else: 

393 encoding = encoding_dist.mode() 

394 latents = (encoding * 0.18215).unsqueeze(0) 

395 if permute: 

396 latents = latents.permute(0, 2, 1, 3, 4) 

397 return latents 

398 

399 def _detect_watermark_in_video(self, 

400 video_frames: Union[torch.Tensor, List[Image.Image]], 

401 prompt: str = "", 

402 detector_type: str = 'bit_acc', 

403 *args, **kwargs) -> Dict[str, float]: 

404 """Detect VideoShield watermark in video. 

405  

406 Args: 

407 video_frames: Input video frames as tensor or list of PIL images 

408 prompt: Text prompt used for generation 

409  

410 Returns: 

411 Dictionary containing detection results 

412 """ 

413 # Use config values as defaults if not explicitly provided 

414 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

415 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps) 

416 

417 # Convert frames to tensor if needed 

418 if isinstance(video_frames, list): 

419 from torchvision import transforms 

420 frames_tensor = torch.stack([transforms.ToTensor()(frame) for frame in video_frames]) 

421 video_frames = 2.0 * frames_tensor - 1.0 # Normalize to [-1, 1] 

422 

423 video_frames = video_frames.to(self.config.device).to(self.config.pipe.vae.dtype) 

424 

425 # Get video latents 

426 with torch.no_grad(): 

427 # TODO: Add support for I2V pipeline 

428 video_latents = self._get_video_latents(self.config.pipe.vae, video_frames, sample=False) 

429 

430 # Perform DDIM inversion 

431 inversion_kwargs = {k: v for k, v in kwargs.items() 

432 if k not in ['guidance_scale', 'num_inference_steps']} 

433 

434 from diffusers import DDIMInverseScheduler 

435 original_scheduler = self.config.pipe.scheduler 

436 inverse_scheduler = DDIMInverseScheduler.from_config(original_scheduler.config) 

437 self.config.pipe.scheduler = inverse_scheduler 

438 

439 video_latents = video_latents.to(self.config.pipe.unet.dtype) 

440 

441 final_reversed_latents = self.config.pipe( 

442 prompt=prompt, 

443 latents=video_latents, 

444 num_inference_steps=num_steps_to_use, 

445 guidance_scale=guidance_scale_to_use, 

446 output_type='latent', 

447 **inversion_kwargs 

448 ).frames # [B, F, H, W, C](T2V) 

449 self.config.pipe.scheduler = original_scheduler 

450 

451 # Use detector for evaluation 

452 return self.detector.eval_watermark(final_reversed_latents, detector_type=detector_type) 

453 

454 

455 def get_data_for_visualize(self, 

456 video_frames: List[Image.Image], 

457 prompt: str = "", 

458 guidance_scale: float = 1, 

459 *args, **kwargs) -> DataForVisualization: 

460 """Get VideoShield visualization data. 

461  

462 This method generates the necessary data for visualizing VideoShield watermarks, 

463 including original watermarked latents and reversed latents from inversion. 

464  

465 Args: 

466 image: The image to visualize watermarks for (can be None for generation only) 

467 prompt: The text prompt used for generation 

468 guidance_scale: Guidance scale for generation and inversion 

469  

470 Returns: 

471 DataForVisualization object containing visualization data 

472 """ 

473 # Use config values as defaults if not explicitly provided 

474 guidance_scale_to_use = kwargs.get('guidance_scale', self.config.guidance_scale) 

475 num_steps_to_use = kwargs.get('num_inference_steps', self.config.num_inference_steps) 

476 

477 # Convert frames to tensor if needed 

478 if isinstance(video_frames, list): 

479 from torchvision import transforms 

480 frames_tensor = torch.stack([transforms.ToTensor()(frame) for frame in video_frames]) 

481 video_frames = 2.0 * frames_tensor - 1.0 # Normalize to [-1, 1] 

482 

483 video_frames = video_frames.to(self.config.device).to(self.config.pipe.vae.dtype) 

484 

485 # Get video latents 

486 with torch.no_grad(): 

487 # TODO: Add support for I2V pipeline 

488 video_latents = self._get_video_latents(self.config.pipe.vae, video_frames, sample=False) 

489 

490 # Perform DDIM inversion 

491 inversion_kwargs = {k: v for k, v in kwargs.items() 

492 if k not in ['guidance_scale', 'num_inference_steps']} 

493 

494 from diffusers import DDIMInverseScheduler 

495 original_scheduler = self.config.pipe.scheduler 

496 inverse_scheduler = DDIMInverseScheduler.from_config(original_scheduler.config) 

497 self.config.pipe.scheduler = inverse_scheduler 

498 collector = DenoisingLatentsCollector(save_every_n_steps=1, to_cpu=True) 

499 

500 video_latents = video_latents.to(self.config.pipe.unet.dtype) 

501 

502 final_reversed_latents = self.config.pipe( 

503 prompt=prompt, 

504 latents=video_latents, 

505 num_inference_steps=num_steps_to_use, 

506 guidance_scale=guidance_scale_to_use, 

507 output_type='latent', 

508 callback=collector, 

509 callback_steps=1, 

510 **inversion_kwargs 

511 ).frames # [B, F, H, W, C](T2V) 

512 self.config.pipe.scheduler = original_scheduler 

513 

514 reversed_latents = collector.latents_list # List[Tensor] 

515 

516 return DataForVisualization( 

517 config=self.config, 

518 utils=self.utils, 

519 orig_watermarked_latents=self.get_orig_watermarked_latents(), 

520 reversed_latents=reversed_latents, 

521 video_frames=video_frames, 

522 )