Skip to content

Commit b41b4c3

Browse files
light 2
1 parent 7b248c2 commit b41b4c3

5 files changed

Lines changed: 226 additions & 9 deletions

File tree

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
{ "GoogleNet":
2+
{
3+
"act_fn": {
4+
"levels": ["Sigmoid",
5+
"Tanh",
6+
"ReLU",
7+
"LeakyReLU",
8+
"ELU",
9+
"Swish"],
10+
"type": "factor",
11+
"default": "ReLU",
12+
"transform": "None",
13+
"class_name": "spotPython.torch.activation",
14+
"core_model_parameter_type": "instance",
15+
"lower": 0,
16+
"upper": 5},
17+
"optimizer_name": {
18+
"levels": [
19+
"Adam"
20+
],
21+
"type": "factor",
22+
"default": "Adam",
23+
"transform": "None",
24+
"class_name": "torch.optim",
25+
"core_model_parameter_type": "str",
26+
"lower": 0,
27+
"upper": 0}
28+
},
29+
"NetLinearBase":
30+
{
31+
"l1": {
32+
"type": "int",
33+
"default": 3,
34+
"transform": "transform_power_2_int",
35+
"lower": 3,
36+
"upper": 8},
37+
"epochs": {
38+
"type": "int",
39+
"default": 4,
40+
"transform": "transform_power_2_int",
41+
"lower": 4,
42+
"upper": 9},
43+
"batch_size": {
44+
"type": "int",
45+
"default": 4,
46+
"transform": "transform_power_2_int",
47+
"lower": 1,
48+
"upper": 4},
49+
"act_fn": {
50+
"levels": ["Sigmoid",
51+
"Tanh",
52+
"ReLU",
53+
"LeakyReLU",
54+
"ELU",
55+
"Swish"],
56+
"type": "factor",
57+
"default": "ReLU",
58+
"transform": "None",
59+
"class_name": "spotPython.torch.activation",
60+
"core_model_parameter_type": "instance()",
61+
"lower": 0,
62+
"upper": 5},
63+
"optimizer": {
64+
"levels": ["Adadelta",
65+
"Adagrad",
66+
"Adam",
67+
"AdamW",
68+
"SparseAdam",
69+
"Adamax",
70+
"ASGD",
71+
"NAdam",
72+
"RAdam",
73+
"RMSprop",
74+
"Rprop",
75+
"SGD"],
76+
"type": "factor",
77+
"default": "SGD",
78+
"transform": "None",
79+
"class_name": "torch.optim",
80+
"core_model_parameter_type": "str",
81+
"lower": 0,
82+
"upper": 11},
83+
"dropout_prob": {
84+
"type": "float",
85+
"default": 0.01,
86+
"transform": "None",
87+
"lower": 0.0,
88+
"upper": 0.25},
89+
"lr_mult": {
90+
"type": "float",
91+
"default": 1.0,
92+
"transform": "None",
93+
"lower": 0.1,
94+
"upper": 10.0},
95+
"patience": {
96+
"type": "int",
97+
"default": 2,
98+
"transform": "transform_power_2_int",
99+
"lower": 2,
100+
"upper": 6
101+
},
102+
"initialization": {
103+
"levels": ["Default", "Kaiming", "Xavier"],
104+
"type": "factor",
105+
"default": "Default",
106+
"transform": "None",
107+
"core_model_parameter_type": "str",
108+
"lower": 0,
109+
"upper": 2}
110+
}
111+
}
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import json
2+
from spotPython.data import base
3+
4+
5+
class LightningHyperDict(base.FileConfig):
6+
"""Lightning hyperparameter dictionary.
7+
8+
This class extends the FileConfig class to provide a dictionary for storing hyperparameters.
9+
10+
Attributes:
11+
filename (str):
12+
The name of the file where the hyperparameters are stored.
13+
"""
14+
15+
def __init__(self):
16+
"""Initialize the LightHyperDict object.
17+
18+
Examples:
19+
>>> lhd = LightHyperDict()
20+
"""
21+
super().__init__(
22+
filename="lightning_hyper_dict.json",
23+
)
24+
25+
def load(self) -> dict:
26+
"""Load the hyperparameters from the file.
27+
28+
Returns:
29+
dict: A dictionary containing the hyperparameters.
30+
31+
Examples:
32+
>>> lhd = LightHyperDict()
33+
>>> hyperparams = lhd.load()
34+
>>> print(hyperparams)
35+
{'learning_rate': 0.001, 'batch_size': 32, 'epochs': 10}
36+
"""
37+
with open(self.path, "r") as f:
38+
d = json.load(f)
39+
return d

