Skip to content

Commit c2c172e

Browse files
committed
validation data option for lightdatamodule
1 parent af1295b commit c2c172e

1 file changed

Lines changed: 39 additions & 0 deletions

File tree

src/spotpython/data/lightdatamodule.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -217,6 +217,43 @@ def _setup_test_data_provided(self, stage) -> None:
217217
# Transform the predict data
218218
self.data_predict = self.transform_dataset(self.data_predict)
219219

220+
def _setup_data_val_provided(self, stage) -> None:
221+
# New functionality for predefined train, validation and test data in the fun_control
222+
# Get the data set sizes
223+
if self.data_full_train is None:
224+
raise ValueError("If data_val is defined, data_full_train must also be defined.")
225+
train_size = len(self.data_full_train)
226+
val_size = len(self.data_val)
227+
test_size = len(self.data_test)
228+
# Assign train and validation data sets
229+
if stage == "fit" or stage is None:
230+
if self.verbosity > 0:
231+
print(f"train_size: {train_size}, val_size: {val_size} used for train & val data.")
232+
generator_fit = torch.Generator().manual_seed(self.test_seed)
233+
# Use all data from data_full_train as training data
234+
self.data_train = self.data_full_train
235+
# Handle scaling and transformation if scaler is provided
236+
if self.scaler is not None:
237+
self.handle_scaling_and_transform()
238+
239+
# Assign test dataset for use in dataloader(s)
240+
if stage == "test" or stage is None:
241+
if self.verbosity > 0:
242+
print(f"test_size: {test_size} used for test dataset.")
243+
self.data_test = self.data_test
244+
if self.scaler is not None:
245+
# Transform the test data
246+
self.data_test = self.transform_dataset(self.data_test)
247+
248+
# Assign pred dataset for use in dataloader(s)
249+
if stage == "predict" or stage is None:
250+
if self.verbosity > 0:
251+
print(f"test_size: {test_size} used for predict dataset.")
252+
self.data_predict = self.data_test
253+
if self.scaler is not None:
254+
# Transform the predict data
255+
self.data_predict = self.transform_dataset(self.data_predict)
256+
220257
def setup(self, stage: Optional[str] = None) -> None:
221258
"""
222259
Splits the data for use in training, validation, and testing.
@@ -243,6 +280,8 @@ def setup(self, stage: Optional[str] = None) -> None:
243280
"""
244281
if self.data_full is not None:
245282
self._setup_full_data_provided(stage)
283+
elif self.data_val is not None:
284+
self._setup_data_val_provided(stage)
246285
else:
247286
self._setup_test_data_provided(stage)
248287

0 commit comments

Comments
 (0)