Skip to content

Commit 1c3b712

Browse files
0.18.6
1 parent 7e28b23 commit 1c3b712

4 files changed

Lines changed: 88 additions & 13 deletions

File tree

RELEASE_NOTES.txt

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1-
spotpython 0.18.5:
1+
spotpython 0.18.6:
22

33
- split.py:
4-
computation fixed
4+
New function: compute_lengths_from_fractions()
5+
6+
- lightdatamodule.py:
7+
train, val, test set computaion updated
58

69
spotpython 0.18.4:
710

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.18.5"
10+
version = "0.18.6"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/data/lightdatamodule.py

Lines changed: 16 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import torch
33
from torch.utils.data import DataLoader, random_split, TensorDataset
44
from typing import Optional
5-
from spotpython.utils.split import calculate_data_split
5+
from math import floor
66

77

88
class LightDataModule(L.LightningDataModule):
@@ -166,12 +166,19 @@ def setup(self, stage: Optional[str] = None) -> None:
166166
Training set size: 3
167167
168168
"""
169-
full_train_size, val_size, train_size, test_size = calculate_data_split(
170-
test_size=self.test_size,
171-
full_size=len(self.data_full),
172-
verbosity=self.verbosity,
173-
stage=stage,
174-
)
169+
full_size = len(self.data_full)
170+
test_size = self.test_size
171+
172+
# consider the case when test_size is a float
173+
if isinstance(self.test_size, float):
174+
full_train_size = 1.0 - self.test_size
175+
val_size = full_train_size * self.test_size
176+
train_size = full_train_size - val_size
177+
else:
178+
# test_size is an int, training size calculation directly based on it
179+
full_train_size = full_size - self.test_size
180+
val_size = floor(full_train_size * self.test_size / full_size)
181+
train_size = full_size - val_size - test_size
175182

176183
# Assign train/val datasets for use in dataloaders
177184
if stage == "fit" or stage is None:
@@ -188,7 +195,7 @@ def setup(self, stage: Optional[str] = None) -> None:
188195
if self.verbosity > 0:
189196
print(f"test_size: {test_size} used for test dataset.")
190197
generator_test = torch.Generator().manual_seed(self.test_seed)
191-
self.data_test, _ = random_split(self.data_full, [test_size, full_train_size], generator=generator_test)
198+
self.data_test, _, _ = random_split(self.data_full, [test_size, train_size, val_size], generator=generator_test)
192199
if self.scaler is not None:
193200
# Transform the test data
194201
self.data_test = self.transform_dataset(self.data_test)
@@ -198,7 +205,7 @@ def setup(self, stage: Optional[str] = None) -> None:
198205
if self.verbosity > 0:
199206
print(f"test_size: {test_size} used for predict dataset.")
200207
generator_predict = torch.Generator().manual_seed(self.test_seed)
201-
self.data_predict, _ = random_split(self.data_full, [test_size, full_train_size], generator=generator_predict)
208+
self.data_predict, _, _ = random_split(self.data_full, [test_size, train_size, val_size], generator=generator_predict)
202209
if self.scaler is not None:
203210
# Transform the predict data
204211
self.data_predict = self.transform_dataset(self.data_predict)

src/spotpython/utils/split.py

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,68 @@
1+
import math
2+
import warnings
3+
from typing import List
4+
5+
6+
def compute_lengths_from_fractions(fractions: List[float], dataset_length: int) -> List[int]:
7+
"""Compute lengths of dataset splits from given fractions.
8+
9+
Given a list of fractions that sum up to 1, compute the lengths of each
10+
corresponding partition of a dataset with a specified length. Each length is
11+
determined as `floor(frac * dataset_length)`. Any remaining items (due to flooring)
12+
are distributed among the partitions in a round-robin fashion.
13+
14+
Args:
15+
fractions (List[float]): A list of fractions that should sum to 1.
16+
dataset_length (int): The length of the dataset.
17+
18+
Returns:
19+
List[int]: A list of lengths corresponding to each fraction.
20+
21+
Raises:
22+
ValueError: If the fractions do not sum to 1.
23+
ValueError: If any fraction is outside the range [0, 1].
24+
ValueError: If the sum of computed lengths does not equal the dataset length.
25+
26+
Examples:
27+
>>> from spotpython.utils.split import compute_lengths_from_fractions
28+
>>> dataset_length = 5
29+
>>> fractions = [0.2, 0.3, 0.5]
30+
>>> compute_lengths_from_fractions(fractions, dataset_length)
31+
[1, 1, 3]
32+
33+
In this example, 'dataset_length' is 5 and the 'fractions' specify the
34+
desired size distribution. The function calculates partitions of lengths
35+
[1, 1, 3] based on the given fractions.
36+
37+
"""
38+
if not math.isclose(sum(fractions), 1) or sum(fractions) > 1:
39+
raise ValueError("Fractions must sum up to 1.")
40+
41+
subset_lengths: List[int] = []
42+
for i, frac in enumerate(fractions):
43+
if frac < 0 or frac > 1:
44+
raise ValueError(f"Fraction at index {i} is not between 0 and 1")
45+
n_items_in_split = int(math.floor(dataset_length * frac))
46+
subset_lengths.append(n_items_in_split)
47+
48+
remainder = dataset_length - sum(subset_lengths)
49+
50+
# Add 1 to all the lengths in a round-robin fashion until the remainder is 0
51+
for i in range(remainder):
52+
idx_to_add_at = i % len(subset_lengths)
53+
subset_lengths[idx_to_add_at] += 1
54+
55+
lengths = subset_lengths
56+
for i, length in enumerate(lengths):
57+
if length == 0:
58+
warnings.warn(f"Length of split at index {i} is 0. " f"This might result in an empty dataset.")
59+
60+
if sum(lengths) != dataset_length:
61+
raise ValueError("Sum of computed lengths does not equal the input dataset length!")
62+
63+
return lengths
64+
65+
166
def calculate_data_split(test_size, full_size, verbosity=0, stage=None) -> tuple:
267
"""
368
Calculates the split sizes for training, validation, and test datasets.
@@ -52,7 +117,7 @@ def calculate_data_split(test_size, full_size, verbosity=0, stage=None) -> tuple
52117
val_size = int(full_train_size * test_size / full_size)
53118
train_size = full_train_size - val_size
54119
# check if the sizes are correct, i.e., full_size = train_size + val_size + test_size
55-
if full_train_size + test_size != full_size:
120+
if train_size + val_size + test_size != full_size:
56121
raise ValueError(f"full_size ({full_size}) != full_train_size ({full_train_size}) + test_size ({test_size})")
57122

58123
if verbosity > 0:

0 commit comments

Comments
 (0)