Coverage for utils / diffusion_config.py: 100.00%

43 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 10:24 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16from dataclasses import dataclass 

17from typing import Optional, Union, Any, Dict 

18import torch 

19from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline, TextToVideoSDPipeline, StableVideoDiffusionPipeline 

20from utils.pipeline_utils import ( 

21 get_pipeline_type, 

22 PIPELINE_TYPE_IMAGE, 

23 PIPELINE_TYPE_TEXT_TO_VIDEO, 

24 PIPELINE_TYPE_IMAGE_TO_VIDEO 

25) 

26 

27@dataclass 

28class DiffusionConfig: 

29 """Configuration class for diffusion models and parameters.""" 

30 

31 def __init__( 

32 self, 

33 scheduler: DPMSolverMultistepScheduler, 

34 pipe: Union[StableDiffusionPipeline, TextToVideoSDPipeline, StableVideoDiffusionPipeline], 

35 device: str, 

36 guidance_scale: float = 7.5, 

37 num_images: int = 1, 

38 num_inference_steps: int = 50, 

39 num_inversion_steps: Optional[int] = None, 

40 image_size: tuple = (512, 512), 

41 dtype: torch.dtype = torch.float16, 

42 gen_seed: int = 0, 

43 init_latents_seed: int = 0, 

44 inversion_type: str = "ddim", 

45 num_frames: int = -1, # -1 means image generation; >=1 means video generation 

46 **kwargs 

47 ): 

48 self.device = device 

49 self.scheduler = scheduler 

50 self.pipe = pipe 

51 self.guidance_scale = guidance_scale 

52 self.num_images = num_images 

53 self.num_inference_steps = num_inference_steps 

54 self.num_inversion_steps = num_inversion_steps or num_inference_steps 

55 self.image_size = image_size 

56 self.dtype = dtype 

57 self.gen_seed = gen_seed 

58 self.init_latents_seed = init_latents_seed 

59 self.inversion_type = inversion_type 

60 self.num_frames = num_frames 

61 # Store additional kwargs 

62 self.gen_kwargs = kwargs 

63 

64 ## Assertions 

65 assert self.inversion_type in ["ddim", "exact"], f"Invalid inversion type: {self.inversion_type}" 

66 

67 ## Validate pipeline type and parameter compatibility 

68 self._validate_pipeline_config() 

69 

70 def _validate_pipeline_config(self) -> None: 

71 """Validate pipeline type and parameter compatibility.""" 

72 pipeline_type = get_pipeline_type(self.pipe) 

73 

74 if pipeline_type is None: 

75 raise ValueError(f"Unsupported pipeline type: {type(self.pipe)}") 

76 

77 # Validate num_frames setting based on pipeline type 

78 if pipeline_type == PIPELINE_TYPE_IMAGE: 

79 if self.num_frames >= 1: 

80 # Auto-correct for image pipelines 

81 self.num_frames = -1 

82 elif pipeline_type in [PIPELINE_TYPE_TEXT_TO_VIDEO, PIPELINE_TYPE_IMAGE_TO_VIDEO]: 

83 if self.num_frames < 1: 

84 raise ValueError(f"For {pipeline_type} pipelines, num_frames must be >= 1, got {self.num_frames}") 

85 

86 @property 

87 def pipeline_type(self) -> str: 

88 """Get the pipeline type.""" 

89 return get_pipeline_type(self.pipe) 

90 

91 @property 

92 def is_video_pipeline(self) -> bool: 

93 """Check if this is a video pipeline.""" 

94 return self.pipeline_type in [PIPELINE_TYPE_TEXT_TO_VIDEO, PIPELINE_TYPE_IMAGE_TO_VIDEO] 

95 

96 @property 

97 def is_image_pipeline(self) -> bool: 

98 """Check if this is an image pipeline.""" 

99 return self.pipeline_type == PIPELINE_TYPE_IMAGE