Skip to content

Commit 13366e1

Browse files
v0.5.0
Tensorboard in core spot
1 parent fec4743 commit 13366e1

6 files changed

Lines changed: 89 additions & 21 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.4.0"
10+
version = "0.5.0"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/build/kriging.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,8 @@ def __init__(
5858
n_p=1,
5959
optim_p=False,
6060
log_level=50,
61+
spot_writer=None,
62+
counter=None,
6163
**kwargs
6264
):
6365
"""
@@ -172,6 +174,8 @@ def __init__(
172174
self.name = name
173175
self.seed = seed
174176
self.log_level = log_level
177+
self.spot_writer = spot_writer
178+
self.counter = counter
175179

176180
self.sigma = 0
177181
self.eps = sqrt(spacing(1))
@@ -299,6 +303,21 @@ def update_log(self):
299303
self.log["theta"] = append(self.log["theta"], self.theta)
300304
self.log["p"] = append(self.log["p"], self.p)
301305
self.log["Lambda"] = append(self.log["Lambda"], self.Lambda)
306+
# get the length of the log
307+
self.log_length = len(self.log["negLnLike"])
308+
if self.spot_writer is not None:
309+
writer = self.spot_writer
310+
negLnLike = self.negLnLike.copy()
311+
theta = self.theta.copy()
312+
p = self.p.copy()
313+
Lambda = self.Lambda.copy()
314+
writer.add_scalar("negLnLike", negLnLike, self.counter+self.log_length)
315+
writer.add_scalar("Lambda", Lambda, self.counter+self.log_length)
316+
# add the self.n_theta theta values to the writer with one key "theta", i.e, the same key for all theta values
317+
writer.add_scalars("theta", {f"theta_{i}": theta[i] for i in range(self.n_theta)}, self.counter+self.log_length)
318+
# add the self.n_p p values to the writer with one key "p", i.e, the same key for all p values
319+
writer.add_scalars("p", {f"p_{i}": p[i] for i in range(self.n_p)}, self.counter+self.log_length)
320+
writer.flush()
302321

303322
def fit_old(self, nat_X, nat_y):
304323
"""

src/spotPython/fun/hypertorch.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ def fun_torch(self, X, fun_control=None):
7979
device=self.fun_control["device"],
8080
show_batch_interval=self.fun_control["show_batch_interval"],
8181
task=self.fun_control["task"],
82-
writer=self.fun_control["writer"],
82+
writer=self.fun_control["spot_writer"],
8383
writerId=config_id,
8484
)
8585
elif self.fun_control["eval"] == "test_cv":
@@ -90,7 +90,7 @@ def fun_torch(self, X, fun_control=None):
9090
device=self.fun_control["device"],
9191
show_batch_interval=self.fun_control["show_batch_interval"],
9292
task=self.fun_control["task"],
93-
writer=self.fun_control["writer"],
93+
writer=self.fun_control["spot_writer"],
9494
writerId=config_id,
9595
)
9696
elif self.fun_control["eval"] == "test_hold_out":
@@ -105,7 +105,7 @@ def fun_torch(self, X, fun_control=None):
105105
show_batch_interval=self.fun_control["show_batch_interval"],
106106
path=self.fun_control["path"],
107107
task=self.fun_control["task"],
108-
writer=self.fun_control["writer"],
108+
writer=self.fun_control["spot_writer"],
109109
writerId=config_id,
110110
)
111111
else: # eval == "train_hold_out"
@@ -119,16 +119,16 @@ def fun_torch(self, X, fun_control=None):
119119
show_batch_interval=self.fun_control["show_batch_interval"],
120120
path=self.fun_control["path"],
121121
task=self.fun_control["task"],
122-
writer=self.fun_control["writer"],
122+
writer=self.fun_control["spot_writer"],
123123
writerId=config_id,
124124
)
125125
except Exception as err:
126126
print(f"Error in fun_torch(). Call to evaluate_model failed. {err=}, {type(err)=}")
127127
print("Setting df_eval to np.nan")
128128
df_eval = np.nan
129129
z_val = fun_control["weights"] * df_eval
130-
if self.fun_control["writer"] is not None:
131-
writer = self.fun_control["writer"]
130+
if self.fun_control["spot_writer"] is not None:
131+
writer = self.fun_control["spot_writer"]
132132
writer.add_hparams(config, {"fun_torch: loss": z_val})
133133
writer.flush()
134134
z_res = np.append(z_res, z_val)

