Coverage for utils / media_utils.py: 90.18%

163 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 10:24 +0000

1import torch 

2import numpy as np 

3from torchvision import transforms 

4from torchvision.transforms.functional import pil_to_tensor 

5from PIL import Image 

6import cv2 

7from diffusers import StableDiffusionPipeline, TextToVideoSDPipeline, StableVideoDiffusionPipeline 

8from transformers import get_cosine_schedule_with_warmup 

9from typing import Optional, Callable, Union, List, Tuple, Dict, Any 

10from tqdm import tqdm 

11import copy 

12from utils.pipeline_utils import ( 

13 get_pipeline_type, 

14 PIPELINE_TYPE_IMAGE, 

15 PIPELINE_TYPE_TEXT_TO_VIDEO, 

16 PIPELINE_TYPE_IMAGE_TO_VIDEO 

17) 

18 

19# ===== Common Utility Functions ===== 

20 

21def torch_to_numpy(tensor) -> np.ndarray: 

22 """Convert tensor to numpy array with proper scaling.""" 

23 tensor = (tensor / 2 + 0.5).clamp(0, 1) 

24 if tensor.dim() == 4: # Image: B, C, H, W 

25 return tensor.cpu().permute(0, 2, 3, 1).numpy() 

26 elif tensor.dim() == 5: # Video: B, C, F, H, W 

27 return tensor.cpu().permute(0, 2, 3, 4, 1).numpy() 

28 else: 

29 raise ValueError(f"Unsupported tensor dimension: {tensor.dim()}") 

30 

31def pil_to_torch(image: Image.Image, normalize: bool = True) -> torch.Tensor: 

32 """Convert PIL image to torch tensor.""" 

33 tensor = pil_to_tensor(image) / 255.0 

34 if normalize: 

35 tensor = 2.0 * tensor - 1.0 # Normalize to [-1, 1] 

36 return tensor 

37 

38def numpy_to_pil(img: np.ndarray) -> Image.Image: 

39 """Convert numpy array to PIL image.""" 

40 if img.dtype != np.uint8: 

41 img = (img * 255).clip(0, 255).astype(np.uint8) 

42 return Image.fromarray(img) 

43 

44def cv2_to_pil(img: np.ndarray) -> Image.Image: 

45 """Convert cv2 image (numpy array) to PIL image.""" 

46 if img.dtype != np.uint8: 

47 img = (img * 255).clip(0, 255).astype(np.uint8) 

48 return Image.fromarray(img) 

49 

50def pil_to_cv2(pil_img: Image.Image) -> np.ndarray: 

51 """Convert PIL image to cv2 format (numpy array).""" 

52 return np.asarray(pil_img) / 255.0 

53 

54def transform_to_model_format(media: Union[Image.Image, List[Image.Image], np.ndarray, torch.Tensor], 

55 target_size: Optional[int] = None) -> torch.Tensor: 

56 """ 

57 Transform image or video frames to model input format. 

58 For image, `media` is a PIL image that will be resized to `target_size`(if provided) and then normalized to [-1, 1] and permuted to [C, H, W] from [H, W, C]. 

59 For video, `media` is a list of frames (PIL images or numpy arrays) that will be normalized to [-1, 1] and permuted to [F, C, H, W] from [F, H, W, C]. 

60  

61 Args: 

62 media: PIL image or list of frames or video tensor 

63 target_size: Target size for resize operations (for images) 

64  

65 Returns: 

66 torch.Tensor: Normalized tensor ready for model input 

67 """ 

68 if isinstance(media, Image.Image): 

69 # Single image 

70 if target_size is not None: 

71 transform = transforms.Compose([ 

72 transforms.Resize(target_size), 

73 transforms.CenterCrop(target_size), 

74 transforms.ToTensor(), 

75 ]) 

76 else: 

77 transform = transforms.ToTensor() 

78 return 2.0 * transform(media) - 1.0 

79 

80 elif isinstance(media, list): 

81 # List of frames (PIL images or numpy arrays) 

82 if all(isinstance(frame, Image.Image) for frame in media): 

83 return torch.stack([2.0 * transforms.ToTensor()(frame) - 1.0 for frame in media]) 

84 elif all(isinstance(frame, np.ndarray) for frame in media): 

