Coverage for inversions / exact_inversion.py: 91.58%

95 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from .base_inversion import BaseInversion 

2import torch 

3from typing import Optional, Callable 

4from tqdm import tqdm 

5from torch.optim.lr_scheduler import ReduceLROnPlateau 

6# from utils.DPMSolverPatch import convert_model_output 

7from diffusers import DPMSolverMultistepInverseScheduler 

8 

9class ExactInversion(BaseInversion): 

10 def __init__(self, 

11 scheduler, 

12 unet, 

13 device, 

14 ): 

15 scheduler = DPMSolverMultistepInverseScheduler.from_config(scheduler.config) 

16 super(ExactInversion, self).__init__(scheduler, unet, device) 

17 

18 @torch.inference_mode() 

19 def forward_diffusion( 

20 self, 

21 use_old_emb_i=25, 

22 text_embeddings=None, 

23 old_text_embeddings=None, 

24 new_text_embeddings=None, 

25 latents: Optional[torch.FloatTensor] = None, 

26 num_inference_steps: int = 10, 

27 guidance_scale: float = 7.5, 

28 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 

29 callback_steps: Optional[int] = 1, 

30 inverse_opt=False, 

31 inv_order=0, 

32 **kwargs, 

33 ): 

34 with torch.no_grad(): 

35 # Keep a list of inverted latents as the process goes on 

36 intermediate_latents = [] 

37 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 

38 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 

39 # corresponds to doing no classifier free guidance. 

40 do_classifier_free_guidance = guidance_scale > 1.0 

41 

42 self.scheduler.set_timesteps(num_inference_steps) 

43 timesteps_tensor = self.scheduler.timesteps.to(self.device) 

44 latents = latents * self.scheduler.init_noise_sigma 

45 

46 if old_text_embeddings is not None and new_text_embeddings is not None: 

47 prompt_to_prompt = True 

48 else: 

49 prompt_to_prompt = False 

50 

51 if inv_order is None: 

52 inv_order = self.scheduler.solver_order 

53 inverse_opt = (inv_order != 0) 

54 

55 # timesteps_tensor = reversed(timesteps_tensor) # inversion process 

56 

57 self.unet = self.unet.float() 

58 latents = latents.float() 

59 text_embeddings = text_embeddings.float() 

60 

61 for i, t in enumerate(tqdm(timesteps_tensor)): 

62 if self.scheduler.step_index is None: 

63 self.scheduler._init_step_index(t) 

64 

65 if prompt_to_prompt: 

66 if i < use_old_emb_i: 

67 text_embeddings = old_text_embeddings 

68 else: 

69 text_embeddings = new_text_embeddings 

70 

71 if i+1 < len(timesteps_tensor): 

72 next_timestep = timesteps_tensor[i+1] 

73 else: 

74 next_timestep = ( 

75 t 

76 + self.scheduler.config.num_train_timesteps 

77 // self.scheduler.num_inference_steps 

78 ) 

79 next_timestep = min(next_timestep, self.scheduler.config.num_train_timesteps - 1) 

80 

81 

82 # call the callback, if provided 

83 if callback is not None and i % callback_steps == 0: 

84 callback(i, t, latents) 

85 

86 

87 # Our Algorithm 

88 

89 # Algorithm 1 

90 if inv_order < 2 or (inv_order == 2 and i == 0): 

91 # s = t  

92 # t = prev_timestep 

93 s = next_timestep 

94 t = ( 

95 next_timestep 

96 - self.scheduler.config.num_train_timesteps 

97 // self.scheduler.num_inference_steps 

98 ) 

99 t = max(t, 0) # Ensure t is not negative 

100 

101 lambda_s, lambda_t = self.scheduler.lambda_t[s], self.scheduler.lambda_t[t] 

102 sigma_s, sigma_t = self.scheduler.sigma_t[s], self.scheduler.sigma_t[t] 

103 h = lambda_t - lambda_s 

104 alpha_s, alpha_t = self.scheduler.alpha_t[s], self.scheduler.alpha_t[t] 

105 phi_1 = torch.expm1(-h) 

106 

107 # expand the latents if classifier free guidance is used 

108 latent_model_input, info = self._prepare_latent_for_unet(latents, do_classifier_free_guidance, self.unet) 

109 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 

110 # predict the noise residual 