src/spotPython/spot/spot.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,18 @@ def __init__(
213213
"seed": 124,
214214
"use_cod_y": False,
215215
}
216+
# Logging information:
217+
self.counter = 0
218+
self.min_y = None
219+
self.min_X = None
220+
self.min_mean_X = None
221+
self.min_mean_y = None
222+
self.mean_X = None
223+
self.mean_y = None
224+
self.var_y = None
225+
logger.setLevel(self.log_level)
226+
logger.info(f"Starting the logger at level {self.log_level} for module {__name__}:")
227+
216228
self.surrogate_control.update(surrogate_control)
217229
# If no surrogate model is specified, use the internal
218230
# spotPython kriging surrogate:
@@ -233,24 +245,15 @@ def __init__(
233245
cod_type=self.surrogate_control["cod_type"],
234246
var_type=self.surrogate_control["var_type"],
235247
use_cod_y=self.surrogate_control["use_cod_y"],
248+
spot_writer=self.fun_control["spot_writer"],
249+
counter=self.design_control["init_size"] * self.design_control["repeats"] - 1,
236250
)
237251
# Optimizer related information:
238252
self.optimizer = optimizer
239253
self.optimizer_control = {"max_iter": 1000, "seed": 125}
240254
self.optimizer_control.update(optimizer_control)
241255
if self.optimizer is None:
242256
self.optimizer = optimize.differential_evolution
243-
# Logging information:
244-
self.counter = 0
245-
self.min_y = None
246-
self.min_X = None
247-
self.min_mean_X = None
248-
self.min_mean_y = None
249-
self.mean_X = None
250-
self.mean_y = None
251-
self.var_y = None
252-
logger.setLevel(self.log_level)
253-
logger.info(f"Starting the logger at level {self.log_level} for module {__name__}:")
254257

255258
def to_red_dim(self):
256259
self.all_lower = self.lower
@@ -334,6 +337,9 @@ def run(self, X_start=None):
334337
self.fit_surrogate()
335338
# progress bar:
336339
self.show_progress_if_needed(timeout_start)
340+
if self.fun_control["spot_writer"] is not None:
341+
writer = self.fun_control["spot_writer"]
342+
writer.close()
337343
return self
338344

339345
def initialize_design(self, X_start=None):
@@ -404,6 +410,8 @@ def update_stats(self):
404410
405411
"""
406412
self.min_y = min(self.y)
413+
# get the last y value:
414+
self.last_y = self.y[-1]
407415
self.min_X = self.X[argmin(self.y)]
408416
self.counter = self.y.size
409417
# Update aggregated x and y values (if noise):
@@ -414,6 +422,14 @@ def update_stats(self):
414422
self.var_y = Z[2]
415423
self.min_mean_y = min(self.mean_y)
416424
self.min_mean_X = self.mean_X[argmin(self.mean_y)]
425+
if self.fun_control["spot_writer"] is not None:
426+
writer = self.fun_control["spot_writer"]
427+
y_min = self.min_y.copy()
428+
y_last = self.last_y.copy()
429+
X_min = self.min_X.copy()
430+
writer.add_scalars("spot_y", {"min": y_min, "last": y_last}, self.counter)
431+
writer.add_scalars("spot_X", {f"X_{i}": X_min[i] for i in range(self.k)}, self.counter)
432+
writer.flush()
417433

418434
def suggest_new_X_old(self):
419435
"""

src/spotPython/utils/file.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import socket
44
from datetime import datetime
55
from dateutil.tz import tzlocal
6+
import pickle
67

78

89
def load_data(data_dir="./data"):
@@ -41,3 +42,35 @@ def get_experiment_name(prefix: str = "00") -> str:
4142
experiment_name = prefix + "_" + HOSTNAME + "_" + str(start_time).split(".", 1)[0].replace(" ", "_")
4243
experiment_name = experiment_name.replace(":", "-")
4344
return experiment_name
45+
46+
47+
def save_pickle(obj, filename: str):
48+
"""Saves an object as a pickle file.
49+
Add .pkl to the filename.
50+
Args:
51+
obj (object): Object to be saved.
52+
filename (str): Name of the pickle file.
53+
Examples:
54+
>>> from spotPython.utils.file import save_pickle
55+
>>> save_pickle(obj, filename="obj.pkl")
56+
"""
57+
filename = filename + ".pkl"
58+
with open(filename, "wb") as f:
59+
pickle.dump(obj, f)
60+
61+
62+
def load_pickle(filename: str):
63+
"""Loads a pickle file.
64+
Add .pkl to the filename.
65+
Args:
66+
filename (str): Name of the pickle file.
67+
Returns:
68+
object: Loaded object.
69+
Examples:
70+
>>> from spotPython.utils.file import load_pickle
71+
>>> obj = load_pickle(filename="obj.pkl")
72+
"""
73+
filename = filename + ".pkl"
74+
with open(filename, "rb") as f:
75+
obj = pickle.load(f)
76+
return obj

src/spotPython/utils/init.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -71,9 +71,9 @@ def fun_control_init(
7171
# Starting with v0.2.41, Summary Writer should be not initialized here but by Lightning
7272
# it is only available for compatibility reasons.
7373
# So, set this to None and let Lightning manage the logging.
74-
writer = SummaryWriter(tensorboard_path)
74+
spot_writer = SummaryWriter(tensorboard_path)
7575
else:
76-
writer = None
76+
spot_writer = None
7777

7878
# Path to the folder where the pretrained models are saved
7979
CHECKPOINT_PATH = os.environ.get("PATH_CHECKPOINT", "saved_models/")
@@ -121,7 +121,7 @@ def fun_control_init(
121121
"task": task,
122122
"tensorboard_path": tensorboard_path,
123123
"weights": 1.0,
124-
"writer": writer,
124+
"spot_writer": spot_writer,
125125
}
126126
return fun_control
127127

0 commit comments

Comments
 (0)