11def calculate_data_split (test_size , full_size , verbosity = 0 , stage = None ) -> tuple :
22 """
33 Calculates the split sizes for training, validation, and test datasets.
4+ Returns a tuple containing the sizes (full_train_size, val_size, train_size, test_size),
5+ where full_train_size is the size of the full dataset minus the test set.
6+
7+ Note:
8+ The first return value is full_train_size, i.e.,
9+ the size of the full dataset minus the test set.
410
511 Args:
612 test_size (float or int):
@@ -15,16 +21,39 @@ def calculate_data_split(test_size, full_size, verbosity=0, stage=None) -> tuple
1521
1622 Returns:
1723 tuple: A tuple containing the sizes (full_train_size, val_size, train_size, test_size).
24+
25+ Examples:
26+ >>> from spotpython.utils.split import calculate_data_split
27+ # Using proportion for test size
28+ calculate_data_split(0.2, 1000)
29+ (0.8, 0.16, 0.64, 0.2)
30+ # Using absolute number for test size
31+ calculate_data_split(200, 1000)
32+ (800, 160, 640, 200)
33+
34+ Raises:
35+ ValueError: If the sizes are not correct, i.e., full_size != train_size + val_size + test_size.
1836 """
1937 if isinstance (test_size , float ):
2038 full_train_size = round (1.0 - test_size , 2 )
2139 val_size = round (full_train_size * test_size , 2 )
22- train_size = round (full_train_size - val_size , 2 )
40+ train_size = 1.0 - test_size - val_size
41+ # check if the sizes are correct, i.e., 1.0 = train_size + val_size + test_size
42+ if full_train_size + test_size != 1.0 :
43+ raise ValueError (f"full_size ({ full_size } ) != full_train_size ({ full_train_size } ) + test_size ({ test_size } )" )
2344 else :
2445 # test_size is considered an int, training size calculation directly based on it
46+ # everything is calculated as an int
47+ # return values are also ints
48+ # check if test_size does not exceed full_size
49+ if test_size > full_size :
50+ raise ValueError (f"test_size ({ test_size } ) > full_size ({ full_size } )" )
2551 full_train_size = full_size - test_size
2652 val_size = int (full_train_size * test_size / full_size )
2753 train_size = full_train_size - val_size
54+ # check if the sizes are correct, i.e., full_size = train_size + val_size + test_size
55+ if full_train_size + test_size != full_size :
56+ raise ValueError (f"full_size ({ full_size } ) != full_train_size ({ full_train_size } ) + test_size ({ test_size } )" )
2857
2958 if verbosity > 0 :
3059 print (f"stage: { stage } " )
0 commit comments