Coverage for utils / callbacks.py: 100.00%
29 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 10:24 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 10:24 +0000
1import torch
2from typing import List
4class DenoisingLatentsCollector:
5 def __init__(self, save_every_n_steps: int = 1, to_cpu: bool = True):
6 """Initialize the latents collector.
8 Args:
9 save_every_n_steps (int, optional): Save latents every n steps. Defaults to 1.
10 to_cpu (bool, optional): Whether to move latents to CPU. Defaults to True.
11 """
13 self.save_every_n_steps = save_every_n_steps
14 self.to_cpu = to_cpu
15 self.data = []
16 self._call_count = 0
18 def __call__(self, step: int, timestep: int, latents: torch.Tensor):
19 self._call_count += 1
21 if self._call_count % self.save_every_n_steps == 0:
22 latents_to_save = latents.clone()
23 if self.to_cpu:
24 latents_to_save = latents_to_save.cpu()
26 self.data.append({
27 'step': step,
28 'timestep': timestep,
29 'latents': latents_to_save,
30 'call_count': self._call_count
31 })
33 @property
34 def latents_list(self) -> List[torch.Tensor]:
35 """Return the list of latents."""
36 return [item['latents'] for item in self.data]
38 @property
39 def timesteps_list(self) -> List[int]:
40 """Return the list of timesteps."""
41 return [item['timestep'] for item in self.data]
43 def get_latents_at_step(self, step: int) -> torch.Tensor:
44 """Get the latents at a specific step."""
45 for item in self.data:
46 if item['step'] == step:
47 return item['latents']
48 raise ValueError(f"No latents found for step {step}")
50 def clear(self):
51 """Clear the collected data."""
52 self.data.clear()
53 self._call_count = 0