85 return torch.stack([2.0 * transforms.ToTensor()(numpy_to_pil(frame)) - 1.0 for frame in media]) 

86 else: 

87 raise ValueError("All frames must be either PIL images or numpy arrays") 

88 

89 elif isinstance(media, np.ndarray) and media.ndim >= 3: 

90 # Video numpy array 

91 if media.ndim == 3: # Single frame: H, W, C 

92 return 2.0 * transforms.ToTensor()(media) - 1.0 

93 elif media.ndim == 4: # Multiple frames: F, H, W, C 

94 return torch.stack([2.0 * transforms.ToTensor()(frame) - 1.0 for frame in media]) 

95 else: 

96 raise ValueError(f"Unsupported numpy array shape: {media.shape}") 

97 

98 else: 

99 raise ValueError(f"Unsupported media type: {type(media)}") 

100 

101# ===== Latent Processing Functions ===== 

102 

103def set_inversion(pipe: Union[StableDiffusionPipeline, TextToVideoSDPipeline, StableVideoDiffusionPipeline], inversion_type: str): 

104 """Set the inversion for the given pipe.""" 

105 from inversions import DDIMInversion, ExactInversion 

106 

107 if inversion_type == "ddim": 

108 return DDIMInversion(pipe.scheduler, pipe.unet, pipe.device) 

109 elif inversion_type == "exact": 

110 return ExactInversion(pipe.scheduler, pipe.unet, pipe.device) 

111 else: 

112 raise ValueError(f"Invalid inversion type: {inversion_type}") 

113 

114def get_random_latents(pipe: Union[StableDiffusionPipeline, TextToVideoSDPipeline, StableVideoDiffusionPipeline], 

115 latents=None, num_frames=None, height=512, width=512, generator=None) -> torch.Tensor: 

116 """Get random latents for the given pipe.""" 

117 pipeline_type = get_pipeline_type(pipe) 

118 height = height or pipe.unet.config.sample_size * pipe.vae_scale_factor 

119 width = width or pipe.unet.config.sample_size * pipe.vae_scale_factor 

120 

121 batch_size = 1 

122 device = pipe._execution_device 

123 num_channels_latents = pipe.unet.config.in_channels 

124 

125 # Handle different pipeline types 

126 if pipeline_type == PIPELINE_TYPE_IMAGE or num_frames is None: 

127 latents = pipe.prepare_latents( 

128 batch_size, 

129 num_channels_latents, 

130 height, 

131 width, 

132 pipe.text_encoder.dtype, 

133 device, 

134 generator, 

135 latents, 

136 ) 

137 else: 

138 # Video pipelines with frames 

139 latents = pipe.prepare_latents( 

140 batch_size, 

141 num_channels_latents, 

142 num_frames, 

143 height, 

144 width, 

145 pipe.text_encoder.dtype, 

146 device, 

147 generator, 

148 latents, 

149 ) 

150 

151 return latents 

152 

153# ===== Image-Specific Functions ===== 

154 

155def _get_image_latents(pipe: StableDiffusionPipeline, image: torch.Tensor, 

156 sample: bool = True, rng_generator: Optional[torch.Generator] = None, 

157 decoder_inv: bool = False) -> torch.Tensor: 

158 """Get the image latents for the given image.""" 

159 encoding_dist = pipe.vae.encode(image).latent_dist 

160 if sample: 

161 encoding = encoding_dist.sample(generator=rng_generator) 

162 else: 

163 encoding = encoding_dist.mode() 

164 latents = encoding * 0.18215 

165 if decoder_inv: 

166 latents = decoder_inv_optimization(pipe, latents, image) 

167 return latents 

168 

169def _decode_image_latents(pipe: StableDiffusionPipeline, latents: torch.FloatTensor) -> torch.Tensor: 

170 """Decode the image from the given latents.""" 

171 scaled_latents = 1 / 0.18215 * latents 

172 image = pipe.vae.decode(scaled_latents, return_dict=False)[0] 

173 image = (image / 2 + 0.5).clamp(0, 1) 

174 return image 

175 

176def decoder_inv_optimization(pipe: StableDiffusionPipeline, latents: torch.FloatTensor, 

177 image: torch.FloatTensor, num_steps: int = 100) -> torch.Tensor: 

