Skip to content

Commit c7142ab

Browse files
v0.2.1
1 parent 71d96c7 commit c7142ab

6 files changed

Lines changed: 36 additions & 11 deletions

File tree

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotPython"
10-
version = "0.2.0"
10+
version = "0.2.1"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotPython/data/torch_hyper_dict.json

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,10 @@
3333
"upper": 4},
3434
"k_folds": {
3535
"type": "int",
36-
"default": 2,
36+
"default": 1,
3737
"transform": "None",
38-
"lower": 2,
39-
"upper": 3},
38+
"lower": 1,
39+
"upper": 1},
4040
"patience": {
4141
"type": "int",
4242
"default": 5,
@@ -93,10 +93,10 @@
9393
"upper": 4},
9494
"k_folds": {
9595
"type": "int",
96-
"default": 2,
96+
"default": 1,
9797
"transform": "None",
98-
"lower": 2,
99-
"upper": 3},
98+
"lower": 1,
99+
"upper": 1},
100100
"patience": {
101101
"type": "int",
102102
"default": 5,
@@ -157,10 +157,10 @@
157157
"upper": 1},
158158
"k_folds": {
159159
"type": "int",
160-
"default": 2,
160+
"default": 1,
161161
"transform": "None",
162-
"lower": 2,
163-
"upper": 3
162+
"lower": 1,
163+
"upper": 1
164164
},
165165
"patience": {
166166
"type": "int",

src/spotPython/data/torchdata.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from torchvision import datasets
2+
import torchvision.transforms as transforms
3+
4+
5+
def load_data_cifar10(data_dir="./data"):
6+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
7+
8+
trainset = datasets.CIFAR10(root=data_dir, train=True, download=True, transform=transform)
9+
10+
testset = datasets.CIFAR10(root=data_dir, train=False, download=True, transform=transform)
11+
12+
return trainset, testset

src/spotPython/hyperparameters/values.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -576,7 +576,7 @@ def get_values_from_dict(dictionary) -> np.array:
576576
return np.array(list(dictionary.values()))
577577

578578

579-
def add_core_model_to_fun_control(core_model, fun_control, hyper_dict, filename) -> dict:
579+
def add_core_model_to_fun_control(core_model, fun_control, hyper_dict, filename=None) -> dict:
580580
"""Add the core model to the function control dictionary.
581581
Args:
582582
core_model (class): The core model.

src/spotPython/spot/spot.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -709,6 +709,18 @@ def plot_contour(
709709
if show:
710710
pylab.show()
711711

712+
def plot_important_hyperparameter_contour(self, threshold=0.025, filename=None):
713+
impo = self.print_importance(threshold=threshold, print_screen=True)
714+
var_plots = [i for i, x in enumerate(impo) if x[1] > threshold]
715+
min_z = min(self.y)
716+
max_z = max(self.y)
717+
for i in var_plots:
718+
for j in var_plots:
719+
if j > i:
720+
if filename is not None:
721+
filename = filename + "_contour_" + str(i) + "_" + str(j) + ".png"
722+
self.plot_contour(i=i, j=j, min_z=min_z, max_z=max_z, filename=filename)
723+
712724
def get_importance(self) -> list:
713725
"""Get importance of each variable and return the results as a list.
714726
Returns:

src/spotPython/utils/init.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,5 +48,6 @@ def fun_control_init():
4848
"show_batch_interval": 1_000_000,
4949
"path": None,
5050
"save_model": False,
51+
"weights": 1.0,
5152
}
5253
return fun_control

0 commit comments

Comments
 (0)