111 noise_pred_raw = self.unet(latent_model_input, s, encoder_hidden_states=text_embeddings).sample 

112 noise_pred = self._restore_latent_from_unet(noise_pred_raw, info, guidance_scale) 

113 

114 model_s = self.scheduler.convert_model_output(model_output=noise_pred, sample=latents) 

115 x_t = latents 

116 

117 # Line 5 

118 latents = (sigma_s / sigma_t) * (latents + alpha_t * phi_1 * model_s) 

119 

120 # Save intermediate latents 

121 intermediate_latents.append(latents.clone()) 

122 

123 self.scheduler._step_index += 1 

124 

125 else: 

126 pass 

127 

128 return intermediate_latents 

129 

130 @torch.inference_mode() 

131 def backward_diffusion( 

132 self, 

133 latents: Optional[torch.FloatTensor] = None, 

134 num_inference_steps: int = 10, 

135 guidance_scale: float = 7.5, 

136 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 

137 callback_steps: Optional[int] = 1, 

138 inv_order=None, 

139 **kwargs, 

140 ): 

141 """ 

142 Reconstruct z_0 from z_T via the forward diffusion process 

143 """ 

144 with torch.no_grad(): 

145 # 1. Setup 

146 do_classifier_free_guidance = guidance_scale > 1.0 

147 self.scheduler.set_timesteps(num_inference_steps) 

148 timesteps_tensor = self.scheduler.timesteps.to(self.device) 

149 

150 # If no inv_order provided, default to scheduler's configuration 

151 if inv_order is None: 

152 inv_order = self.scheduler.solver_order 

153 

154 self.unet = self.unet.float() 

155 latents = latents.float() 

156 

157 # last output from the model to be used in higher order methods 

158 old_model_output = None 

159 

160 # 2. Denoising Loop (T -> 0) 

161 for i, t in enumerate(tqdm(timesteps_tensor)): 

162 if self.scheduler.step_index is None: 

163 self.scheduler._init_step_index(t) 

164 

165 # s (prev_timestep in diffusion terms, lower noise) 

166 if i + 1 < len(timesteps_tensor): 

167 s = timesteps_tensor[i + 1] 

168 else: 

169 s = torch.tensor(0, device=self.device) 

170 

171 # 3. Prepare Model Input 

172 latent_model_input, info = self._prepare_latent_for_unet(latents, do_classifier_free_guidance, self.unet) 

173 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 

174 

175 # 4. Predict Noise/Data 

176 noise_pred_raw = self.unet(latent_model_input, t, encoder_hidden_states=kwargs.get("text_embeddings")).sample 

177 noise_pred = self._restore_latent_from_unet(noise_pred_raw, info, guidance_scale) 

178 

179 # Transform prediction according to the type of prediction required by the scheduler 

180 model_output = self.scheduler.convert_model_output(model_output=noise_pred, sample=latents) 

181 

182 # 5. Calculate Solver Parameters 

183 # Aquire alpha, sigma, lambda 

184 lambda_t, lambda_s = self.scheduler.lambda_t[t], self.scheduler.lambda_t[s] 

185 alpha_t, alpha_s = self.scheduler.alpha_t[t], self.scheduler.alpha_t[s] 

186 sigma_t, sigma_s = self.scheduler.sigma_t[t], self.scheduler.sigma_t[s] 

187 

188 h = lambda_s - lambda_t # step size 

189 phi_1 = torch.expm1(-h) # e^{-h} - 1 

190 

191 # 6. Sampling Step (Explicit) 

192 

193 # Case 1: First Order (DDIM) or First Step of Second Order 

194 if inv_order == 1 or i == 0: 

195 # Eq. (5): Forward Euler 

196 # x_{t_i} = (sigma_{t_i} / sigma_{t_{i-1}}) * x_{t_{i-1}} - alpha_{t_i} * (e^{-h} - 1) * x_theta 

197 # x_s = (sigma_s/sigma_t) * latents - alpha_s * phi_1 * model_output 

198 latents = (sigma_s / sigma_t) * latents - (alpha_s * phi_1) * model_output 

199 else: 

200 pass 

201 

202 # Update history 

203 old_model_output = model_output 

204 self.scheduler._step_index += 1 

205 

206 if callback is not None and i % callback_steps == 0: 

207 callback(i, t, latents) 

208 

209 return latents