178 """ 

179 Optimize latents to better reconstruct the input image by minimizing the error between 

180 decoded latents and original image. 

181  

182 Args: 

183 pipe: The diffusion pipeline 

184 latents: Initial latents 

185 image: Target image 

186 num_steps: Number of optimization steps 

187  

188 Returns: 

189 torch.Tensor: Optimized latents 

190 """ 

191 input_image = image.clone().float() 

192 z = latents.clone().float().detach() 

193 z.requires_grad_(True) 

194 

195 loss_function = torch.nn.MSELoss(reduction='sum') 

196 optimizer = torch.optim.Adam([z], lr=0.1) 

197 lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=10, num_training_steps=num_steps) 

198 

199 for i in tqdm(range(num_steps)): 

200 # Decode without normalization to match original implementation 

201 scaled_latents = 1 / 0.18215 * z 

202 x_pred = pipe.vae.decode(scaled_latents, return_dict=False)[0] 

203 

204 loss = loss_function(x_pred, input_image) 

205 

206 optimizer.zero_grad() 

207 loss.backward() 

208 optimizer.step() 

209 lr_scheduler.step() 

210 

211 return z.detach() 

212 

213# ===== Video-Specific Functions ===== 

214 

215def _get_video_latents(pipe: Union[TextToVideoSDPipeline, StableVideoDiffusionPipeline], 

216 video_frames: torch.Tensor, sample: bool = True, 

217 rng_generator: Optional[torch.Generator] = None, 

218 permute: bool = True, 

219 decoder_inv: bool = False) -> torch.Tensor: 

220 """ 

221 Encode video frames to latents. 

222  

223 Args: 

224 pipe: Video diffusion pipeline 

225 video_frames: Tensor of video frames [F, C, H, W] 

226 sample: Whether to sample from the latent distribution 

227 rng_generator: Random generator for sampling 

228 permute: Whether to permute the latents to [B, C, F, H, W] format 

229 decoder_inv: Whether to decode the latents 

230  

231 Returns: 

232 torch.Tensor: Video latents 

233 """ 

234 encoding_dist = pipe.vae.encode(video_frames).latent_dist 

235 if sample: 

236 encoding = encoding_dist.sample(generator=rng_generator) 

237 else: 

238 encoding = encoding_dist.mode() 

239 latents = (encoding * 0.18215).unsqueeze(0) 

240 if permute: 

241 latents = latents.permute(0, 2, 1, 3, 4) 

242 if decoder_inv: # TODO: Implement decoder inversion for video latents 

243 raise NotImplementedError("Decoder inversion is not implemented for video latents") 

244 return latents 

245 

246def tensor2vid(video: torch.Tensor, processor, output_type: str = "np"): 

247 """ 

248 Convert video tensor to desired output format. 

249  

250 Args: 

251 video: Video tensor [B, C, F, H, W] 

252 processor: Video processor from the diffusion pipeline 

253 output_type: Output type - 'np', 'pt', or 'pil' 

254  

255 Returns: 

256 Video in requested format 

257 """ 

258 batch_size, channels, num_frames, height, width = video.shape 

259 outputs = [] 

260 for batch_idx in range(batch_size): 

261 batch_vid = video[batch_idx].permute(1, 0, 2, 3) 

262 batch_output = processor.postprocess(batch_vid, output_type) 

263 outputs.append(batch_output) 

264 

265 if output_type == "np": 

266 outputs = np.stack(outputs) 

267 elif output_type == "pt": 

268 outputs = torch.stack(outputs) 

269 elif not output_type == "pil": 

270 raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']") 

271 

272 return outputs 

273 

274def _decode_video_latents(pipe: Union[TextToVideoSDPipeline, StableVideoDiffusionPipeline], 

275 latents: torch.Tensor, 

276 num_frames: Optional[int] = None) -> np.ndarray: 

277 """ 

278 Decode latents to video frames. 

279  

280 Args: 

281 pipe: Video diffusion pipeline 

282 latents: Video latents 

283 num_frames: Number of frames to decode 

284  

285 Returns: 

286 np.ndarray: Video frames 

287 """ 

288 if num_frames is None: 

289 video_tensor = pipe.decode_latents(latents) 

290 else: 

291 video_tensor = pipe.decode_latents(latents, num_frames) 

292 video = tensor2vid(video_tensor, pipe.video_processor) 

293 return video 

294 

