@@ -217,6 +217,43 @@ def _setup_test_data_provided(self, stage) -> None:
217217 # Transform the predict data
218218 self .data_predict = self .transform_dataset (self .data_predict )
219219
220+ def _setup_data_val_provided (self , stage ) -> None :
221+ # New functionality for predefined train, validation and test data in the fun_control
222+ # Get the data set sizes
223+ if self .data_full_train is None :
224+ raise ValueError ("If data_val is defined, data_full_train must also be defined." )
225+ train_size = len (self .data_full_train )
226+ val_size = len (self .data_val )
227+ test_size = len (self .data_test )
228+ # Assign train and validation data sets
229+ if stage == "fit" or stage is None :
230+ if self .verbosity > 0 :
231+ print (f"train_size: { train_size } , val_size: { val_size } used for train & val data." )
232+ generator_fit = torch .Generator ().manual_seed (self .test_seed )
233+ # Use all data from data_full_train as training data
234+ self .data_train = self .data_full_train
235+ # Handle scaling and transformation if scaler is provided
236+ if self .scaler is not None :
237+ self .handle_scaling_and_transform ()
238+
239+ # Assign test dataset for use in dataloader(s)
240+ if stage == "test" or stage is None :
241+ if self .verbosity > 0 :
242+ print (f"test_size: { test_size } used for test dataset." )
243+ self .data_test = self .data_test
244+ if self .scaler is not None :
245+ # Transform the test data
246+ self .data_test = self .transform_dataset (self .data_test )
247+
248+ # Assign pred dataset for use in dataloader(s)
249+ if stage == "predict" or stage is None :
250+ if self .verbosity > 0 :
251+ print (f"test_size: { test_size } used for predict dataset." )
252+ self .data_predict = self .data_test
253+ if self .scaler is not None :
254+ # Transform the predict data
255+ self .data_predict = self .transform_dataset (self .data_predict )
256+
220257 def setup (self , stage : Optional [str ] = None ) -> None :
221258 """
222259 Splits the data for use in training, validation, and testing.
@@ -243,6 +280,8 @@ def setup(self, stage: Optional[str] = None) -> None:
243280 """
244281 if self .data_full is not None :
245282 self ._setup_full_data_provided (stage )
283+ elif self .data_val is not None :
284+ self ._setup_data_val_provided (stage )
246285 else :
247286 self ._setup_test_data_provided (stage )
248287
0 commit comments