@@ -129,7 +129,6 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
129129 )
130130 else :
131131 dm = fun_control ["data_module" ]
132- dm .setup () # Manually call setup to prepare the datasets
133132
134133 model = build_model_instance (config , fun_control )
135134 # TODO: Check if this is necessary or if this is handled by the trainer
@@ -239,7 +238,7 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
239238 gradient_clip_algorithm = "norm" ,
240239 )
241240
242- trainer .fit (model = model , train_dataloaders = train_dl , val_dataloaders = test_dl , ckpt_path = None )
241+ trainer .fit (model = model , train_dataloaders = train_dl , ckpt_path = None )
243242 result = trainer .validate (model = model , dataloaders = test_dl , ckpt_path = None , verbose = verbose )
244243 result = result [0 ]
245244
@@ -351,13 +350,10 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
351350 # Could also be one of two special keywords "last" and "hpc".
352351 # If there is no checkpoint file at the path, an exception is raised.
353352 try :
354- trainer .fit (model = model , train_dataloaders = dm . train_dataloader (), val_dataloaders = dm . val_dataloader () , ckpt_path = None )
353+ trainer .fit (model = model , datamodule = dm , ckpt_path = None )
355354 except Exception as e :
356355 print (f"train_model(): trainer.fit failed with exception: { e } " )
357- return None
358356 # Test best model on validation and test set
359- # The validate and test methods expect a datamodule or dataloaders.
360- # Using the datamodule is cleaner.
361357 verbose = fun_control ["verbosity" ] > 0
362358
363359 # Validate the model
@@ -459,7 +455,6 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
459455 )
460456 else :
461457 dm = fun_control ["data_module" ]
462- dm .setup () # Manually call setup to prepare the datasets
463458
464459 model = build_model_instance (config , fun_control )
465460 # TODO: Check if this is necessary or if this is handled by the trainer
@@ -624,13 +619,10 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
624619 # Could also be one of two special keywords "last" and "hpc".
625620 # If there is no checkpoint file at the path, an exception is raised.
626621 try :
627- trainer .fit (model = model , train_dataloaders = dm . train_dataloader (), val_dataloaders = dm . val_dataloader () , ckpt_path = None )
622+ trainer .fit (model = model , datamodule = dm , ckpt_path = None )
628623 except Exception as e :
629624 print (f"train_model(): trainer.fit failed with exception: { e } " )
630- return None
631625 # Test best model on validation and test set
632- # The validate and test methods expect a datamodule or dataloaders.
633- # Using the datamodule is cleaner.
634626 verbose = fun_control ["verbosity" ] > 0
635627
636628 # Validate the model
0 commit comments