Coverage for inversions / ddim_inversion.py: 84.44%

45 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 11:32 +0000

1from functools import partial 

2import torch 

3from typing import Optional, Callable 

4from tqdm import tqdm 

5from .base_inversion import BaseInversion 

6import warnings 

7 

8class DDIMInversion(BaseInversion): 

9 def __init__(self, 

10 scheduler, 

11 unet, 

12 device, 

13 ): 

14 super(DDIMInversion, self).__init__(scheduler, unet, device) 

15 self.forward_diffusion = partial(self.backward_diffusion, reverse_process=True) 

16 

17 def _backward_ddim(self, x_t, alpha_t, alpha_tm1, eps_xt): 

18 """ from noise to image""" 

19 return ( 

20 alpha_tm1**0.5 

21 * ( 

22 (alpha_t**-0.5 - alpha_tm1**-0.5) * x_t 

23 + ((1 / alpha_tm1 - 1) ** 0.5 - (1 / alpha_t - 1) ** 0.5) * eps_xt 

24 ) 

25 + x_t 

26 ) 

27 

28 @torch.inference_mode() 

29 def backward_diffusion( 

30 self, 

31 use_old_emb_i=25, 

32 text_embeddings=None, 

33 old_text_embeddings=None, 

34 new_text_embeddings=None, 

35 latents: Optional[torch.FloatTensor] = None, 

36 num_inference_steps: int = 50, 

37 guidance_scale: float = 7.5, 

38 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 

39 callback_steps: Optional[int] = 1, 

40 reverse_process: True = False, 

41 **kwargs, 

42 ): 

43 """ Generate image from text prompt and latents 

44 """ 

45 ## If kwargs has inv_order, warn that it is ignored for DDIM Inversion 

46 if "inv_order" in kwargs: 

47 warnings.warn("inv_order is ignored for DDIM Inversion") 

48 if "inverse_opt" in kwargs: 

49 warnings.warn("inverse_opt is ignored for DDIM Inversion") 

50 

51 # Keep a list of inverted latents as the process goes on 

52 intermediate_latents = [] 

53 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 

54 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 

55 # corresponds to doing no classifier free guidance. 

56 do_classifier_free_guidance = guidance_scale > 1.0 

57 # set timesteps 

58 self.scheduler.set_timesteps(num_inference_steps) 

59 # Some schedulers like PNDM have timesteps as arrays 

60 # It's more optimized to move all timesteps to correct device beforehand 

61 timesteps_tensor = self.scheduler.timesteps.to(self.device) 

62 # scale the initial noise by the standard deviation required by the scheduler 

63 latents = latents * self.scheduler.init_noise_sigma 

64 

65 if old_text_embeddings is not None and new_text_embeddings is not None: 

66 prompt_to_prompt = True 

67 else: 

68 prompt_to_prompt = False 

69 

70 

71 for i, t in enumerate(tqdm(timesteps_tensor if not reverse_process else reversed(timesteps_tensor))): 

72 if prompt_to_prompt: 

73 if i < use_old_emb_i: 

74 text_embeddings = old_text_embeddings 

75 else: 

76 text_embeddings = new_text_embeddings 

77 

78 # expand the latents if we are doing classifier free guidance 

79 # latent_model_input = ( 

80 # torch.cat([latents] * 2) if do_classifier_free_guidance else latents 

81 # ) 

82 latent_model_input, info = self._prepare_latent_for_unet(latents, do_classifier_free_guidance, self.unet) 

83 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 

84 

85 # predict the noise residual 

86 noise_pred_raw = self.unet( 

87 latent_model_input, t, encoder_hidden_states=text_embeddings 

88 ).sample 

89 

90 # reshape back if needed 

91 noise_pred = self._restore_latent_from_unet(noise_pred_raw, info, guidance_scale) 

92 

93 # # perform guidance 

94 # noise_pred = self._apply_guidance_scale(noise_pred, guidance_scale) 

95 

96 prev_timestep = ( 

97 t 

98 - self.scheduler.config.num_train_timesteps 

99 // self.scheduler.num_inference_steps 

100 ) 

101 # call the callback, if provided 

102 if callback is not None and i % callback_steps == 0: 

103 callback(i, t, latents) 

104 

105 # ddim  

106 alpha_prod_t = self.scheduler.alphas_cumprod[t] 

107 alpha_prod_t_prev = ( 

108 self.scheduler.alphas_cumprod[prev_timestep] 

109 if prev_timestep >= 0 

110 else getattr(self.scheduler, 'final_alpha_cumprod', 1.0) 

111 ) 

112 if reverse_process: 

113 alpha_prod_t, alpha_prod_t_prev = alpha_prod_t_prev, alpha_prod_t 

114 latents = self._backward_ddim( 

115 x_t=latents, 

116 alpha_t=alpha_prod_t, 

117 alpha_tm1=alpha_prod_t_prev, 

118 eps_xt=noise_pred, 

119 ) 

120 # Save intermediate latents 

121 intermediate_latents.append(latents.clone()) 

122 return intermediate_latents