Coverage for utils / diffusion_config.py: 100.00%
43 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 10:24 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 10:24 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16from dataclasses import dataclass
17from typing import Optional, Union, Any, Dict
18import torch
19from diffusers import DPMSolverMultistepScheduler, StableDiffusionPipeline, TextToVideoSDPipeline, StableVideoDiffusionPipeline
20from utils.pipeline_utils import (
21 get_pipeline_type,
22 PIPELINE_TYPE_IMAGE,
23 PIPELINE_TYPE_TEXT_TO_VIDEO,
24 PIPELINE_TYPE_IMAGE_TO_VIDEO
25)
27@dataclass
28class DiffusionConfig:
29 """Configuration class for diffusion models and parameters."""
31 def __init__(
32 self,
33 scheduler: DPMSolverMultistepScheduler,
34 pipe: Union[StableDiffusionPipeline, TextToVideoSDPipeline, StableVideoDiffusionPipeline],
35 device: str,
36 guidance_scale: float = 7.5,
37 num_images: int = 1,
38 num_inference_steps: int = 50,
39 num_inversion_steps: Optional[int] = None,
40 image_size: tuple = (512, 512),
41 dtype: torch.dtype = torch.float16,
42 gen_seed: int = 0,
43 init_latents_seed: int = 0,
44 inversion_type: str = "ddim",
45 num_frames: int = -1, # -1 means image generation; >=1 means video generation
46 **kwargs
47 ):
48 self.device = device
49 self.scheduler = scheduler
50 self.pipe = pipe
51 self.guidance_scale = guidance_scale
52 self.num_images = num_images
53 self.num_inference_steps = num_inference_steps
54 self.num_inversion_steps = num_inversion_steps or num_inference_steps
55 self.image_size = image_size
56 self.dtype = dtype
57 self.gen_seed = gen_seed
58 self.init_latents_seed = init_latents_seed
59 self.inversion_type = inversion_type
60 self.num_frames = num_frames
61 # Store additional kwargs
62 self.gen_kwargs = kwargs
64 ## Assertions
65 assert self.inversion_type in ["ddim", "exact"], f"Invalid inversion type: {self.inversion_type}"
67 ## Validate pipeline type and parameter compatibility
68 self._validate_pipeline_config()
70 def _validate_pipeline_config(self) -> None:
71 """Validate pipeline type and parameter compatibility."""
72 pipeline_type = get_pipeline_type(self.pipe)
74 if pipeline_type is None:
75 raise ValueError(f"Unsupported pipeline type: {type(self.pipe)}")
77 # Validate num_frames setting based on pipeline type
78 if pipeline_type == PIPELINE_TYPE_IMAGE:
79 if self.num_frames >= 1:
80 # Auto-correct for image pipelines
81 self.num_frames = -1
82 elif pipeline_type in [PIPELINE_TYPE_TEXT_TO_VIDEO, PIPELINE_TYPE_IMAGE_TO_VIDEO]:
83 if self.num_frames < 1:
84 raise ValueError(f"For {pipeline_type} pipelines, num_frames must be >= 1, got {self.num_frames}")
86 @property
87 def pipeline_type(self) -> str:
88 """Get the pipeline type."""
89 return get_pipeline_type(self.pipe)
91 @property
92 def is_video_pipeline(self) -> bool:
93 """Check if this is a video pipeline."""
94 return self.pipeline_type in [PIPELINE_TYPE_TEXT_TO_VIDEO, PIPELINE_TYPE_IMAGE_TO_VIDEO]
96 @property
97 def is_image_pipeline(self) -> bool:
98 """Check if this is an image pipeline."""
99 return self.pipeline_type == PIPELINE_TYPE_IMAGE