Coverage for inversions / base_inversion.py: 86.96%

46 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1import torch 

2from typing import Optional, Callable 

3 

4class BaseInversion(): 

5 def __init__(self, 

6 scheduler, 

7 unet, 

8 device, 

9 ): 

10 self.scheduler = scheduler 

11 self.unet = unet 

12 self.device = device 

13 

14 def _prepare_latent_for_unet(self, latents, do_cfg, unet): 

15 """ 

16 Inputs: 

17 latents: [B,C,H,W] or [B,F,C,H,W] 

18 do_cfg: bool 

19 Outputs: 

20 latent_model_input: Tensor ready for UNet input 

21 info: dict containing shape info 

22 """ 

23 

24 is_video_unet = any(isinstance(m, torch.nn.Conv3d) for m in unet.modules()) 

25 

26 info = { 

27 "do_cfg": do_cfg, 

28 "is_video_unet": is_video_unet 

29 } 

30 

31 # ------------------------------ 

32 # Case 1: image latent (4D) 

33 # ------------------------------ 

34 if latents.ndim == 4: 

35 # [B, C, H, W] 

36 info["shp"] = latents.shape 

37 if do_cfg: 

38 latents = torch.cat([latents, latents], dim=0) 

39 return latents, info 

40 

41 # ------------------------------ 

42 # Case 2: video latent (5D) 

43 # ------------------------------ 

44 assert latents.ndim == 5, "Video input must be 4D or 5D latent." 

45 B, F, C, H, W = latents.shape 

46 info["shp"] = (B, F, C, H, W) 

47 

48 if is_video_unet: 

49 # video UNet (Conv3d): [B, C, F, H, W] 

50 latents = latents.permute(0, 2, 1, 3, 4).contiguous() 

51 if do_cfg: 

52 latents = torch.cat([latents, latents], dim=0) 

53 return latents, info 

54 

55 else: 

56 # image UNet but input video → flatten frames 

57 latents = latents.reshape(B * F, C, H, W) 

58 info["flatten"] = (B, F) 

59 if do_cfg: 

60 latents = torch.cat([latents, latents], dim=0) 

61 return latents, info 

62 

63 def _restore_latent_from_unet(self, noise_pred, info, guidance_scale): 

64 """ 

65 Inputs: 

66 noise_pred: UNet Input 

67 info: prepare 阶段保存的结构信息 

68 guidance_scale: CFG scale 

69 输出: 

70 与原输入匹配格式的噪声预测 

71 图像: [B,C,H,W] 

72 视频: [B,F,C,H,W] 

73 """ 

74 

75 do_cfg = info["do_cfg"] 

76 is_video_unet = info["is_video_unet"] 

77 shp = info["shp"] 

78 

79 # 1. CFG 合并 

80 if do_cfg: 

81 noise_pred_cond, noise_pred_uncond = noise_pred.chunk(2, dim=0) 

82 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond) 

83 

84 # -------------------------- 

85 # Case 1: 输入是图像 latent 

86 # -------------------------- 

87 if len(shp) == 4: 

88 # [B, C, H, W] — no reshape needed 

89 return noise_pred 

90 

91 # -------------------------- 

92 # Case 2: 输入是视频 latent 

93 # -------------------------- 

94 B, F, C, H, W = shp 

95 

96 if is_video_unet: 

97 # video UNet 输出格式: [B, C, F, H, W] 

98 noise_pred = noise_pred.permute(0, 2, 1, 3, 4).contiguous() 

99 return noise_pred 

100 

101 else: 

102 # 图像 UNet 输出格式: [B*F, C, H, W] 

103 noise_pred = noise_pred.reshape(B, F, C, H, W) 

104 return noise_pred 

105 

106 @torch.inference_mode() 

107 def forward_diffusion(self, 

108 use_old_emb_i=25, 

109 text_embeddings=None, 

110 old_text_embeddings=None, 

111 new_text_embeddings=None, 

112 latents: Optional[torch.FloatTensor] = None, 

113 num_inference_steps: int = 10, 

114 guidance_scale: float = 7.5, 

115 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 

116 callback_steps: Optional[int] = 1, 

117 inverse_opt=True, 

118 inv_order=None, 

119 **kwargs, 

120 ): 

121 pass 

122 

123 # def _apply_guidance_scale(self, model_output, guidance_scale): 

124 # if guidance_scale > 1.0: 

125 # noise_pred_uncond, noise_pred_text = model_output.chunk(2) 

126 # noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 

127 # return noise_pred 

128 # else: 

129 # return model_output