Coverage for evaluation / dataset.py: 100.00%
87 statements
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 10:24 +0000
« prev ^ index » next coverage.py v7.13.0, created at 2025-12-22 10:24 +0000
1# Copyright 2025 THU-BPM MarkDiffusion.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
16import ujson as json
17from datasets import load_dataset
18import pandas as pd
19from PIL import Image
20import requests
21from io import BytesIO
22from tqdm import tqdm
23import random
24from typing import List
26class BaseDataset:
27 """Base dataset class."""
29 def __init__(self, max_samples: int = 200):
30 """Initialize the dataset.
32 Parameters:
33 max_samples: Maximum number of samples to load.
34 """
35 self.max_samples = max_samples
36 self.prompts = []
37 self.references = []
39 @property
40 def num_samples(self) -> int:
41 """Number of samples in the dataset."""
42 return len(self.prompts)
44 @property
45 def num_references(self) -> int:
46 """Number of references in the dataset."""
47 return len(self.references)
49 def get_prompt(self, idx) -> str:
50 """Get the prompt at the given index."""
51 return self.prompts[idx]
53 def get_reference(self, idx) -> Image.Image:
54 """Get the reference Image at the given index."""
55 return self.references[idx]
57 def __len__(self) -> int:
58 """Number of samples in the dataset.(Equivalent to num_samples)"""
59 return self.num_samples
61 def __getitem__(self, idx) -> tuple[str, Image.Image]:
62 """Get the prompt (and reference Image if available) at the given index."""
63 if len(self.references) == 0:
64 return self.prompts[idx]
65 else:
66 return self.prompts[idx], self.references[idx]
68 def _load_data(self):
69 """Load data from the dataset."""
70 pass
73class StableDiffusionPromptsDataset(BaseDataset):
74 """Stable Diffusion prompts dataset."""
76 def __init__(self, max_samples: int = 200, split: str = "test", shuffle: bool = False):
77 """Initialize the dataset.
79 Parameters:
80 max_samples: Maximum number of samples to load.
81 split: Split to load.
82 shuffle: Whether to shuffle the dataset.
83 """
84 super().__init__(max_samples)
85 self.split = split
86 self.shuffle = shuffle
87 self._load_data()
89 @property
90 def name(self):
91 """Name of the dataset."""
92 return "Stable Diffusion Prompts"
94 def _load_data(self):
95 dataset = load_dataset("dataset/stable_diffusion_prompts", split=self.split)
96 if self.shuffle:
97 dataset = dataset.shuffle()
98 for prompt in dataset["Prompt"][:self.max_samples]:
99 self.prompts.append(prompt)
101class MSCOCODataset(BaseDataset):
102 """MSCOCO 2017 dataset."""
104 def __init__(self, max_samples: int = 200, shuffle: bool = False):
105 """Initialize the dataset.
107 Parameters:
108 max_samples: Maximum number of samples to load.
109 shuffle: Whether to shuffle the dataset.
110 """
111 super().__init__(max_samples)
112 self.shuffle = shuffle
113 self._load_data()
115 @property
116 def name(self):
117 """Name of the dataset."""
118 return "MS-COCO 2017"
120 def _load_image_from_url(self, url):
121 """Load image from url."""
122 try:
123 response = requests.get(url)
124 response.raise_for_status()
125 image = Image.open(BytesIO(response.content))
126 return image
127 except Exception as e:
128 print(f"Load image from url failed: {e}")
129 return None
131 def _load_data(self):
132 """Load data from the MSCOCO 2017 dataset."""
133 df = pd.read_parquet("dataset/mscoco/mscoco.parquet")
134 if self.shuffle:
135 df = df.sample(frac=1).reset_index(drop=True)
136 for i in tqdm(range(self.max_samples), desc="Loading MSCOCO dataset"):
137 item = df.iloc[i]
138 self.prompts.append(item['TEXT'])
139 self.references.append(self._load_image_from_url(item['URL']))
141class VBenchDataset(BaseDataset):
142 """VBench dataset."""
144 def __init__(self, max_samples: int, dimension: str = "subject_consistency", shuffle: bool = False):
145 """Initialize the dataset.
147 Args:
148 max_samples (int): Maximum number of samples to load.
149 dimension (str, optional): Dimensions to load. Selected from "subject_consistency", "background_consistency", "imaging_quality", "motion_smoothness", "dynamic_degree".
150 shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
151 """
152 super().__init__(max_samples)
153 self.shuffle = shuffle
154 self.dimension = dimension
155 self._load_data()
157 @property
158 def name(self):
159 """Name of the dataset."""
160 return "VBench"
162 def _load_data(self):
163 """Load data from the VBench dataset."""
164 with open(f"dataset/vbench/prompts_per_dimension/{self.dimension}.txt", "r") as f:
165 prompts = [line.strip() for line in f.readlines()]
166 if self.shuffle:
167 random.shuffle(prompts)
168 self.prompts.extend(prompts[:self.max_samples])