Coverage for inversions / exact_inversion.py: 91.58%
95 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from .base_inversion import BaseInversion
2import torch
3from typing import Optional, Callable
4from tqdm import tqdm
5from torch.optim.lr_scheduler import ReduceLROnPlateau
6# from utils.DPMSolverPatch import convert_model_output
7from diffusers import DPMSolverMultistepInverseScheduler
9class ExactInversion(BaseInversion):
10 def __init__(self,
11 scheduler,
12 unet,
13 device,
14 ):
15 scheduler = DPMSolverMultistepInverseScheduler.from_config(scheduler.config)
16 super(ExactInversion, self).__init__(scheduler, unet, device)
18 @torch.inference_mode()
19 def forward_diffusion(
20 self,
21 use_old_emb_i=25,
22 text_embeddings=None,
23 old_text_embeddings=None,
24 new_text_embeddings=None,
25 latents: Optional[torch.FloatTensor] = None,
26 num_inference_steps: int = 10,
27 guidance_scale: float = 7.5,
28 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
29 callback_steps: Optional[int] = 1,
30 inverse_opt=False,
31 inv_order=0,
32 **kwargs,
33 ):
34 with torch.no_grad():
35 # Keep a list of inverted latents as the process goes on
36 intermediate_latents = []
37 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
38 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
39 # corresponds to doing no classifier free guidance.
40 do_classifier_free_guidance = guidance_scale > 1.0
42 self.scheduler.set_timesteps(num_inference_steps)
43 timesteps_tensor = self.scheduler.timesteps.to(self.device)
44 latents = latents * self.scheduler.init_noise_sigma
46 if old_text_embeddings is not None and new_text_embeddings is not None:
47 prompt_to_prompt = True
48 else:
49 prompt_to_prompt = False
51 if inv_order is None:
52 inv_order = self.scheduler.solver_order
53 inverse_opt = (inv_order != 0)
55 # timesteps_tensor = reversed(timesteps_tensor) # inversion process
57 self.unet = self.unet.float()
58 latents = latents.float()
59 text_embeddings = text_embeddings.float()
61 for i, t in enumerate(tqdm(timesteps_tensor)):
62 if self.scheduler.step_index is None:
63 self.scheduler._init_step_index(t)
65 if prompt_to_prompt:
66 if i < use_old_emb_i:
67 text_embeddings = old_text_embeddings
68 else:
69 text_embeddings = new_text_embeddings
71 if i+1 < len(timesteps_tensor):
72 next_timestep = timesteps_tensor[i+1]
73 else:
74 next_timestep = (
75 t
76 + self.scheduler.config.num_train_timesteps
77 // self.scheduler.num_inference_steps
78 )
79 next_timestep = min(next_timestep, self.scheduler.config.num_train_timesteps - 1)
82 # call the callback, if provided
83 if callback is not None and i % callback_steps == 0:
84 callback(i, t, latents)
87 # Our Algorithm
89 # Algorithm 1
90 if inv_order < 2 or (inv_order == 2 and i == 0):
91 # s = t
92 # t = prev_timestep
93 s = next_timestep
94 t = (
95 next_timestep
96 - self.scheduler.config.num_train_timesteps
97 // self.scheduler.num_inference_steps
98 )
99 t = max(t, 0) # Ensure t is not negative
101 lambda_s, lambda_t = self.scheduler.lambda_t[s], self.scheduler.lambda_t[t]
102 sigma_s, sigma_t = self.scheduler.sigma_t[s], self.scheduler.sigma_t[t]
103 h = lambda_t - lambda_s
104 alpha_s, alpha_t = self.scheduler.alpha_t[s], self.scheduler.alpha_t[t]
105 phi_1 = torch.expm1(-h)
107 # expand the latents if classifier free guidance is used
108 latent_model_input, info = self._prepare_latent_for_unet(latents, do_classifier_free_guidance, self.unet)
109 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
110 # predict the noise residual
111 noise_pred_raw = self.unet(latent_model_input, s, encoder_hidden_states=text_embeddings).sample
112 noise_pred = self._restore_latent_from_unet(noise_pred_raw, info, guidance_scale)
114 model_s = self.scheduler.convert_model_output(model_output=noise_pred, sample=latents)
115 x_t = latents
117 # Line 5
118 latents = (sigma_s / sigma_t) * (latents + alpha_t * phi_1 * model_s)
120 # Save intermediate latents
121 intermediate_latents.append(latents.clone())
123 self.scheduler._step_index += 1
125 else:
126 pass
128 return intermediate_latents
130 @torch.inference_mode()
131 def backward_diffusion(
132 self,
133 latents: Optional[torch.FloatTensor] = None,
134 num_inference_steps: int = 10,
135 guidance_scale: float = 7.5,
136 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
137 callback_steps: Optional[int] = 1,
138 inv_order=None,
139 **kwargs,
140 ):
141 """
142 Reconstruct z_0 from z_T via the forward diffusion process
143 """
144 with torch.no_grad():
145 # 1. Setup
146 do_classifier_free_guidance = guidance_scale > 1.0
147 self.scheduler.set_timesteps(num_inference_steps)
148 timesteps_tensor = self.scheduler.timesteps.to(self.device)
150 # If no inv_order provided, default to scheduler's configuration
151 if inv_order is None:
152 inv_order = self.scheduler.solver_order
154 self.unet = self.unet.float()
155 latents = latents.float()
157 # last output from the model to be used in higher order methods
158 old_model_output = None
160 # 2. Denoising Loop (T -> 0)
161 for i, t in enumerate(tqdm(timesteps_tensor)):
162 if self.scheduler.step_index is None:
163 self.scheduler._init_step_index(t)
165 # s (prev_timestep in diffusion terms, lower noise)
166 if i + 1 < len(timesteps_tensor):
167 s = timesteps_tensor[i + 1]
168 else:
169 s = torch.tensor(0, device=self.device)
171 # 3. Prepare Model Input
172 latent_model_input, info = self._prepare_latent_for_unet(latents, do_classifier_free_guidance, self.unet)
173 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
175 # 4. Predict Noise/Data
176 noise_pred_raw = self.unet(latent_model_input, t, encoder_hidden_states=kwargs.get("text_embeddings")).sample
177 noise_pred = self._restore_latent_from_unet(noise_pred_raw, info, guidance_scale)
179 # Transform prediction according to the type of prediction required by the scheduler
180 model_output = self.scheduler.convert_model_output(model_output=noise_pred, sample=latents)
182 # 5. Calculate Solver Parameters
183 # Aquire alpha, sigma, lambda
184 lambda_t, lambda_s = self.scheduler.lambda_t[t], self.scheduler.lambda_t[s]
185 alpha_t, alpha_s = self.scheduler.alpha_t[t], self.scheduler.alpha_t[s]
186 sigma_t, sigma_s = self.scheduler.sigma_t[t], self.scheduler.sigma_t[s]
188 h = lambda_s - lambda_t # step size
189 phi_1 = torch.expm1(-h) # e^{-h} - 1
191 # 6. Sampling Step (Explicit)
193 # Case 1: First Order (DDIM) or First Step of Second Order
194 if inv_order == 1 or i == 0:
195 # Eq. (5): Forward Euler
196 # x_{t_i} = (sigma_{t_i} / sigma_{t_{i-1}}) * x_{t_{i-1}} - alpha_{t_i} * (e^{-h} - 1) * x_theta
197 # x_s = (sigma_s/sigma_t) * latents - alpha_s * phi_1 * model_output
198 latents = (sigma_s / sigma_t) * latents - (alpha_s * phi_1) * model_output
199 else:
200 pass
202 # Update history
203 old_model_output = model_output
204 self.scheduler._step_index += 1
206 if callback is not None and i % callback_steps == 0:
207 callback(i, t, latents)
209 return latents