Coverage for inversions / base_inversion.py: 86.96%
46 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1import torch
2from typing import Optional, Callable
4class BaseInversion():
5 def __init__(self,
6 scheduler,
7 unet,
8 device,
9 ):
10 self.scheduler = scheduler
11 self.unet = unet
12 self.device = device
14 def _prepare_latent_for_unet(self, latents, do_cfg, unet):
15 """
16 Inputs:
17 latents: [B,C,H,W] or [B,F,C,H,W]
18 do_cfg: bool
19 Outputs:
20 latent_model_input: Tensor ready for UNet input
21 info: dict containing shape info
22 """
24 is_video_unet = any(isinstance(m, torch.nn.Conv3d) for m in unet.modules())
26 info = {
27 "do_cfg": do_cfg,
28 "is_video_unet": is_video_unet
29 }
31 # ------------------------------
32 # Case 1: image latent (4D)
33 # ------------------------------
34 if latents.ndim == 4:
35 # [B, C, H, W]
36 info["shp"] = latents.shape
37 if do_cfg:
38 latents = torch.cat([latents, latents], dim=0)
39 return latents, info
41 # ------------------------------
42 # Case 2: video latent (5D)
43 # ------------------------------
44 assert latents.ndim == 5, "Video input must be 4D or 5D latent."
45 B, F, C, H, W = latents.shape
46 info["shp"] = (B, F, C, H, W)
48 if is_video_unet:
49 # video UNet (Conv3d): [B, C, F, H, W]
50 latents = latents.permute(0, 2, 1, 3, 4).contiguous()
51 if do_cfg:
52 latents = torch.cat([latents, latents], dim=0)
53 return latents, info
55 else:
56 # image UNet but input video → flatten frames
57 latents = latents.reshape(B * F, C, H, W)
58 info["flatten"] = (B, F)
59 if do_cfg:
60 latents = torch.cat([latents, latents], dim=0)
61 return latents, info
63 def _restore_latent_from_unet(self, noise_pred, info, guidance_scale):
64 """
65 Inputs:
66 noise_pred: UNet Input
67 info: prepare 阶段保存的结构信息
68 guidance_scale: CFG scale
69 输出:
70 与原输入匹配格式的噪声预测
71 图像: [B,C,H,W]
72 视频: [B,F,C,H,W]
73 """
75 do_cfg = info["do_cfg"]
76 is_video_unet = info["is_video_unet"]
77 shp = info["shp"]
79 # 1. CFG 合并
80 if do_cfg:
81 noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2, dim=0)
82 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
84 # --------------------------
85 # Case 1: 输入是图像 latent
86 # --------------------------
87 if len(shp) == 4:
88 # [B, C, H, W] — no reshape needed
89 return noise_pred
91 # --------------------------
92 # Case 2: 输入是视频 latent
93 # --------------------------
94 B, F, C, H, W = shp
96 if is_video_unet:
97 # video UNet 输出格式: [B, C, F, H, W]
98 noise_pred = noise_pred.permute(0, 2, 1, 3, 4).contiguous()
99 return noise_pred
101 else:
102 # 图像 UNet 输出格式: [B*F, C, H, W]
103 noise_pred = noise_pred.reshape(B, F, C, H, W)
104 return noise_pred
106 @torch.inference_mode()
107 def forward_diffusion(self,
108 use_old_emb_i=25,
109 text_embeddings=None,
110 old_text_embeddings=None,
111 new_text_embeddings=None,
112 latents: Optional[torch.FloatTensor] = None,
113 num_inference_steps: int = 10,
114 guidance_scale: float = 7.5,
115 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
116 callback_steps: Optional[int] = 1,
117 inverse_opt=True,
118 inv_order=None,
119 **kwargs,
120 ):
121 pass
123 # def _apply_guidance_scale(self, model_output, guidance_scale):
124 # if guidance_scale > 1.0:
125 # noise_pred_uncond, noise_pred_text = model_output.chunk(2)
126 # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
127 # return noise_pred
128 # else:
129 # return model_output