Skip to content

Commit c9b7e5c

Browse files
0.18.11
1 parent f6ba61d commit c9b7e5c

5 files changed

Lines changed: 50 additions & 36 deletions

File tree

RELEASE_NOTES.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,8 @@
1+
spotpython 0.18.11:
2+
3+
- testmodel, predictmodel, and cvmodel functions updated, so that they can handle DataModules specified by the user in fun_control.
4+
5+
16
spotpython 0.18.8:
27

38
- lightdatamodule.py:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.18.10"
10+
version = "0.18.11"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/light/cvmodel.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,22 @@ def cv_model(config: dict, fun_control: dict) -> float:
5656

5757
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _L_cond=_L_cond, _torchmetric=_torchmetric)
5858

59-
dm = LightCrossValidationDataModule(
60-
k=k,
61-
num_splits=num_folds,
62-
split_seed=split_seed,
63-
dataset=fun_control["data_set"],
64-
data_full_train=fun_control["data_full_train"],
65-
data_test=fun_control["data_test"],
66-
num_workers=fun_control["num_workers"],
67-
batch_size=config["batch_size"],
68-
data_dir=fun_control["DATASET_PATH"],
69-
scaler=fun_control["scaler"],
70-
verbosity=fun_control["verbosity"],
71-
)
59+
if fun_control["data_module"] is None:
60+
dm = LightCrossValidationDataModule(
61+
k=k,
62+
num_splits=num_folds,
63+
split_seed=split_seed,
64+
dataset=fun_control["data_set"],
65+
data_full_train=fun_control["data_full_train"],
66+
data_test=fun_control["data_test"],
67+
num_workers=fun_control["num_workers"],
68+
batch_size=config["batch_size"],
69+
data_dir=fun_control["DATASET_PATH"],
70+
scaler=fun_control["scaler"],
71+
verbosity=fun_control["verbosity"],
72+
)
73+
else:
74+
dm = fun_control["data_module"]
7275
dm.setup()
7376
dm.prepare_data()
7477

src/spotpython/light/predictmodel.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -64,17 +64,20 @@ def predict_model(config: dict, fun_control: dict) -> Tuple[float, float]:
6464
# the config id is generated here without a timestamp. This differs from
6565
# the config id generated in cvmodel.py and trainmodel.py.
6666
config_id = generate_config_id(config, timestamp=False) + "_TEST"
67-
dm = LightDataModule(
68-
dataset=fun_control["data_set"],
69-
data_full_train=fun_control["data_full_train"],
70-
data_test=fun_control["data_test"],
71-
batch_size=config["batch_size"],
72-
num_workers=fun_control["num_workers"],
73-
test_size=fun_control["test_size"],
74-
test_seed=fun_control["test_seed"],
75-
scaler=fun_control["scaler"],
76-
verbosity=fun_control["verbosity"],
77-
)
67+
if fun_control["data_module"] is None:
68+
dm = LightDataModule(
69+
dataset=fun_control["data_set"],
70+
data_full_train=fun_control["data_full_train"],
71+
data_test=fun_control["data_test"],
72+
batch_size=config["batch_size"],
73+
num_workers=fun_control["num_workers"],
74+
test_size=fun_control["test_size"],
75+
test_seed=fun_control["test_seed"],
76+
scaler=fun_control["scaler"],
77+
verbosity=fun_control["verbosity"],
78+
)
79+
else:
80+
dm = fun_control["data_module"]
7881
# TODO: Check if this is necessary:
7982
# dm.setup(stage="train")
8083
# Init model from datamodule's attributes

src/spotpython/light/testmodel.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -65,17 +65,20 @@ def test_model(config: dict, fun_control: dict) -> Tuple[float, float]:
6565
# the config id is generated here without a timestamp. This differs from
6666
# the config id generated in cvmodel.py and trainmodel.py.
6767
config_id = generate_config_id(config, timestamp=False) + "_TEST"
68-
dm = LightDataModule(
69-
dataset=fun_control["data_set"],
70-
data_full_train=fun_control["data_full_train"],
71-
data_test=fun_control["data_test"],
72-
batch_size=config["batch_size"],
73-
num_workers=fun_control["num_workers"],
74-
test_size=fun_control["test_size"],
75-
test_seed=fun_control["test_seed"],
76-
scaler=fun_control["scaler"],
77-
verbosity=fun_control["verbosity"],
78-
)
68+
if fun_control["data_module"] is None:
69+
dm = LightDataModule(
70+
dataset=fun_control["data_set"],
71+
data_full_train=fun_control["data_full_train"],
72+
data_test=fun_control["data_test"],
73+
batch_size=config["batch_size"],
74+
num_workers=fun_control["num_workers"],
75+
test_size=fun_control["test_size"],
76+
test_seed=fun_control["test_seed"],
77+
scaler=fun_control["scaler"],
78+
verbosity=fun_control["verbosity"],
79+
)
80+
else:
81+
dm = fun_control["data_module"]
7982
# TODO: Check if this is necessary:
8083
# dm.setup()
8184
# Init model from datamodule's attributes

0 commit comments

Comments
 (0)