Skip to content

Commit 0b0f91a

Browse files
force_run
1 parent 48b8bdb commit 0b0f91a

13 files changed

Lines changed: 284 additions & 146 deletions

notebooks/00_spotPython_tests.ipynb

Lines changed: 169 additions & 69 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.23.1"
10+
version = "0.23.2"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/spot/spot.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -802,11 +802,12 @@ def run(self, X_start: np.ndarray = None) -> Spot:
802802
#
803803
PREFIX = self.fun_control["PREFIX"]
804804
filename = get_result_filename(PREFIX)
805-
if os.path.exists(filename):
805+
if os.path.exists(filename) and not self.fun_control.get("force_run"):
806806
# print a warning and load the result
807807
print(f"Result file {filename} exists. Loading the result.")
808-
spot_tuner = load_result(filename=filename)
809-
return spot_tuner
808+
S = load_result(filename=filename)
809+
self._copy_from(S)
810+
return self
810811
else:
811812
self.initialize_design(X_start)
812813
self.update_stats()
@@ -829,6 +830,10 @@ def run(self, X_start: np.ndarray = None) -> Spot:
829830
self.save_result(verbosity=self.verbosity)
830831
return self
831832

833+
def _copy_from(self, other):
834+
for attr in other.__dict__:
835+
setattr(self, attr, getattr(other, attr))
836+
832837
def initialize_design(self, X_start=None) -> None:
833838
"""
834839
Initialize design. Generate and evaluate initial design.

src/spotpython/utils/init.py

Lines changed: 27 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@
44
from scipy.optimize import differential_evolution
55
import numpy as np
66
import socket
7+
import copy
78
import datetime
89
from dateutil.tz import tzlocal
10+
from importlib.metadata import version, PackageNotFoundError
911
from spotpython.hyperparameters.values import (
1012
add_core_model_to_fun_control,
1113
get_core_model_from_name,
@@ -43,6 +45,7 @@ def fun_control_init(
4345
enable_progress_bar=False,
4446
EXPERIMENT_NAME=None,
4547
eval=None,
48+
force_run=False,
4649
fun_evals=15,
4750
fun_repeats=1,
4851
horizon=None,
@@ -162,6 +165,10 @@ def fun_control_init(
162165
The name of the experiment.
163166
Default is None. If None, the experiment name is generated based on the
164167
current date and time.
168+
force_run (bool):
169+
Whether to force the run or not. If a result file (PREFIX+"_run.pkl") exists, the run is mot
170+
performed and the result is loaded from the file.
171+
Default is False.
165172
fun_evals (int):
166173
The number of function evaluations.
167174
fun_repeats (int):
@@ -382,7 +389,7 @@ def fun_control_init(
382389
L.seed_everything(seed)
383390

384391
if PREFIX is None:
385-
PREFIX = _init_PREFIX()
392+
PREFIX = _init_prefix()
386393

387394
CHECKPOINT_PATH, DATASET_PATH, RESULTS_PATH, TENSORBOARD_PATH = setup_paths(TENSORBOARD_CLEAN)
388395
spot_tensorboard_path = create_spot_tensorboard_path(tensorboard_log, PREFIX)
@@ -420,6 +427,7 @@ def fun_control_init(
420427
"devices": devices,
421428
"enable_progress_bar": enable_progress_bar,
422429
"eval": eval,
430+
"force_run": force_run,
423431
"fun_evals": fun_evals,
424432
"fun_repeats": fun_repeats,
425433
"horizon": horizon,
@@ -521,21 +529,26 @@ def fun_control_init(
521529
return fun_control
522530

523531

524-
def _init_PREFIX() -> str:
525-
"""Initialize the PREFIX for the experiment name.
532+
def _init_prefix() -> str:
533+
"""Initialize the prefix for the experiment name.
534+
Attempts to derive the prefix from the package version. If unsuccessful,
535+
defaults to '000'.
526536
527537
Returns:
528-
PREFIX (str):
529-
The PREFIX for the experiment name.
538+
str: The prefix for the experiment name.
530539
531540
Examples:
532-
>>> from spotpython.utils.init import _init_PREFIX
533-
>>> _init_PREFIX()
541+
>>> from spotpython.utils.init import _init_prefix
542+
>>> _init_prefix()
534543
'00'
535544
"""
536-
# set the prefix to the actual date and time
537-
PREFIX = datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
538-
return PREFIX
545+
DEFAULT_PREFIX = "000"
546+
try:
547+
package_version = version("package_name")
548+
except PackageNotFoundError:
549+
package_version = DEFAULT_PREFIX
550+
551+
return package_version
539552

