Skip to content

Commit 4da9e58

Browse files
0.18.5
split
1 parent e8ff574 commit 4da9e58

3 files changed

Lines changed: 48 additions & 2 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.18.4"
10+
version = "0.18.5"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/utils/split.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
def calculate_data_split(test_size, full_size, verbosity=0, stage=None) -> tuple:
22
"""
33
Calculates the split sizes for training, validation, and test datasets.
4+
Returns a tuple containing the sizes (full_train_size, val_size, train_size, test_size),
5+
where full_train_size is the size of the full dataset minus the test set.
6+
7+
Note:
8+
The first return value is full_train_size, i.e.,
9+
the size of the full dataset minus the test set.
410
511
Args:
612
test_size (float or int):
@@ -15,16 +21,39 @@ def calculate_data_split(test_size, full_size, verbosity=0, stage=None) -> tuple
1521
1622
Returns:
1723
tuple: A tuple containing the sizes (full_train_size, val_size, train_size, test_size).
24+
25+
Examples:
26+
>>> from spotpython.utils.split import calculate_data_split
27+
# Using proportion for test size
28+
calculate_data_split(0.2, 1000)
29+
(0.8, 0.16, 0.64, 0.2)
30+
# Using absolute number for test size
31+
calculate_data_split(200, 1000)
32+
(800, 160, 640, 200)
33+
34+
Raises:
35+
ValueError: If the sizes are not correct, i.e., full_size != train_size + val_size + test_size.
1836
"""
1937
if isinstance(test_size, float):
2038
full_train_size = round(1.0 - test_size, 2)
2139
val_size = round(full_train_size * test_size, 2)
22-
train_size = round(full_train_size - val_size, 2)
40+
train_size = 1.0 - test_size - val_size
41+
# check if the sizes are correct, i.e., 1.0 = train_size + val_size + test_size
42+
if full_train_size + test_size != 1.0:
43+
raise ValueError(f"full_size ({full_size}) != full_train_size ({full_train_size}) + test_size ({test_size})")
2344
else:
2445
# test_size is considered an int, training size calculation directly based on it
46+
# everything is calculated as an int
47+
# return values are also ints
48+
# check if test_size does not exceed full_size
49+
if test_size > full_size:
50+
raise ValueError(f"test_size ({test_size}) > full_size ({full_size})")
2551
full_train_size = full_size - test_size
2652
val_size = int(full_train_size * test_size / full_size)
2753
train_size = full_train_size - val_size
54+
# check if the sizes are correct, i.e., full_size = train_size + val_size + test_size
55+
if full_train_size + test_size != full_size:
56+
raise ValueError(f"full_size ({full_size}) != full_train_size ({full_train_size}) + test_size ({test_size})")
2857

2958
if verbosity > 0:
3059
print(f"stage: {stage}")

test/test_split_data.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import pytest
2+
from spotpython.utils.split import calculate_data_split
3+
4+
def test_calculate_data_split_proportion():
5+
# Test with proportion for test size
6+
result = calculate_data_split(0.2, 1000)
7+
assert result == (0.8, 0.16, 0.64, 0.2), f"Unexpected result: {result}"
8+
9+
def test_calculate_data_split_absolute():
10+
# Test with absolute number for test size
11+
result = calculate_data_split(200, 1000)
12+
assert result == (800, 160, 640, 200), f"Unexpected result: {result}"
13+
14+
def test_calculate_data_split_invalid():
15+
# Test with invalid input where test size exceeds full size
16+
with pytest.raises(ValueError):
17+
calculate_data_split(1200, 1000)

0 commit comments

Comments
 (0)