Coverage for evaluation / dataset.py: 100.00%

87 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-22 10:24 +0000

1# Copyright 2025 THU-BPM MarkDiffusion. 

2# 

3# Licensed under the Apache License, Version 2.0 (the "License"); 

4# you may not use this file except in compliance with the License. 

5# You may obtain a copy of the License at 

6# 

7# http://www.apache.org/licenses/LICENSE-2.0 

8# 

9# Unless required by applicable law or agreed to in writing, software 

10# distributed under the License is distributed on an "AS IS" BASIS, 

11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 

12# See the License for the specific language governing permissions and 

13# limitations under the License. 

14 

15 

16import ujson as json 

17from datasets import load_dataset 

18import pandas as pd 

19from PIL import Image 

20import requests 

21from io import BytesIO 

22from tqdm import tqdm 

23import random 

24from typing import List 

25 

26class BaseDataset: 

27 """Base dataset class.""" 

28 

29 def __init__(self, max_samples: int = 200): 

30 """Initialize the dataset. 

31  

32 Parameters: 

33 max_samples: Maximum number of samples to load. 

34 """ 

35 self.max_samples = max_samples 

36 self.prompts = [] 

37 self.references = [] 

38 

39 @property 

40 def num_samples(self) -> int: 

41 """Number of samples in the dataset.""" 

42 return len(self.prompts) 

43 

44 @property 

45 def num_references(self) -> int: 

46 """Number of references in the dataset.""" 

47 return len(self.references) 

48 

49 def get_prompt(self, idx) -> str: 

50 """Get the prompt at the given index.""" 

51 return self.prompts[idx] 

52 

53 def get_reference(self, idx) -> Image.Image: 

54 """Get the reference Image at the given index.""" 

55 return self.references[idx] 

56 

57 def __len__(self) -> int: 

58 """Number of samples in the dataset.(Equivalent to num_samples)""" 

59 return self.num_samples 

60 

61 def __getitem__(self, idx) -> tuple[str, Image.Image]: 

62 """Get the prompt (and reference Image if available) at the given index.""" 

63 if len(self.references) == 0: 

64 return self.prompts[idx] 

65 else: 

66 return self.prompts[idx], self.references[idx] 

67 

68 def _load_data(self): 

69 """Load data from the dataset.""" 

70 pass 

71 

72 

73class StableDiffusionPromptsDataset(BaseDataset): 

74 """Stable Diffusion prompts dataset.""" 

75 

76 def __init__(self, max_samples: int = 200, split: str = "test", shuffle: bool = False): 

77 """Initialize the dataset. 

78  

79 Parameters: 

80 max_samples: Maximum number of samples to load. 

81 split: Split to load. 

82 shuffle: Whether to shuffle the dataset. 

83 """ 

84 super().__init__(max_samples) 

85 self.split = split 

86 self.shuffle = shuffle 

87 self._load_data() 

88 

89 @property 

90 def name(self): 

91 """Name of the dataset.""" 

92 return "Stable Diffusion Prompts" 

93 

94 def _load_data(self): 

95 dataset = load_dataset("dataset/stable_diffusion_prompts", split=self.split) 

96 if self.shuffle: 

97 dataset = dataset.shuffle() 

98 for prompt in dataset["Prompt"][:self.max_samples]: 

99 self.prompts.append(prompt) 

100 

101class MSCOCODataset(BaseDataset): 

102 """MSCOCO 2017 dataset.""" 

103 

104 def __init__(self, max_samples: int = 200, shuffle: bool = False): 

105 """Initialize the dataset. 

106  

107 Parameters: 

108 max_samples: Maximum number of samples to load. 

109 shuffle: Whether to shuffle the dataset. 

110 """ 

111 super().__init__(max_samples) 

112 self.shuffle = shuffle 

113 self._load_data() 

114 

115 @property 

116 def name(self): 

117 """Name of the dataset.""" 

118 return "MS-COCO 2017" 

119 

120 def _load_image_from_url(self, url): 

121 """Load image from url.""" 

122 try: 

123 response = requests.get(url) 

124 response.raise_for_status() 

125 image = Image.open(BytesIO(response.content)) 

126 return image 

127 except Exception as e: 

128 print(f"Load image from url failed: {e}") 

129 return None 

130 

131 def _load_data(self): 

132 """Load data from the MSCOCO 2017 dataset.""" 

133 df = pd.read_parquet("dataset/mscoco/mscoco.parquet") 

134 if self.shuffle: 

135 df = df.sample(frac=1).reset_index(drop=True) 

136 for i in tqdm(range(self.max_samples), desc="Loading MSCOCO dataset"): 

137 item = df.iloc[i] 

138 self.prompts.append(item['TEXT']) 

139 self.references.append(self._load_image_from_url(item['URL'])) 

140 

141class VBenchDataset(BaseDataset): 

142 """VBench dataset.""" 

143 

144 def __init__(self, max_samples: int, dimension: str = "subject_consistency", shuffle: bool = False): 

145 """Initialize the dataset. 

146 

147 Args: 

148 max_samples (int): Maximum number of samples to load. 

149 dimension (str, optional): Dimensions to load. Selected from "subject_consistency", "background_consistency", "imaging_quality", "motion_smoothness", "dynamic_degree". 

150 shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. 

151 """ 

152 super().__init__(max_samples) 

153 self.shuffle = shuffle 

154 self.dimension = dimension 

155 self._load_data() 

156 

157 @property 

158 def name(self): 

159 """Name of the dataset.""" 

160 return "VBench" 

161 

162 def _load_data(self): 

163 """Load data from the VBench dataset.""" 

164 with open(f"dataset/vbench/prompts_per_dimension/{self.dimension}.txt", "r") as f: 

165 prompts = [line.strip() for line in f.readlines()] 

166 if self.shuffle: 

167 random.shuffle(prompts) 

168 self.prompts.extend(prompts[:self.max_samples]) 

169