Skip to content

Commit e240832

Browse files
0.15.12
- hidden_size generatot in nn_linear_regressor.py - anisotropic is set to default in spot.py
1 parent 622220c commit e240832

6 files changed

Lines changed: 104 additions & 13 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4836,6 +4836,90 @@
48364836
"spot_tuner = load_and_run_spot_python_experiment(\"spot_000_experiment.pickle\")"
48374837
]
48384838
},
4839+
{
4840+
"cell_type": "markdown",
4841+
"metadata": {},
4842+
"source": [
4843+
"# Lightning "
4844+
]
4845+
},
4846+
{
4847+
"cell_type": "code",
4848+
"execution_count": 13,
4849+
"metadata": {},
4850+
"outputs": [],
4851+
"source": [
4852+
"def _generate_div2_list(n, n_min) -> list:\n",
4853+
" \"\"\"\n",
4854+
" Generate a list of numbers from n to n_min (inclusive) by dividing n by 2\n",
4855+
" until the result is less than n_min.\n",
4856+
" This function starts with n and keeps dividing it by 2 until n_min is reached.\n",
4857+
" The number of times each value is added to the list is determined by n // current.\n",
4858+
"\n",
4859+
" Args:\n",
4860+
" n (int): The number to start with.\n",
4861+
" n_min (int): The minimum number to stop at.\n",
4862+
"\n",
4863+
" Returns:\n",
4864+
" list: A list of numbers from n to n_min (inclusive).\n",
4865+
"\n",
4866+
" Examples:\n",
4867+
" _generate_div2_list(10, 1)\n",
4868+
" [10, 5, 5, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
4869+
" _generate_div2_list(10, 2)\n",
4870+
" [10, 5, 5, 2, 2, 2, 2, 2]\n",
4871+
" \"\"\"\n",
4872+
" result = []\n",
4873+
" current = n\n",
4874+
" repeats = 1\n",
4875+
" max_repeats = 4\n",
4876+
" while current >= n_min:\n",
4877+
" result.extend([current] * min(repeats, max_repeats))\n",
4878+
" current = current // 2\n",
4879+
" repeats = repeats + 1\n",
4880+
" return result"
4881+
]
4882+
},
4883+
{
4884+
"cell_type": "code",
4885+
"execution_count": 14,
4886+
"metadata": {},
4887+
"outputs": [
4888+
{
4889+
"data": {
4890+
"text/plain": [
4891+
"[10, 5, 5]"
4892+
]
4893+
},
4894+
"execution_count": 14,
4895+
"metadata": {},
4896+
"output_type": "execute_result"
4897+
}
4898+
],
4899+
"source": [
4900+
"_generate_div2_list(10, 3)"
4901+
]
4902+
},
4903+
{
4904+
"cell_type": "code",
4905+
"execution_count": 17,
4906+
"metadata": {},
4907+
"outputs": [
4908+
{
4909+
"data": {
4910+
"text/plain": [
4911+
"[128, 64, 64, 32, 32, 32]"
4912+
]
4913+
},
4914+
"execution_count": 17,
4915+
"metadata": {},
4916+
"output_type": "execute_result"
4917+
}
4918+
],
4919+
"source": [
4920+
"_generate_div2_list(128, 32)"
4921+
]
4922+
},
48394923
{
48404924
"cell_type": "code",
48414925
"execution_count": null,

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.15.11"
10+
version = "0.15.12"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/light/regression/netlightregression.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def __init__(
155155
if self.hparams.l1 < 4:
156156
raise ValueError("l1 must be at least 4")
157157

158-
# TODO: Implement a hidden_sizes generator function.
158+
# TODO: Implement a hidden_sizes generator function.
159159
# This function is implemented in the updadated version of this class which
160160
# is available as nn_linear_regression.py in the same directory.
161161
hidden_sizes = [self.hparams.l1, self.hparams.l1 // 2, self.hparams.l1 // 2, self.hparams.l1 // 4]

src/spotpython/light/regression/nn_linear_regressor.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from torch import nn
44
from spotpython.hyperparameters.optimizer import optimizer_handler
55
import torchmetrics.functional.regression
6+
from spotpython.utils.math import generate_div2_list
67

78

89
class NNLinearRegressor(L.LightningModule):
@@ -164,14 +165,7 @@ def __init__(
164165
self.example_input_array = torch.zeros((batch_size, self._L_in))
165166
if self.hparams.l1 < 4:
166167
raise ValueError("l1 must be at least 4")
167-
168-
# TODO: Implement a hidden_sizes generator function
169-
hidden_sizes = [self.hparams.l1, self.hparams.l1 // 2, self.hparams.l1 // 2, self.hparams.l1 // 4]
170-
# n_low = _L_in // 4
171-
# # ensure that n_high is larger than n_low
172-
# n_high = max(self.hparams.l1, 2 * n_low)
173-
# hidden_sizes = generate_div2_list(n_high, n_low)
174-
168+
hidden_sizes = self._get_hidden_sizes()
175169
# Create the network based on the specified hidden sizes
176170
layers = []
177171
layer_sizes = [self._L_in] + hidden_sizes
@@ -187,6 +181,19 @@ def __init__(
187181
# nn.Sequential summarizes a list of modules into a single module, applying them in sequence
188182
self.layers = nn.Sequential(*layers)
189183

184+
def _get_hidden_sizes(self):
185+
"""
186+
Generate the hidden layer sizes for the network.
187+
188+
Returns:
189+
list: A list of hidden layer sizes.
190+
191+
"""
192+
n_low = self._L_in // 4
193+
n_high = max(self.hparams.l1, 2 * n_low)
194+
hidden_sizes = generate_div2_list(n_high, n_low)
195+
return hidden_sizes
196+
190197
def forward(self, x: torch.Tensor) -> torch.Tensor:
191198
"""
192199
Performs a forward pass through the model.

src/spotpython/spot/spot.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,7 +311,7 @@ def __init__(
311311
if self.surrogate_control["n_theta"] == "anisotropic":
312312
surrogate_control.update({"n_theta": self.k})
313313
else:
314-
# case "isotropic":
314+
# case "isotropic":
315315
surrogate_control.update({"n_theta": 1})
316316
if isinstance(self.surrogate_control["n_theta"], int):
317317
if self.surrogate_control["n_theta"] > 1:
@@ -2146,7 +2146,8 @@ def get_importance(self) -> list:
21462146
21472147
Returns:
21482148
output (list):
2149-
list of results. If the surrogate has more than one theta values, the importance is calculated. Otherwise, a list of zeros is returned.
2149+
list of results. If the surrogate has more than one theta values,
2150+
the importance is calculated. Otherwise, a list of zeros is returned.
21502151
21512152
"""
21522153
if self.surrogate.n_theta > 1 and self.var_name is not None:

src/spotpython/utils/file.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import importlib
88
from spotpython.hyperparameters.values import get_tuned_architecture
99
from spotpython.utils.eda import gen_design_table
10-
from spotpython.utils.tensorboard import start_tensorboard, stop_tensorboard
1110
from spotpython.utils.init import setup_paths
1211

1312

0 commit comments

Comments
 (0)