Skip to content

Commit a08435a

Browse files
testmodel, trainmodel. cvmodel, loadmodel
1 parent 377c371 commit a08435a

20 files changed

Lines changed: 365 additions & 935 deletions

notebooks/00_spotPython_tests.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -686,12 +686,12 @@
686686
"metadata": {},
687687
"outputs": [],
688688
"source": [
689-
"from spotPython.light.netlightbase import NetLightBase\n",
689+
"from spotPython.light.netlightregressione import NetLightRegression\n",
690690
"from spotPython.utils.init import fun_control_init\n",
691691
"from spotPython.hyperdict.light_hyper_dict import LightHyperDict\n",
692692
"from spotPython.hyperparameters.values import add_core_model_to_fun_control\n",
693693
"fun_control = fun_control_init()\n",
694-
"add_core_model_to_fun_control(core_model=NetLightBase,\n",
694+
"add_core_model_to_fun_control(core_model=NetLightRegression,\n",
695695
" fun_control=fun_control,\n",
696696
" hyper_dict=LightHyperDict)\n",
697697
"fun_control[\"core_model\"].__name__"

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.6.58"
10+
version = "0.6.60"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/data/csvdataset.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from torch.utils.data import Dataset
44
from sklearn.preprocessing import LabelEncoder
55
import pathlib
6-
from typing import Any
76

87

98
class CSVDataset(Dataset):

src/spotPython/fun/hyperlight.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import logging
22
import numpy as np
33
from numpy.random import default_rng
4-
from spotPython.light.traintest import train_model
4+
from spotPython.light.trainmodel import train_model
55
from spotPython.hyperparameters.values import assign_values, generate_one_config_from_var_dict, get_var_name
66

77
logger = logging.getLogger(__name__)

src/spotPython/fun/hyperlightning.py

Lines changed: 0 additions & 158 deletions
This file was deleted.

src/spotPython/hyperparameters/values.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -677,16 +677,16 @@ def add_core_model_to_fun_control(core_model, fun_control, hyper_dict=None, file
677677
The updated fun_control dictionary.
678678
679679
Examples:
680-
>>> from spotPython.light.netlightbase import NetLightBase
680+
>>> from spotPython.light.netlightregressione import NetLightRegression
681681
from spotPython.hyperdict.light_hyper_dict import LightHyperDict
682682
from spotPython.hyperparameters.values import add_core_model_to_fun_control
683-
add_core_model_to_fun_control(core_model=NetLightBase,
683+
add_core_model_to_fun_control(core_model=NetLightRegression,
684684
fun_control=fun_control,
685685
hyper_dict=LightHyperDict)
686686
# or, if a user wants to use a custom hyper_dict:
687-
>>> from spotPython.light.netlightbase import NetLightBase
687+
>>> from spotPython.light.netlightregression import NetLightRegression
688688
from spotPython.hyperparameters.values import add_core_model_to_fun_control
689-
add_core_model_to_fun_control(core_model=NetLightBase,
689+
add_core_model_to_fun_control(core_model=NetLightRegression,
690690
fun_control=fun_control,
691691
filename="./hyperdict/user_hyper_dict.json")
692692

src/spotPython/light/cnn/googlenet.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from types import SimpleNamespace
22
import torch.nn as nn
33
from spotPython.light.cnn.inceptionblock import InceptionBlock
4-
from typing import Any
54

65

76
class GoogleNet(nn.Module):

src/spotPython/light/cnn/inceptionblock.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ class InceptionBlock(nn.Module):
2727
Dictionary with keys "1x1", "3x3", "5x5", and "max"
2828
act_fn (nn.Module):
2929
Activation class constructor (e.g. nn.ReLU)
30-
30+
3131
3232
Examples:
3333
>>> from spotPython.light.cnn.googlenet import InceptionBlock
@@ -43,6 +43,7 @@ class InceptionBlock(nn.Module):
4343
torch.Size([1, 64, 32, 32])
4444
4545
"""
46+
4647
def __init__(self, c_in, c_red: dict, c_out: dict, act_fn):
4748
super().__init__()
4849

@@ -79,7 +80,7 @@ def __init__(self, c_in, c_red: dict, c_out: dict, act_fn):
7980
act_fn(),
8081
)
8182

82-
def forward(self, x)->torch.Tensor:
83+
def forward(self, x) -> torch.Tensor:
8384
x_1x1 = self.conv_1x1(x)
8485
x_3x3 = self.conv_3x3(x)
8586
x_5x5 = self.conv_5x5(x)

src/spotPython/light/cvmodel.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
import lightning as L
2+
from spotPython.data.lightcrossvalidationdatamodule import LightCrossValidationDataModule
3+
from spotPython.utils.eda import generate_config_id
4+
from pytorch_lightning.loggers import TensorBoardLogger
5+
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
6+
from spotPython.torch.initialization import kaiming_init, xavier_init
7+
import os
8+
9+
10+
def cv_model(config: dict, fun_control: dict) -> float:
11+
"""
12+
Performs k-fold cross-validation on a model using the given configuration and function control parameters.
13+
14+
Args:
15+
config (dict): A dictionary containing the configuration parameters for the model.
16+
fun_control (dict): A dictionary containing the function control parameters.
17+
18+
Returns:
19+
(float): The mean average precision at k (MAP@k) score of the model.
20+
21+
Examples:
22+
>>> config = {
23+
... "initialization": "Xavier",
24+
... "batch_size": 32,
25+
... "patience": 10,
26+
... }
27+
>>> fun_control = {
28+
... "_L_in": 10,
29+
... "_L_out": 1,
30+
... "enable_progress_bar": True,
31+
... "core_model": MyModel,
32+
... "num_workers": 4,
33+
... "DATASET_PATH": "./data",
34+
... "CHECKPOINT_PATH": "./checkpoints",
35+
... "TENSORBOARD_PATH": "./tensorboard",
36+
... "k_folds": 5,
37+
... }
38+
>>> mapk_score = cv_model(config, fun_control)
39+
"""
40+
_L_in = fun_control["_L_in"]
41+
_L_out = fun_control["_L_out"]
42+
if fun_control["enable_progress_bar"] is None:
43+
enable_progress_bar = False
44+
else:
45+
enable_progress_bar = fun_control["enable_progress_bar"]
46+
# Add "CV" postfix to config_id
47+
config_id = generate_config_id(config) + "_CV"
48+
results = []
49+
num_folds = fun_control["k_folds"]
50+
split_seed = 12345
51+
52+
for k in range(num_folds):
53+
print("k:", k)
54+
55+
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out)
56+
initialization = config["initialization"]
57+
if initialization == "Xavier":
58+
xavier_init(model)
59+
elif initialization == "Kaiming":
60+
kaiming_init(model)
61+
else:
62+
pass
63+
# print(f"model: {model}")
64+
65+
dm = LightCrossValidationDataModule(
66+
k=k,
67+
num_splits=num_folds,
68+
split_seed=split_seed,
69+
dataset=fun_control["data_set"],
70+
num_workers=fun_control["num_workers"],
71+
batch_size=config["batch_size"],
72+
data_dir=fun_control["DATASET_PATH"],
73+
)
74+
dm.prepare_data()
75+
dm.setup()
76+
77+
# Init trainer
78+
trainer = L.Trainer(
79+
# Where to save models
80+
default_root_dir=os.path.join(fun_control["CHECKPOINT_PATH"], config_id),
81+
max_epochs=model.hparams.epochs,
82+
accelerator="auto",
83+
devices=1,
84+
logger=TensorBoardLogger(
85+
save_dir=fun_control["TENSORBOARD_PATH"], version=config_id, default_hp_metric=True, log_graph=True
86+
),
87+
callbacks=[
88+
EarlyStopping(monitor="val_loss", patience=config["patience"], mode="min", strict=False, verbose=False)
89+
],
90+
enable_progress_bar=enable_progress_bar,
91+
)
92+
# Pass the datamodule as arg to trainer.fit to override model hooks :)
93+
trainer.fit(model=model, datamodule=dm)
94+
# Test best model on validation and test set
95+
# result = trainer.validate(model=model, datamodule=dm, ckpt_path="last")
96+
score = trainer.validate(model=model, datamodule=dm)
97+
# unlist the result (from a list of one dict)
98+
score = score[0]
99+
print(f"train_model result: {score}")
100+
101+
results.append(score["val_loss"])
102+
103+
score = sum(results) / num_folds
104+
# print(f"cv_model mapk result: {mapk_score}")
105+
return score

0 commit comments

Comments
 (0)