540553

541554
def setup_paths(tensorboard_clean) -> tuple:
@@ -584,8 +597,11 @@ def setup_paths(tensorboard_clean) -> tuple:
584597
now = datetime.datetime.now()
585598
os.makedirs("runs_OLD", exist_ok=True)
586599
# use [:-1] to remove "/" from the end of the path
587-
TENSORBOARD_PATH_OLD = "runs_OLD/" + TENSORBOARD_PATH[:-1] + "_" + now.strftime("%Y_%m_%d_%H_%M_%S")
600+
TENSORBOARD_PATH_OLD = "runs_OLD/" + TENSORBOARD_PATH[:-1] + "_" + now.strftime("%Y_%m_%d_%H_%M_%S") + "_" + "0"
588601
print(f"Moving TENSORBOARD_PATH: {TENSORBOARD_PATH} to TENSORBOARD_PATH_OLD: {TENSORBOARD_PATH_OLD}")
602+
# if TENSORBOARD_PATH_OLD already exists, change the name increasing the number at the end
603+
while os.path.exists(TENSORBOARD_PATH_OLD):
604+
TENSORBOARD_PATH_OLD = copy.deepcopy(TENSORBOARD_PATH_OLD[:-1] + str(int(TENSORBOARD_PATH_OLD[-1]) + 1))
589605
os.rename(TENSORBOARD_PATH[:-1], TENSORBOARD_PATH_OLD)
590606

591607
os.makedirs(TENSORBOARD_PATH, exist_ok=True)

test/test_get_spot_attributes_as_df.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import pandas as pd
44
from math import inf
55
from spotpython.fun.objectivefunctions import Analytical
6-
from spotpython.spot import spot
6+
from spotpython.spot import Spot
77
from spotpython.utils.init import fun_control_init, design_control_init
88

