Coverage for inversions / ddim_inversion.py: 84.44%
45 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 11:32 +0000
1from functools import partial
2import torch
3from typing import Optional, Callable
4from tqdm import tqdm
5from .base_inversion import BaseInversion
6import warnings
8class DDIMInversion(BaseInversion):
9 def __init__(self,
10 scheduler,
11 unet,
12 device,
13 ):
14 super(DDIMInversion, self).__init__(scheduler, unet, device)
15 self.forward_diffusion = partial(self.backward_diffusion, reverse_process=True)
17 def _backward_ddim(self, x_t, alpha_t, alpha_tm1, eps_xt):
18 """ from noise to image"""
19 return (
20 alpha_tm1**0.5
21 * (
22 (alpha_t**-0.5 - alpha_tm1**-0.5) * x_t
23 + ((1 / alpha_tm1 - 1) ** 0.5 - (1 / alpha_t - 1) ** 0.5) * eps_xt
24 )
25 + x_t
26 )
28 @torch.inference_mode()
29 def backward_diffusion(
30 self,
31 use_old_emb_i=25,
32 text_embeddings=None,
33 old_text_embeddings=None,
34 new_text_embeddings=None,
35 latents: Optional[torch.FloatTensor] = None,
36 num_inference_steps: int = 50,
37 guidance_scale: float = 7.5,
38 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
39 callback_steps: Optional[int] = 1,
40 reverse_process: True = False,
41 **kwargs,
42 ):
43 """ Generate image from text prompt and latents
44 """
45 ## If kwargs has inv_order, warn that it is ignored for DDIM Inversion
46 if "inv_order" in kwargs:
47 warnings.warn("inv_order is ignored for DDIM Inversion")
48 if "inverse_opt" in kwargs:
49 warnings.warn("inverse_opt is ignored for DDIM Inversion")
51 # Keep a list of inverted latents as the process goes on
52 intermediate_latents = []
53 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
54 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
55 # corresponds to doing no classifier free guidance.
56 do_classifier_free_guidance = guidance_scale > 1.0
57 # set timesteps
58 self.scheduler.set_timesteps(num_inference_steps)
59 # Some schedulers like PNDM have timesteps as arrays
60 # It's more optimized to move all timesteps to correct device beforehand
61 timesteps_tensor = self.scheduler.timesteps.to(self.device)
62 # scale the initial noise by the standard deviation required by the scheduler
63 latents = latents * self.scheduler.init_noise_sigma
65 if old_text_embeddings is not None and new_text_embeddings is not None:
66 prompt_to_prompt = True
67 else:
68 prompt_to_prompt = False
71 for i, t in enumerate(tqdm(timesteps_tensor if not reverse_process else reversed(timesteps_tensor))):
72 if prompt_to_prompt:
73 if i < use_old_emb_i:
74 text_embeddings = old_text_embeddings
75 else:
76 text_embeddings = new_text_embeddings
78 # expand the latents if we are doing classifier free guidance
79 # latent_model_input = (
80 # torch.cat([latents] * 2) if do_classifier_free_guidance else latents
81 # )
82 latent_model_input, info = self._prepare_latent_for_unet(latents, do_classifier_free_guidance, self.unet)
83 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
85 # predict the noise residual
86 noise_pred_raw = self.unet(
87 latent_model_input, t, encoder_hidden_states=text_embeddings
88 ).sample
90 # reshape back if needed
91 noise_pred = self._restore_latent_from_unet(noise_pred_raw, info, guidance_scale)
93 # # perform guidance
94 # noise_pred = self._apply_guidance_scale(noise_pred, guidance_scale)
96 prev_timestep = (
97 t
98 - self.scheduler.config.num_train_timesteps
99 // self.scheduler.num_inference_steps
100 )
101 # call the callback, if provided
102 if callback is not None and i % callback_steps == 0:
103 callback(i, t, latents)
105 # ddim
106 alpha_prod_t = self.scheduler.alphas_cumprod[t]
107 alpha_prod_t_prev = (
108 self.scheduler.alphas_cumprod[prev_timestep]
109 if prev_timestep >= 0
110 else getattr(self.scheduler, 'final_alpha_cumprod', 1.0)
111 )
112 if reverse_process:
113 alpha_prod_t, alpha_prod_t_prev = alpha_prod_t_prev, alpha_prod_t
114 latents = self._backward_ddim(
115 x_t=latents,
116 alpha_t=alpha_prod_t,
117 alpha_tm1=alpha_prod_t_prev,
118 eps_xt=noise_pred,
119 )
120 # Save intermediate latents
121 intermediate_latents.append(latents.clone())
122 return intermediate_latents