Skip to content

Commit 8afc4b7

Browse files
Revert "033.7"
This reverts commit 15cb116.
1 parent 15cb116 commit 8afc4b7

2 files changed

Lines changed: 6 additions & 12 deletions

File tree

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.33.7"
10+
version = "0.33.6"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]
@@ -45,7 +45,9 @@ dependencies = [
4545
"plotly",
4646
"pytest",
4747
"pytest-mock",
48+
"PyQt6",
4849
"python-markdown-math",
50+
"pytorch-lightning>=1.4",
4951
"river>=0.22.0",
5052
"scikit-learn",
5153
"scipy",

src/spotpython/light/trainmodel.py

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,6 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
129129
)
130130
else:
131131
dm = fun_control["data_module"]
132-
dm.setup() # Manually call setup to prepare the datasets
133132

134133
model = build_model_instance(config, fun_control)
135134
# TODO: Check if this is necessary or if this is handled by the trainer
@@ -239,7 +238,7 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
239238
gradient_clip_algorithm="norm",
240239
)
241240

242-
trainer.fit(model=model, train_dataloaders=train_dl, val_dataloaders=test_dl, ckpt_path=None)
241+
trainer.fit(model=model, train_dataloaders=train_dl, ckpt_path=None)
243242
result = trainer.validate(model=model, dataloaders=test_dl, ckpt_path=None, verbose=verbose)
244243
result = result[0]
245244

@@ -351,13 +350,10 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
351350
# Could also be one of two special keywords "last" and "hpc".
352351
# If there is no checkpoint file at the path, an exception is raised.
353352
try:
354-
trainer.fit(model=model, train_dataloaders=dm.train_dataloader(), val_dataloaders=dm.val_dataloader(), ckpt_path=None)
353+
trainer.fit(model=model, datamodule=dm, ckpt_path=None)
355354
except Exception as e:
356355
print(f"train_model(): trainer.fit failed with exception: {e}")
357-
return None
358356
# Test best model on validation and test set
359-
# The validate and test methods expect a datamodule or dataloaders.
360-
# Using the datamodule is cleaner.
361357
verbose = fun_control["verbosity"] > 0
362358

363359
# Validate the model
@@ -459,7 +455,6 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
459455
)
460456
else:
461457
dm = fun_control["data_module"]
462-
dm.setup() # Manually call setup to prepare the datasets
463458

464459
model = build_model_instance(config, fun_control)
465460
# TODO: Check if this is necessary or if this is handled by the trainer
@@ -624,13 +619,10 @@ def train_model_xai(config: dict, fun_control: dict, timestamp: bool = True) ->
624619
# Could also be one of two special keywords "last" and "hpc".
625620
# If there is no checkpoint file at the path, an exception is raised.
626621
try:
627-
trainer.fit(model=model, train_dataloaders=dm.train_dataloader(), val_dataloaders=dm.val_dataloader(), ckpt_path=None)
622+
trainer.fit(model=model, datamodule=dm, ckpt_path=None)
628623
except Exception as e:
629624
print(f"train_model(): trainer.fit failed with exception: {e}")
630-
return None
631625
# Test best model on validation and test set
632-
# The validate and test methods expect a datamodule or dataloaders.
633-
# Using the datamodule is cleaner.
634626
verbose = fun_control["verbosity"] > 0
635627

636628
# Validate the model

0 commit comments

Comments
 (0)