Skip to content

Commit 118e6ec

Browse files
0.24.21
1 parent 27db166 commit 118e6ec

4 files changed

Lines changed: 67 additions & 3 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.24.20"
10+
version = "0.24.21"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/light/trainmodel.py

Lines changed: 63 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import lightning as L
2-
from spotpython.data.lightdatamodule import LightDataModule
2+
from spotpython.data.lightdatamodule import LightDataModule, PadSequenceManyToMany
33
from spotpython.utils.eda import generate_config_id
44
from pytorch_lightning.loggers import TensorBoardLogger
55
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
66
from lightning.pytorch.callbacks import ModelCheckpoint
7+
from torch.utils.data import DataLoader
8+
import torch
79
import os
810

11+
import numpy as np
12+
913

1014
def generate_config_id_with_timestamp(config: dict, timestamp: bool) -> str:
1115
"""
@@ -124,6 +128,7 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
124128
)
125129
else:
126130
dm = fun_control["data_module"]
131+
127132
model = build_model_instance(config, fun_control)
128133
# TODO: Check if this is necessary or if this is handled by the trainer
129134
# dm.setup()
@@ -183,6 +188,63 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
183188
dirpath = os.path.join(fun_control["CHECKPOINT_PATH"], config_id)
184189
callbacks.append(ModelCheckpoint(dirpath=dirpath, monitor=None, verbose=False, save_last=True)) # Save the last checkpoint
185190

191+
if fun_control["hacky"]:
192+
verbose = fun_control["verbosity"] > 0
193+
ds = fun_control["data_full_train"]
194+
indices = list(range(len(ds)))
195+
indice_results_val_loss = []
196+
indice_results_hp_metric = []
197+
for i in indices:
198+
print(f"train_model(): Hacky Implementation with Index {i}")
199+
test_indices = [indices[i]]
200+
train_indices = [index for index in indices if index != test_indices[0]]
201+
202+
train_dataset = torch.utils.data.Subset(ds, train_indices)
203+
test_dataset = torch.utils.data.Subset(ds, test_indices)
204+
205+
train_dl = DataLoader(train_dataset, batch_size=config["batch_size"], shuffle=False, collate_fn=PadSequenceManyToMany())
206+
test_dl = DataLoader(test_dataset, batch_size=config["batch_size"], shuffle=False, collate_fn=PadSequenceManyToMany())
207+
208+
model = build_model_instance(config, fun_control)
209+
210+
enable_progress_bar = fun_control["enable_progress_bar"] or False
211+
trainer = L.Trainer(
212+
# Where to save models
213+
default_root_dir=os.path.join(fun_control["CHECKPOINT_PATH"], config_id),
214+
max_epochs=model.hparams.epochs,
215+
accelerator=fun_control["accelerator"],
216+
devices=fun_control["devices"],
217+
strategy=fun_control["strategy"],
218+
num_nodes=fun_control["num_nodes"],
219+
precision=fun_control["precision"],
220+
logger=TensorBoardLogger(save_dir=fun_control["TENSORBOARD_PATH"], version=config_id, default_hp_metric=True, log_graph=fun_control["log_graph"], name=""),
221+
callbacks=callbacks,
222+
enable_progress_bar=enable_progress_bar,
223+
num_sanity_val_steps=fun_control["num_sanity_val_steps"],
224+
log_every_n_steps=fun_control["log_every_n_steps"],
225+
gradient_clip_val=None,
226+
gradient_clip_algorithm="norm",
227+
)
228+
229+
trainer.fit(model=model, train_dataloaders=train_dl, ckpt_path=None)
230+
result = trainer.validate(model=model, dataloaders=test_dl, ckpt_path=None, verbose=verbose)
231+
result = result[0]
232+
233+
print(f"results_dict: {result}")
234+
235+
indice_results_val_loss.append(result["val_loss"])
236+
indice_results_hp_metric.append(result["hp_metric"])
237+
238+
mean_val_loss = np.mean(indice_results_val_loss)
239+
mean_hp_metric = np.mean(indice_results_hp_metric)
240+
241+
print(f"train_model(): Mean Validation Loss: {mean_val_loss}")
242+
print(f"train_model(): Mean Hyperparameter Metric: {mean_hp_metric}")
243+
244+
results_dict = {"val_loss": mean_val_loss, "hp_metric": mean_hp_metric}
245+
246+
return results_dict["val_loss"]
247+
186248
# Tensorboard logger. The tensorboard is passed to the trainer.
187249
# See: https://lightning.ai/docs/pytorch/stable/extensions/generated/lightning.pytorch.loggers.TensorBoardLogger.html
188250
# It uses the following arguments:

src/spotpython/utils/init.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def fun_control_init(
3333
core_model_name=None,
3434
data=None,
3535
data_full_train=None,
36+
hacky=False, # !TODO: Documentation
3637
data_val=None,
3738
data_dir="./data",
3839
data_module=None,
@@ -429,6 +430,7 @@ def fun_control_init(
429430
"data": data,
430431
"data_dir": data_dir,
431432
"data_full_train": data_full_train,
433+
"hacky": hacky,
432434
"data_module": data_module,
433435
"data_set": data_set,
434436
"data_set_name": data_set_name,

test/test_kriging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from spotpython.build import Kriging
1+
from spotpython.build.kriging import Kriging
22
import numpy as np
33
from math import erf
44
from numpy import log, var

0 commit comments

Comments
 (0)