99
def test_get_spot_attributes_as_df():
@@ -12,24 +12,25 @@ def test_get_spot_attributes_as_df():
1212
n = 10
1313
fun = Analytical().fun_sphere
1414
fun_control = fun_control_init(
15+
PREFIX= "test_get_spot_attributes_as_df",
1516
lower=np.array([-1]),
1617
upper=np.array([1]),
1718
fun_evals=n
1819
)
1920
design_control = design_control_init(init_size=ni)
20-
21+
2122
# Create instance of the Spot class
22-
spot_instance = spot.Spot(
23+
S = Spot(
2324
fun=fun,
2425
fun_control=fun_control,
2526
design_control=design_control
2627
)
2728

2829
# Run the optimization
29-
spot_instance.run()
30+
S.run()
3031

3132
# Get the attributes as a DataFrame
32-
df = spot_instance.get_spot_attributes_as_df()
33+
df = S.get_spot_attributes_as_df()
3334

3435
# Define expected attribute names (ensure these match your Spot class' attributes)
3536
expected_attributes = ['X',

test/test_repair_non_numeric.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,27 +10,30 @@ def test_repair_non_numeric():
1010
fun_control_init,
1111
design_control_init,
1212
)
13-
13+
1414
fun = Analytical().fun_branin_factor
1515
ni = 12
16-
spot_test = Spot(
16+
S = Spot(
1717
fun=fun,
1818
fun_control=fun_control_init(
19-
lower=np.array([-5, -0, 1]), upper=np.array([10, 15, 3]), var_type=["num", "num", "factor"]
19+
PREFIX="test_repair_non_numeric",
20+
lower=np.array([-5, -0, 1]),
21+
upper=np.array([10, 15, 3]),
22+
var_type=["num", "num", "factor"]
2023
),
2124
design_control=design_control_init(init_size=ni),
2225
)
23-
spot_test.run()
26+
S.run()
2427
# 3rd variable should be a rounded float, because it was labeled as a factor
25-
assert spot_test.min_X[2] == round(spot_test.min_X[2])
28+
assert S.min_X[2] == round(S.min_X[2])
2629

27-
spot_test.X = spot_test.generate_design(
28-
size=spot_test.design_control["init_size"],
29-
repeats=spot_test.design_control["repeats"],
30-
lower=spot_test.lower,
31-
upper=spot_test.upper,
30+
S.X = S.generate_design(
31+
size=S.design_control["init_size"],
32+
repeats=S.design_control["repeats"],
33+
lower=S.lower,
34+
upper=S.upper,
3235
)
33-
spot_test.X = repair_non_numeric(spot_test.X, spot_test.var_type)
34-
assert spot_test.X.ndim == 2
35-
assert spot_test.X.shape[0] == ni
36-
assert spot_test.X.shape[1] == 3
36+
S.X = repair_non_numeric(S.X, S.var_type)
37+
assert S.X.ndim == 2
38+
assert S.X.shape[0] == ni
39+
assert S.X.shape[1] == 3

test/test_run.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def test_run_method():
1616

1717
# Initialize function control
1818
fun_control = fun_control_init(
19+
PREFIX = "test_spot_run",
1920
lower=np.array([-1, -1]),
2021
upper=np.array([1, 1])
2122
)
@@ -39,7 +40,7 @@ def test_run_method():
3940
assert S.X.shape[0] == S.y.shape[0], "The design matrix X and response vector y should have the same number of rows."
4041

4142
# Check if the minimum value in y is as expected
42-
assert np.min(S.y) == 0.0, "The minimum value in y should be 0.0."
43+
assert np.min(np.abs(S.y)) == 0.0, "The minimum value in y should be 0.0."
4344

4445
# Check if the corresponding x values are as expected
4546
min_index = np.argmin(S.y)

test/test_save_and_load_experiment.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def _compare_dicts(dict1, dict2, ignore_keys=None):
4848
return True
4949

5050
def test_save_and_load_experiment():
51-
PREFIX = "test_02"
51+
PREFIX = "test_save_and_load_experiment_02"
5252
# Initialize function control
5353
fun_control = fun_control_init(
5454
PREFIX=PREFIX,

test/test_save_experiment.py

Lines changed: 13 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,14 @@
55
from spotpython.spot import Spot
66
from spotpython.fun.objectivefunctions import Analytical
77
from spotpython.utils.init import fun_control_init, design_control_init
8+
from spotpython.utils.file import load_experiment, load_result
89

910
def test_save_experiment(tmp_path, capsys):
11+
PREFIX="test_save_experiment"
12+
1013
# Initialize function control
1114
fun_control = fun_control_init(
15+
PREFIX=PREFIX,
1216
lower=np.array([-1, -1]),
1317
upper=np.array([1, 1])
1418
)
@@ -25,36 +29,20 @@ def test_save_experiment(tmp_path, capsys):
2529
fun_control=fun_control,
2630
design_control=design_control,
2731
)
32+
33+
Sexp_load = load_experiment(PREFIX)
34+
assert Sexp_load.design_control["init_size"] == 7
2835

2936
# Run the optimization to generate some data
3037
X_start = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])
3138
S.run(X_start=X_start)
39+
40+
Srun_load = load_result(PREFIX)
41+
assert Srun_load.design_control["init_size"] == 7
42+
assert Srun_load.X.shape[0] > 0
43+
assert Srun_load.y.shape[0] > 0
44+
3245

33-
# Define the filename and path
34-
filename = "test_experiment.pkl"
35-
path = tmp_path
36-
37-
# Save the experiment
38-
S.save_experiment(filename=filename, path=path)
39-
40-
# Check if the file was created
41-
filepath = os.path.join(path, filename)
42-
assert os.path.exists(filepath), f"File {filepath} should exist."
43-
44-
# Load the experiment and check its contents
45-
with open(filepath, "rb") as handle:
46-
experiment = pickle.load(handle)
47-
assert "design_control" in experiment, "design_control should be in the experiment dictionary."
48-
assert "fun_control" in experiment, "fun_control should be in the experiment dictionary."
49-
assert "optimizer_control" in experiment, "optimizer_control should be in the experiment dictionary."
50-
assert "spot_tuner" in experiment, "spot_tuner should be in the experiment dictionary."
51-
assert "surrogate_control" in experiment, "surrogate_control should be in the experiment dictionary."
52-
53-
# Test overwrite functionality
54-
S.save_experiment(filename=filename, path=path, overwrite=False)
55-
captured = capsys.readouterr()
56-
assert "Error: File" in captured.out
57-
assert "already exists. Use overwrite=True to overwrite the file." in captured.out
5846

5947
if __name__ == "__main__":
6048
pytest.main()

test/test_show_progress.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ def test_show_progress():
44
"""
55
import numpy as np
66
from spotpython.fun.objectivefunctions import Analytical
7-
from spotpython.spot import spot
7+
from spotpython.spot import Spot
88
from math import inf
99
from spotpython.utils.init import (
1010
fun_control_init,
@@ -20,9 +20,10 @@ def test_show_progress():
2020
lower = np.array([-1])
2121
upper = np.array([1])
2222

23-
spot_1 = spot.Spot(
23+
spot_1 = Spot(
2424
fun=fun,
25-
fun_control=fun_control_init(lower=lower, upper=upper, fun_evals=n, show_progress=False),
25+
fun_control=fun_control_init(PREFIX="test_show_progress_1",
26+
lower=lower, upper=upper, fun_evals=n, show_progress=False),
2627
design_control=design_control_init(init_size=ni),
2728
)
2829
spot_1.run()
@@ -31,9 +32,13 @@ def test_show_progress():
3132
# number of points.
3233
assert spot_1.y.shape[0] == n
3334

34-
spot_2 = spot.Spot(
35+
spot_2 = Spot(
3536
fun=fun,
36-
fun_control=fun_control_init(lower=lower, upper=upper, fun_evals=n, show_progress=False),
37+
fun_control=fun_control_init(PREFIX="test_show_progress_2",
38+
lower=lower,
39+
upper=upper,
40+
fun_evals=n,
41+
show_progress=False),
3742
design_control=design_control_init(init_size=ni),
3843
)
3944
spot_2.run()
@@ -42,9 +47,13 @@ def test_show_progress():
4247
# number of points.
4348
assert spot_2.y.shape[0] == n
4449

45-
spot_3 = spot.Spot(
50+
spot_3 = Spot(
4651
fun=fun,
47-
fun_control=fun_control_init(lower=lower, upper=upper, fun_evals=inf, max_time=0.1, show_progress=False),
52+
fun_control=fun_control_init(PREFIX="test_show_progress_3",
53+
lower=lower,
54+
upper=upper,
55+
fun_evals=inf,
56+
max_time=0.1, show_progress=False),
4857
design_control=design_control_init(init_size=ni),
4958
)
5059
spot_3.run()
@@ -53,9 +62,10 @@ def test_show_progress():
5362
# because we do not know how many points can be evaluated.
5463
assert spot_3.y.shape[0] > 0
5564

56-
spot_4 = spot.Spot(
65+
spot_4 = Spot(
5766
fun=fun,
58-
fun_control=fun_control_init(lower=lower, upper=upper, fun_evals=inf, max_time=0.1, show_progress=True),
67+
fun_control=fun_control_init(PREFIX="test_show_progress_4",
68+
lower=lower, upper=upper, fun_evals=inf, max_time=0.1, show_progress=True),
5969
design_control=design_control_init(init_size=ni),
6070
)
6171
spot_4.run()

0 commit comments

Comments
 (0)