src/spotPython/fun/hyperlightning.py

Lines changed: 44 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,21 @@
2121
class HyperLightning:
2222
"""
2323
Hyperparameter Tuning for Lightning.
24+
25+
Args:
26+
seed (int): seed for the random number generator. See Numpy Random Sampling.
27+
log_level (int): log level for the logger.
28+
29+
Attributes:
30+
seed (int): seed for the random number generator.
31+
rng (Generator): random number generator.
32+
fun_control (dict): dictionary containing control parameters for the hyperparameter tuning.
33+
log_level (int): log level for the logger.
34+
35+
Examples:
36+
>>> hyper_light = HyperLight(seed=126, log_level=50)
37+
>>> print(hyper_light.seed)
38+
126
2439
"""
2540

2641
def __init__(self, seed: int = 126, log_level: int = 50) -> None:
@@ -90,11 +105,33 @@ def fun(self, X: np.ndarray, fun_control: dict = None) -> np.ndarray:
90105
array containing the evaluation results.
91106
92107
Examples:
93-
>>> hyper_light = HyperLight(seed=126, log_level=50)
94-
X = np.array([[1, 2], [3, 4]])
95-
fun_control = {"weights": np.array([1, 0, 0])}
96-
hyper_light.fun(X, fun_control)
97-
array([nan, nan])
108+
>>> MAX_TIME = 1
109+
INIT_SIZE = 5
110+
WORKERS = 0
111+
PREFIX="TEST"
112+
from spotPython.utils.init import fun_control_init
113+
from spotPython.utils.file import get_experiment_name, get_spot_tensorboard_path
114+
from spotPython.utils.device import getDevice
115+
experiment_name = get_experiment_name(prefix=PREFIX)
116+
fun_control = fun_control_init(
117+
spot_tensorboard_path=get_spot_tensorboard_path(experiment_name),
118+
num_workers=WORKERS,
119+
device=getDevice(),
120+
_L_in=3,
121+
_L_out=10,
122+
TENSORBOARD_CLEAN=True)
123+
from spotPython.light.cnn.googlenet import GoogleNet
124+
from spotPython.data.lightning_hyper_dict import LightningHyperDict
125+
from spotPython.hyperparameters.values import add_core_model_to_fun_control
126+
add_core_model_to_fun_control(core_model=GoogleNet,
127+
fun_control=fun_control,
128+
hyper_dict= LightningHyperDict)
129+
from spotPython.hyperparameters.values import get_default_hyperparameters_as_array
130+
X_start = get_default_hyperparameters_as_array(fun_control)
131+
from spotPython.fun.hyperlightning import HyperLightning
132+
hyper_light = HyperLightning(seed=126, log_level=50)
133+
hyper_light.fun(X=X_start, fun_control=fun_control)
134+
98135
"""
99136
z_res = np.array([], dtype=float)
100137
if fun_control is not None:
@@ -104,6 +141,8 @@ def fun(self, X: np.ndarray, fun_control: dict = None) -> np.ndarray:
104141
# type information and transformations are considered in generate_one_config_from_var_dict:
105142
for config in generate_one_config_from_var_dict(var_dict, self.fun_control):
106143
logger.debug(f"\nconfig: {config}")
144+
print(f"\ncore_model: {fun_control['core_model']}")
145+
print(f"config: {config}")
107146
# extract parameters like epochs, batch_size, lr, etc. from config
108147
# config_id = generate_config_id(config)
109148
try:

src/spotPython/light/cnn/netcnnbase.py

Lines changed: 26 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,13 @@
11
import lightning as L
22
import torch
33
from torch import nn
4-
from spotPython.light.utils import create_model
4+
5+
# from spotPython.light.utils import create_model
56
import torch.optim as optim
67

8+
# from spotPython.light.cnn.googlenet import GoogleNet
9+
import spotPython.light.cnn.googlenet
10+
711

812
class NetCNNBase(L.LightningModule):
913
def __init__(self, config, fun_control):
@@ -31,11 +35,30 @@ def __init__(self, config, fun_control):
3135
torch.Size([1, 10])
3236
3337
"""
38+
print("NetCNNBase: Starting")
39+
print(f"NetCNNBase: config: {config}")
40+
print(f"NetCNNBase: fun_control['core_model']: {fun_control['core_model']}")
41+
config = {
42+
"c_in": 3,
43+
"c_out": 10,
44+
"act_fn": nn.ReLU,
45+
"optimizer_name": "Adam",
46+
"optimizer_hparams": {"lr": 1e-3, "weight_decay": 1e-4},
47+
}
48+
print("fun_control['core_model']: ", fun_control["core_model"])
49+
print("fun_control['core_model'].type: ", fun_control["core_model"].type)
50+
# fun_control = {"core_model": GoogleNet}
51+
fun_control = {"core_model": spotPython.light.cnn.googlenet.GoogleNet}
3452
super().__init__()
3553
# Exports the hyperparameters to a YAML file, and create "self.hparams" namespace
36-
self.save_hyperparameters()
54+
self.save_hyperparameters() # "fun_control" is not a hyperparameter )
55+
print(f"config: {config}")
3756
# Create model
38-
self.model = create_model(config, fun_control)
57+
print("Creating model")
58+
# self.model = create_model(config, fun_control)
59+
self.model = fun_control["core_model"](**config)
60+
print("Model created")
61+
print(f"self.model: {self.model}")
3962
# Create loss module
4063
self.loss_module = nn.CrossEntropyLoss()
4164
# Example input for visualizing the graph in Tensorboard

