@@ -64,17 +64,20 @@ def predict_model(config: dict, fun_control: dict) -> Tuple[float, float]:
6464 # the config id is generated here without a timestamp. This differs from
6565 # the config id generated in cvmodel.py and trainmodel.py.
6666 config_id = generate_config_id (config , timestamp = False ) + "_TEST"
67- dm = LightDataModule (
68- dataset = fun_control ["data_set" ],
69- data_full_train = fun_control ["data_full_train" ],
70- data_test = fun_control ["data_test" ],
71- batch_size = config ["batch_size" ],
72- num_workers = fun_control ["num_workers" ],
73- test_size = fun_control ["test_size" ],
74- test_seed = fun_control ["test_seed" ],
75- scaler = fun_control ["scaler" ],
76- verbosity = fun_control ["verbosity" ],
77- )
67+ if fun_control ["data_module" ] is None :
68+ dm = LightDataModule (
69+ dataset = fun_control ["data_set" ],
70+ data_full_train = fun_control ["data_full_train" ],
71+ data_test = fun_control ["data_test" ],
72+ batch_size = config ["batch_size" ],
73+ num_workers = fun_control ["num_workers" ],
74+ test_size = fun_control ["test_size" ],
75+ test_seed = fun_control ["test_seed" ],
76+ scaler = fun_control ["scaler" ],
77+ verbosity = fun_control ["verbosity" ],
78+ )
79+ else :
80+ dm = fun_control ["data_module" ]
7881 # TODO: Check if this is necessary:
7982 # dm.setup(stage="train")
8083 # Init model from datamodule's attributes
0 commit comments