Coverage for utils / media_utils.py: 90.18%
163 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 10:24 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 10:24 +0000
1import torch
2import numpy as np
3from torchvision import transforms
4from torchvision.transforms.functional import pil_to_tensor
5from PIL import Image
6import cv2
7from diffusers import StableDiffusionPipeline, TextToVideoSDPipeline, StableVideoDiffusionPipeline
8from transformers import get_cosine_schedule_with_warmup
9from typing import Optional, Callable, Union, List, Tuple, Dict, Any
10from tqdm import tqdm
11import copy
12from utils.pipeline_utils import (
13 get_pipeline_type,
14 PIPELINE_TYPE_IMAGE,
15 PIPELINE_TYPE_TEXT_TO_VIDEO,
16 PIPELINE_TYPE_IMAGE_TO_VIDEO
17)
19# ===== Common Utility Functions =====
21def torch_to_numpy(tensor) -> np.ndarray:
22 """Convert tensor to numpy array with proper scaling."""
23 tensor = (tensor / 2 + 0.5).clamp(0, 1)
24 if tensor.dim() == 4: # Image: B, C, H, W
25 return tensor.cpu().permute(0, 2, 3, 1).numpy()
26 elif tensor.dim() == 5: # Video: B, C, F, H, W
27 return tensor.cpu().permute(0, 2, 3, 4, 1).numpy()
28 else:
29 raise ValueError(f"Unsupported tensor dimension: {tensor.dim()}")
31def pil_to_torch(image: Image.Image, normalize: bool = True) -> torch.Tensor:
32 """Convert PIL image to torch tensor."""
33 tensor = pil_to_tensor(image) / 255.0
34 if normalize:
35 tensor = 2.0 * tensor - 1.0 # Normalize to [-1, 1]
36 return tensor
38def numpy_to_pil(img: np.ndarray) -> Image.Image:
39 """Convert numpy array to PIL image."""
40 if img.dtype != np.uint8:
41 img = (img * 255).clip(0, 255).astype(np.uint8)
42 return Image.fromarray(img)
44def cv2_to_pil(img: np.ndarray) -> Image.Image:
45 """Convert cv2 image (numpy array) to PIL image."""
46 if img.dtype != np.uint8:
47 img = (img * 255).clip(0, 255).astype(np.uint8)
48 return Image.fromarray(img)
50def pil_to_cv2(pil_img: Image.Image) -> np.ndarray:
51 """Convert PIL image to cv2 format (numpy array)."""
52 return np.asarray(pil_img) / 255.0
54def transform_to_model_format(media: Union[Image.Image, List[Image.Image], np.ndarray, torch.Tensor],
55 target_size: Optional[int] = None) -> torch.Tensor:
56 """
57 Transform image or video frames to model input format.
58 For image, `media` is a PIL image that will be resized to `target_size`(if provided) and then normalized to [-1, 1] and permuted to [C, H, W] from [H, W, C].
59 For video, `media` is a list of frames (PIL images or numpy arrays) that will be normalized to [-1, 1] and permuted to [F, C, H, W] from [F, H, W, C].
61 Args:
62 media: PIL image or list of frames or video tensor
63 target_size: Target size for resize operations (for images)
65 Returns:
66 torch.Tensor: Normalized tensor ready for model input
67 """
68 if isinstance(media, Image.Image):
69 # Single image
70 if target_size is not None:
71 transform = transforms.Compose([
72 transforms.Resize(target_size),
73 transforms.CenterCrop(target_size),
74 transforms.ToTensor(),
75 ])
76 else:
77 transform = transforms.ToTensor()
78 return 2.0 * transform(media) - 1.0
80 elif isinstance(media, list):
81 # List of frames (PIL images or numpy arrays)
82 if all(isinstance(frame, Image.Image) for frame in media):
83 return torch.stack([2.0 * transforms.ToTensor()(frame) - 1.0 for frame in media])
84 elif all(isinstance(frame, np.ndarray) for frame in media):
85 return torch.stack([2.0 * transforms.ToTensor()(numpy_to_pil(frame)) - 1.0 for frame in media])
86 else:
87 raise ValueError("All frames must be either PIL images or numpy arrays")
89 elif isinstance(media, np.ndarray) and media.ndim >= 3:
90 # Video numpy array
91 if media.ndim == 3: # Single frame: H, W, C
92 return 2.0 * transforms.ToTensor()(media) - 1.0
93 elif media.ndim == 4: # Multiple frames: F, H, W, C
94 return torch.stack([2.0 * transforms.ToTensor()(frame) - 1.0 for frame in media])
95 else:
96 raise ValueError(f"Unsupported numpy array shape: {media.shape}")
98 else:
99 raise ValueError(f"Unsupported media type: {type(media)}")
101# ===== Latent Processing Functions =====
103def set_inversion(pipe: Union[StableDiffusionPipeline, TextToVideoSDPipeline, StableVideoDiffusionPipeline], inversion_type: str):
104 """Set the inversion for the given pipe."""
105 from inversions import DDIMInversion, ExactInversion
107 if inversion_type == "ddim":
108 return DDIMInversion(pipe.scheduler, pipe.unet, pipe.device)
109 elif inversion_type == "exact":
110 return ExactInversion(pipe.scheduler, pipe.unet, pipe.device)
111 else:
112 raise ValueError(f"Invalid inversion type: {inversion_type}")
114def get_random_latents(pipe: Union[StableDiffusionPipeline, TextToVideoSDPipeline, StableVideoDiffusionPipeline],
115 latents=None, num_frames=None, height=512, width=512, generator=None) -> torch.Tensor:
116 """Get random latents for the given pipe."""
117 pipeline_type = get_pipeline_type(pipe)
118 height = height or pipe.unet.config.sample_size * pipe.vae_scale_factor
119 width = width or pipe.unet.config.sample_size * pipe.vae_scale_factor
121 batch_size = 1
122 device = pipe._execution_device
123 num_channels_latents = pipe.unet.config.in_channels
125 # Handle different pipeline types
126 if pipeline_type == PIPELINE_TYPE_IMAGE or num_frames is None:
127 latents = pipe.prepare_latents(
128 batch_size,
129 num_channels_latents,
130 height,
131 width,
132 pipe.text_encoder.dtype,
133 device,
134 generator,
135 latents,
136 )
137 else:
138 # Video pipelines with frames
139 latents = pipe.prepare_latents(
140 batch_size,
141 num_channels_latents,
142 num_frames,
143 height,
144 width,
145 pipe.text_encoder.dtype,
146 device,
147 generator,
148 latents,
149 )
151 return latents
153# ===== Image-Specific Functions =====
155def _get_image_latents(pipe: StableDiffusionPipeline, image: torch.Tensor,
156 sample: bool = True, rng_generator: Optional[torch.Generator] = None,
157 decoder_inv: bool = False) -> torch.Tensor:
158 """Get the image latents for the given image."""
159 encoding_dist = pipe.vae.encode(image).latent_dist
160 if sample:
161 encoding = encoding_dist.sample(generator=rng_generator)
162 else:
163 encoding = encoding_dist.mode()
164 latents = encoding * 0.18215
165 if decoder_inv:
166 latents = decoder_inv_optimization(pipe, latents, image)
167 return latents
169def _decode_image_latents(pipe: StableDiffusionPipeline, latents: torch.FloatTensor) -> torch.Tensor:
170 """Decode the image from the given latents."""
171 scaled_latents = 1 / 0.18215 * latents
172 image = pipe.vae.decode(scaled_latents, return_dict=False)[0]
173 image = (image / 2 + 0.5).clamp(0, 1)
174 return image
176def decoder_inv_optimization(pipe: StableDiffusionPipeline, latents: torch.FloatTensor,
177 image: torch.FloatTensor, num_steps: int = 100) -> torch.Tensor:
178 """
179 Optimize latents to better reconstruct the input image by minimizing the error between
180 decoded latents and original image.
182 Args:
183 pipe: The diffusion pipeline
184 latents: Initial latents
185 image: Target image
186 num_steps: Number of optimization steps
188 Returns:
189 torch.Tensor: Optimized latents
190 """
191 input_image = image.clone().float()
192 z = latents.clone().float().detach()
193 z.requires_grad_(True)
195 loss_function = torch.nn.MSELoss(reduction='sum')
196 optimizer = torch.optim.Adam([z], lr=0.1)
197 lr_scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=10, num_training_steps=num_steps)
199 for i in tqdm(range(num_steps)):
200 # Decode without normalization to match original implementation
201 scaled_latents = 1 / 0.18215 * z
202 x_pred = pipe.vae.decode(scaled_latents, return_dict=False)[0]
204 loss = loss_function(x_pred, input_image)
206 optimizer.zero_grad()
207 loss.backward()
208 optimizer.step()
209 lr_scheduler.step()
211 return z.detach()
213# ===== Video-Specific Functions =====
215def _get_video_latents(pipe: Union[TextToVideoSDPipeline, StableVideoDiffusionPipeline],
216 video_frames: torch.Tensor, sample: bool = True,
217 rng_generator: Optional[torch.Generator] = None,
218 permute: bool = True,
219 decoder_inv: bool = False) -> torch.Tensor:
220 """
221 Encode video frames to latents.
223 Args:
224 pipe: Video diffusion pipeline
225 video_frames: Tensor of video frames [F, C, H, W]
226 sample: Whether to sample from the latent distribution
227 rng_generator: Random generator for sampling
228 permute: Whether to permute the latents to [B, C, F, H, W] format
229 decoder_inv: Whether to decode the latents
231 Returns:
232 torch.Tensor: Video latents
233 """
234 encoding_dist = pipe.vae.encode(video_frames).latent_dist
235 if sample:
236 encoding = encoding_dist.sample(generator=rng_generator)
237 else:
238 encoding = encoding_dist.mode()
239 latents = (encoding * 0.18215).unsqueeze(0)
240 if permute:
241 latents = latents.permute(0, 2, 1, 3, 4)
242 if decoder_inv: # TODO: Implement decoder inversion for video latents
243 raise NotImplementedError("Decoder inversion is not implemented for video latents")
244 return latents
246def tensor2vid(video: torch.Tensor, processor, output_type: str = "np"):
247 """
248 Convert video tensor to desired output format.
250 Args:
251 video: Video tensor [B, C, F, H, W]
252 processor: Video processor from the diffusion pipeline
253 output_type: Output type - 'np', 'pt', or 'pil'
255 Returns:
256 Video in requested format
257 """
258 batch_size, channels, num_frames, height, width = video.shape
259 outputs = []
260 for batch_idx in range(batch_size):
261 batch_vid = video[batch_idx].permute(1, 0, 2, 3)
262 batch_output = processor.postprocess(batch_vid, output_type)
263 outputs.append(batch_output)
265 if output_type == "np":
266 outputs = np.stack(outputs)
267 elif output_type == "pt":
268 outputs = torch.stack(outputs)
269 elif not output_type == "pil":
270 raise ValueError(f"{output_type} does not exist. Please choose one of ['np', 'pt', 'pil']")
272 return outputs
274def _decode_video_latents(pipe: Union[TextToVideoSDPipeline, StableVideoDiffusionPipeline],
275 latents: torch.Tensor,
276 num_frames: Optional[int] = None) -> np.ndarray:
277 """
278 Decode latents to video frames.
280 Args:
281 pipe: Video diffusion pipeline
282 latents: Video latents
283 num_frames: Number of frames to decode
285 Returns:
286 np.ndarray: Video frames
287 """
288 if num_frames is None:
289 video_tensor = pipe.decode_latents(latents)
290 else:
291 video_tensor = pipe.decode_latents(latents, num_frames)
292 video = tensor2vid(video_tensor, pipe.video_processor)
293 return video
295def convert_video_frames_to_images(frames: List[Union[np.ndarray, Image.Image]]) -> List[Image.Image]:
296 """
297 Convert video frames to a list of PIL.Image objects.
299 Args:
300 frames: List of video frames (numpy arrays or PIL images)
302 Returns:
303 List[Image.Image]: List of PIL images
304 """
305 pil_frames = []
306 for frame in frames:
307 if isinstance(frame, np.ndarray):
308 # Convert numpy array to PIL
309 pil_frames.append(numpy_to_pil(frame))
310 elif isinstance(frame, Image.Image):
311 # Already a PIL image
312 pil_frames.append(frame)
313 else:
314 raise ValueError(f"Unsupported frame type: {type(frame)}")
315 return pil_frames
317def save_video_frames(frames: List[Union[np.ndarray, Image.Image]], save_dir: str) -> None:
318 """
319 Save video frames to a directory.
321 Args:
322 frames: List of video frames (numpy arrays or PIL images)
323 save_dir: Directory to save frames
324 """
325 if isinstance(frames[0], np.ndarray):
326 frames = [(frame * 255).astype(np.uint8) if frame.dtype != np.uint8 else frame for frame in frames]
327 elif isinstance(frames[0], Image.Image):
328 frames = [np.array(frame) for frame in frames]
330 for i, frame in enumerate(frames):
331 img = cv2.cvtColor(frame, cv2.COLOR_RGB2BGR)
332 cv2.imwrite(f'{save_dir}/{i:02d}.png', img)
334# ===== Utility Functions for Working with Different Pipeline Types =====
336def get_media_latents(pipe: Union[StableDiffusionPipeline, TextToVideoSDPipeline, StableVideoDiffusionPipeline],
337 media: Union[torch.Tensor, List[torch.Tensor]],
338 sample: bool = True,
339 rng_generator: Optional[torch.Generator] = None,
340 decoder_inv: bool = False) -> torch.Tensor:
341 """
342 Get latents from media (either image or video) based on pipeline type.
344 Args:
345 pipe: Diffusion pipeline
346 media: Image tensor or video frames tensor
347 sample: Whether to sample from the latent distribution
348 rng_generator: Random generator for sampling
349 decoder_inv: Whether to use decoder inversion optimization
350 Returns:
351 torch.Tensor: Media latents
352 """
353 pipeline_type = get_pipeline_type(pipe)
355 if pipeline_type == PIPELINE_TYPE_IMAGE:
356 return _get_image_latents(pipe, media, sample, rng_generator, decoder_inv)
357 elif pipeline_type in [PIPELINE_TYPE_TEXT_TO_VIDEO, PIPELINE_TYPE_IMAGE_TO_VIDEO]:
358 permute = pipeline_type == PIPELINE_TYPE_TEXT_TO_VIDEO
359 return _get_video_latents(pipe, media, sample, rng_generator, permute, decoder_inv)
360 else:
361 raise ValueError(f"Unsupported pipeline type: {pipeline_type}")
363def decode_media_latents(pipe: Union[StableDiffusionPipeline, TextToVideoSDPipeline, StableVideoDiffusionPipeline],
364 latents: torch.Tensor,
365 num_frames: Optional[int] = None) -> Union[torch.Tensor, np.ndarray]:
366 """
367 Decode latents to media (either image or video) based on pipeline type.
369 Args:
370 pipe: Diffusion pipeline
371 latents: Media latents
372 num_frames: Number of frames (for video)
374 Returns:
375 Union[torch.Tensor, np.ndarray]: Decoded media
376 """
377 pipeline_type = get_pipeline_type(pipe)
379 if pipeline_type == PIPELINE_TYPE_IMAGE:
380 return _decode_image_latents(pipe, latents)
381 elif pipeline_type in [PIPELINE_TYPE_TEXT_TO_VIDEO, PIPELINE_TYPE_IMAGE_TO_VIDEO]:
382 return _decode_video_latents(pipe, latents, num_frames)
383 else:
384 raise ValueError(f"Unsupported pipeline type: {pipeline_type}")