Skip to content

Commit eda3087

Browse files
0.15.13
global initialization (torch) removed
1 parent e240832 commit eda3087

6 files changed

Lines changed: 3 additions & 58 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.15.12"
10+
version = "0.15.13"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/light/cvmodel.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from spotpython.utils.eda import generate_config_id
44
from pytorch_lightning.loggers import TensorBoardLogger
55
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
6-
from spotpython.torch.initialization import kaiming_init, xavier_init
76
import os
87

98

@@ -54,14 +53,6 @@ def cv_model(config: dict, fun_control: dict) -> float:
5453
print("k:", k)
5554

5655
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
57-
initialization = config["initialization"]
58-
if initialization == "Xavier":
59-
xavier_init(model)
60-
elif initialization == "Kaiming":
61-
kaiming_init(model)
62-
else:
63-
pass
64-
# print(f"model: {model}")
6556

6657
dm = LightCrossValidationDataModule(
6758
k=k,

src/spotpython/light/predictmodel.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from pytorch_lightning.loggers import TensorBoardLogger
55
from lightning.pytorch.callbacks import ModelCheckpoint
66
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
7-
from spotpython.torch.initialization import kaiming_init, xavier_init
87
import os
98
from typing import Tuple
109

@@ -76,15 +75,7 @@ def predict_model(config: dict, fun_control: dict) -> Tuple[float, float]:
7675
# dm.setup(stage="train")
7776
# Init model from datamodule's attributes
7877
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
79-
initialization = config["initialization"]
80-
if initialization == "Xavier":
81-
xavier_init(model)
82-
elif initialization == "Kaiming":
83-
kaiming_init(model)
84-
else:
85-
pass
86-
# print(f"model: {model}")
87-
# Init trainer
78+
8879
trainer = L.Trainer(
8980
# Where to save models
9081
default_root_dir=os.path.join(fun_control["CHECKPOINT_PATH"], config_id),

src/spotpython/light/testmodel.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
from pytorch_lightning.loggers import TensorBoardLogger
55
from lightning.pytorch.callbacks import ModelCheckpoint
66
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
7-
from spotpython.torch.initialization import kaiming_init, xavier_init
87
import os
98
from typing import Tuple
109

@@ -77,15 +76,7 @@ def test_model(config: dict, fun_control: dict) -> Tuple[float, float]:
7776
# dm.setup()
7877
# Init model from datamodule's attributes
7978
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
80-
initialization = config["initialization"]
81-
if initialization == "Xavier":
82-
xavier_init(model)
83-
elif initialization == "Kaiming":
84-
kaiming_init(model)
85-
else:
86-
pass
87-
# print(f"model: {model}")
88-
# Init trainer
79+
8980
trainer = L.Trainer(
9081
# Where to save models
9182
default_root_dir=os.path.join(fun_control["CHECKPOINT_PATH"], config_id),

src/spotpython/light/trainmodel.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
from spotpython.utils.eda import generate_config_id
44
from pytorch_lightning.loggers import TensorBoardLogger
55
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
6-
from spotpython.torch.initialization import kaiming_init, xavier_init
76
from lightning.pytorch.callbacks import ModelCheckpoint
87
import os
98

@@ -88,13 +87,6 @@ def train_model(config: dict, fun_control: dict, timestamp: bool = True) -> floa
8887
# the config id is generated here without a timestamp.
8988
config_id = generate_config_id(config, timestamp=False) + "_TRAIN"
9089
model = fun_control["core_model"](**config, _L_in=_L_in, _L_out=_L_out, _torchmetric=_torchmetric)
91-
initialization = config["initialization"]
92-
if initialization == "Xavier":
93-
xavier_init(model)
94-
elif initialization == "Kaiming":
95-
kaiming_init(model)
96-
else:
97-
pass
9890

9991
dm = LightDataModule(
10092
dataset=fun_control["data_set"],

src/spotpython/torch/initialization.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

0 commit comments

Comments
 (0)