Coverage for utils / callbacks.py: 100.00%

29 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 10:24 +0000

1import torch 

2from typing import List 

3 

4class DenoisingLatentsCollector: 

5 def __init__(self, save_every_n_steps: int = 1, to_cpu: bool = True): 

6 """Initialize the latents collector. 

7 

8 Args: 

9 save_every_n_steps (int, optional): Save latents every n steps. Defaults to 1. 

10 to_cpu (bool, optional): Whether to move latents to CPU. Defaults to True. 

11 """ 

12 

13 self.save_every_n_steps = save_every_n_steps 

14 self.to_cpu = to_cpu 

15 self.data = [] 

16 self._call_count = 0 

17 

18 def __call__(self, step: int, timestep: int, latents: torch.Tensor): 

19 self._call_count += 1 

20 

21 if self._call_count % self.save_every_n_steps == 0: 

22 latents_to_save = latents.clone() 

23 if self.to_cpu: 

24 latents_to_save = latents_to_save.cpu() 

25 

26 self.data.append({ 

27 'step': step, 

28 'timestep': timestep, 

29 'latents': latents_to_save, 

30 'call_count': self._call_count 

31 }) 

32 

33 @property 

34 def latents_list(self) -> List[torch.Tensor]: 

35 """Return the list of latents.""" 

36 return [item['latents'] for item in self.data] 

37 

38 @property 

39 def timesteps_list(self) -> List[int]: 

40 """Return the list of timesteps.""" 

41 return [item['timestep'] for item in self.data] 

42 

43 def get_latents_at_step(self, step: int) -> torch.Tensor: 

44 """Get the latents at a specific step.""" 

45 for item in self.data: 

46 if item['step'] == step: 

47 return item['latents'] 

48 raise ValueError(f"No latents found for step {step}") 

49 

50 def clear(self): 

51 """Clear the collected data.""" 

52 self.data.clear() 

53 self._call_count = 0