src/spotPython/light/trainmodel.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,8 @@ def train_model(config: dict, fun_control: dict):
3939
{'test': 0.8772, 'val': 0.8772}
4040
4141
"""
42+
print("train_model: Starting")
43+
print(f"train_model: config: {config}")
4244
save_name = "saved_models"
4345
# Create PyTorch Lightning data loaders
4446
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/ConvNets")
@@ -85,13 +87,14 @@ def train_model(config: dict, fun_control: dict):
8587
# END TODO
8688

8789
# Create a PyTorch Lightning trainer with the generation callback
90+
print("train_model: Creating trainer")
8891
trainer = L.Trainer(
8992
default_root_dir=os.path.join(CHECKPOINT_PATH, save_name), # Where to save models
9093
# We run on a single GPU (if possible)
9194
accelerator="auto",
9295
devices=1,
9396
# How many epochs to train for if no patience is set
94-
max_epochs=180,
97+
max_epochs=4,
9598
callbacks=[
9699
ModelCheckpoint(
97100
save_weights_only=True, mode="max", monitor="val_acc"
@@ -101,6 +104,7 @@ def train_model(config: dict, fun_control: dict):
101104
) # In case your notebook crashes due to the progress bar, consider increasing the refresh rate
102105
trainer.logger._log_graph = True # If True, we plot the computation graph in tensorboard
103106
trainer.logger._default_hp_metric = None # Optional logging argument that we don't need
107+
print("train_model: Created trainer")
104108

105109
# Check whether pretrained model exists. If yes, load it and skip training
106110
pretrained_filename = os.path.join(CHECKPOINT_PATH, save_name + ".ckpt")
@@ -110,6 +114,7 @@ def train_model(config: dict, fun_control: dict):
110114
model = NetCNNBase.load_from_checkpoint(pretrained_filename)
111115
else:
112116
L.seed_everything(42) # To be reproducable
117+
print("train_model: Creating model")
113118
model = NetCNNBase(config=config, fun_control=fun_control) # Create model
114119
trainer.fit(model, train_loader, val_loader)
115120
model = NetCNNBase.load_from_checkpoint(

0 commit comments

Comments
 (0)