295def convert_video_frames_to_images(frames: List[Union[np.ndarray, Image.Image]]) -> List[Image.Image]: 

296 """ 

297 Convert video frames to a list of PIL.Image objects. 

298  

299 Args: 

300 frames: List of video frames (numpy arrays or PIL images) 

301  

302 Returns: 

303 List[Image.Image]: List of PIL images 

304 """ 

305 pil_frames = [] 

306 for frame in frames: 

307 if isinstance(frame, np.ndarray): 

308 # Convert numpy array to PIL 

309 pil_frames.append(numpy_to_pil(frame)) 

310 elif isinstance(frame, Image.Image): 

311 # Already a PIL image 

312 pil_frames.append(frame) 

313 else: 

314 raise ValueError(f"Unsupported frame type: {type(frame)}") 

315 return pil_frames 

316 

317def save_video_frames(frames: List[Union[np.ndarray, Image.Image]], save_dir: str) -> None: 

318 """ 

319 Save video frames to a directory. 

320  

321 Args: 

322 frames: List of video frames (numpy arrays or PIL images) 

323 save_dir: Directory to save frames 

324 """ 

325 if isinstance(frames[0], np.ndarray): 

326 frames = [(frame * 255).astype(np.uint8) if frame.dtype != np.uint8 else frame for frame in frames] 

327 elif isinstance(frames[0], Image.Image): 

328 frames = [np.array(frame) for frame in frames] 

329 

330 for i, frame in enumerate(frames): 

331 img = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR) 

332 cv2.imwrite(f'{save_dir}/{i:02d}.png', img) 

333 

334# ===== Utility Functions for Working with Different Pipeline Types ===== 

335 

336def get_media_latents(pipe: Union[StableDiffusionPipeline, TextToVideoSDPipeline, StableVideoDiffusionPipeline], 

337 media: Union[torch.Tensor, List[torch.Tensor]], 

338 sample: bool = True, 

339 rng_generator: Optional[torch.Generator] = None, 

340 decoder_inv: bool = False) -> torch.Tensor: 

341 """ 

342 Get latents from media (either image or video) based on pipeline type. 

343  

344 Args: 

345 pipe: Diffusion pipeline 

346 media: Image tensor or video frames tensor 

347 sample: Whether to sample from the latent distribution 

348 rng_generator: Random generator for sampling 

349 decoder_inv: Whether to use decoder inversion optimization 

350 Returns: 

351 torch.Tensor: Media latents 

352 """ 

353 pipeline_type = get_pipeline_type(pipe) 

354 

355 if pipeline_type == PIPELINE_TYPE_IMAGE: 

356 return _get_image_latents(pipe, media, sample, rng_generator, decoder_inv) 

357 elif pipeline_type in [PIPELINE_TYPE_TEXT_TO_VIDEO, PIPELINE_TYPE_IMAGE_TO_VIDEO]: 

358 permute = pipeline_type == PIPELINE_TYPE_TEXT_TO_VIDEO 

359 return _get_video_latents(pipe, media, sample, rng_generator, permute, decoder_inv) 

360 else: 

361 raise ValueError(f"Unsupported pipeline type: {pipeline_type}") 

362 

363def decode_media_latents(pipe: Union[StableDiffusionPipeline, TextToVideoSDPipeline, StableVideoDiffusionPipeline], 

364 latents: torch.Tensor, 

365 num_frames: Optional[int] = None) -> Union[torch.Tensor, np.ndarray]: 

366 """ 

367 Decode latents to media (either image or video) based on pipeline type. 

368  

369 Args: 

370 pipe: Diffusion pipeline 

371 latents: Media latents 

372 num_frames: Number of frames (for video) 

373  

374 Returns: 

375 Union[torch.Tensor, np.ndarray]: Decoded media 

376 """ 

377 pipeline_type = get_pipeline_type(pipe) 

378 

379 if pipeline_type == PIPELINE_TYPE_IMAGE: 

380 return _decode_image_latents(pipe, latents) 

381 elif pipeline_type in [PIPELINE_TYPE_TEXT_TO_VIDEO, PIPELINE_TYPE_IMAGE_TO_VIDEO]: 

382 return _decode_video_latents(pipe, latents, num_frames) 

383 else: 

384 raise ValueError(f"Unsupported pipeline type: {pipeline_type}")