Skip to content

Commit 469126a

Browse files
0.21.3
1 parent 6a095a4 commit 469126a

11 files changed

Lines changed: 789 additions & 351 deletions

notebooks/00_spotPython_tests.ipynb

Lines changed: 183 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -7131,39 +7131,9 @@
71317131
},
71327132
{
71337133
"cell_type": "code",
7134-
"execution_count": 1,
7134+
"execution_count": null,
71357135
"metadata": {},
7136-
"outputs": [
7137-
{
7138-
"name": "stderr",
7139-
"output_type": "stream",
7140-
"text": [
7141-
"Seed set to 123\n",
7142-
"Seed set to 123\n"
7143-
]
7144-
},
7145-
{
7146-
"name": "stdout",
7147-
"output_type": "stream",
7148-
"text": [
7149-
"spotpython tuning: 0.0 [########--] 80.00% \n",
7150-
"spotpython tuning: 0.0 [#########-] 86.67% \n",
7151-
"spotpython tuning: 0.0 [#########-] 93.33% \n",
7152-
"spotpython tuning: 0.0 [##########] 100.00% Done...\n",
7153-
"\n"
7154-
]
7155-
},
7156-
{
7157-
"data": {
7158-
"text/plain": [
7159-
"<spotpython.spot.spot.Spot at 0x177589760>"
7160-
]
7161-
},
7162-
"execution_count": 1,
7163-
"metadata": {},
7164-
"output_type": "execute_result"
7165-
}
7166-
],
7136+
"outputs": [],
71677137
"source": [
71687138
"import numpy as np\n",
71697139
"from spotpython import Analytical\n",
@@ -7213,12 +7183,192 @@
72137183
"S.print_results()"
72147184
]
72157185
},
7186+
{
7187+
"cell_type": "markdown",
7188+
"metadata": {},
7189+
"source": [
7190+
"## initialize_design()"
7191+
]
7192+
},
72167193
{
72177194
"cell_type": "code",
72187195
"execution_count": null,
72197196
"metadata": {},
72207197
"outputs": [],
7221-
"source": []
7198+
"source": [
7199+
"import numpy as np\n",
7200+
"from spotpython.fun.objectivefunctions import Analytical\n",
7201+
"from spotpython.spot import spot\n",
7202+
"from spotpython.utils.init import (\n",
7203+
" fun_control_init, design_control_init\n",
7204+
" )\n",
7205+
"# number of initial points:\n",
7206+
"ni = 7\n",
7207+
"# start point X_0\n",
7208+
"X_start = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])\n",
7209+
"fun = Analytical().fun_sphere\n",
7210+
"fun_control = fun_control_init(\n",
7211+
" lower = np.array([-1, -1]),\n",
7212+
" upper = np.array([1, 1]))\n",
7213+
"design_control=design_control_init(init_size=ni)\n",
7214+
"S = spot.Spot(fun=fun,\n",
7215+
" fun_control=fun_control,\n",
7216+
" design_control=design_control,)\n",
7217+
"S.initialize_design(X_start=X_start)\n",
7218+
"print(f\"S.X: {S.X}\")\n",
7219+
"print(f\"S.y: {S.y}\")"
7220+
]
7221+
},
7222+
{
7223+
"cell_type": "markdown",
7224+
"metadata": {},
7225+
"source": [
7226+
"## write_tensorboard_log()"
7227+
]
7228+
},
7229+
{
7230+
"cell_type": "code",
7231+
"execution_count": null,
7232+
"metadata": {},
7233+
"outputs": [],
7234+
"source": [
7235+
"import numpy as np\n",
7236+
"from spotpython.fun import Analytical\n",
7237+
"from spotpython.spot import Spot\n",
7238+
"from spotpython.utils.init import fun_control_init\n",
7239+
"fun_control = fun_control_init(\n",
7240+
" tensorboard_log=True,\n",
7241+
" TENSORBOARD_CLEAN=True,\n",
7242+
" lower = np.array([-1]),\n",
7243+
" upper = np.array([1])\n",
7244+
" )\n",
7245+
"fun = Analytical().fun_sphere\n",
7246+
"\n",
7247+
"S = Spot(fun=fun,\n",
7248+
" fun_control=fun_control,\n",
7249+
" )\n",
7250+
"S.initialize_design()\n",
7251+
"S.write_tensorboard_log()"
7252+
]
7253+
},
7254+
{
7255+
"cell_type": "markdown",
7256+
"metadata": {},
7257+
"source": [
7258+
"## initialize_design_matrix()"
7259+
]
7260+
},
7261+
{
7262+
"cell_type": "code",
7263+
"execution_count": 2,
7264+
"metadata": {},
7265+
"outputs": [
7266+
{
7267+
"name": "stderr",
7268+
"output_type": "stream",
7269+
"text": [
7270+
"Seed set to 123\n"
7271+
]
7272+
},
7273+
{
7274+
"name": "stdout",
7275+
"output_type": "stream",
7276+
"text": [
7277+
"Design matrix: [[ 0.1 0.2 ]\n",
7278+
" [ 0.3 0.4 ]\n",
7279+
" [ 0.86352963 0.7892358 ]\n",
7280+
" [-0.24407197 -0.83687436]\n",
7281+
" [ 0.36481882 0.8375811 ]\n",
7282+
" [ 0.415331 0.54468512]\n",
7283+
" [-0.56395091 -0.77797854]\n",
7284+
" [-0.90259409 -0.04899292]\n",
7285+
" [-0.16484832 0.35724741]\n",
7286+
" [ 0.05170659 0.07401196]\n",
7287+
" [-0.78548145 -0.44638164]\n",
7288+
" [ 0.64017497 -0.30363301]]\n"
7289+
]
7290+
}
7291+
],
7292+
"source": [
7293+
"import numpy as np\n",
7294+
"from spotpython.fun import Analytical\n",
7295+
"from spotpython.spot import Spot\n",
7296+
"from spotpython.utils.init import fun_control_init\n",
7297+
"fun_control = fun_control_init(\n",
7298+
" lower = np.array([-1, -1]),\n",
7299+
" upper = np.array([1, 1])\n",
7300+
" )\n",
7301+
"fun = Analytical().fun_sphere\n",
7302+
"\n",
7303+
"S = Spot(fun=fun,\n",
7304+
" fun_control=fun_control,\n",
7305+
" )\n",
7306+
"X_start = np.array([[0.1, 0.2], [0.3, 0.4]])\n",
7307+
"S.initialize_design_matrix(X_start)\n",
7308+
"print(f\"Design matrix: {S.X}\")"
7309+
]
7310+
},
7311+
{
7312+
"cell_type": "markdown",
7313+
"metadata": {},
7314+
"source": [
7315+
"## evaluate_initial_design()"
7316+
]
7317+
},
7318+
{
7319+
"cell_type": "code",
7320+
"execution_count": 1,
7321+
"metadata": {},
7322+
"outputs": [
7323+
{
7324+
"name": "stderr",
7325+
"output_type": "stream",
7326+
"text": [
7327+
"Seed set to 123\n",
7328+
"Seed set to 123\n"
7329+
]
7330+
},
7331+
{
7332+
"name": "stdout",
7333+
"output_type": "stream",
7334+
"text": [
7335+
"S.X: [[ 0. 0. ]\n",
7336+
" [ 0. 1. ]\n",
7337+
" [ 1. 0. ]\n",
7338+
" [ 1. 1. ]\n",
7339+
" [ 0.86352963 0.7892358 ]\n",
7340+
" [-0.24407197 -0.83687436]\n",
7341+
" [ 0.36481882 0.8375811 ]\n",
7342+
" [ 0.415331 0.54468512]\n",
7343+
" [-0.56395091 -0.77797854]\n",
7344+
" [-0.90259409 -0.04899292]\n",
7345+
" [-0.16484832 0.35724741]\n",
7346+
" [ 0.05170659 0.07401196]\n",
7347+
" [-0.78548145 -0.44638164]\n",
7348+
" [ 0.64017497 -0.30363301]]\n",
7349+
"S.y: [0. 1. 1. 2. 1.36857656 0.75992983\n",
7350+
" 0.83463487 0.46918172 0.92329124 0.8170764 0.15480068 0.00815134\n",
7351+
" 0.81623768 0.502017 ]\n"
7352+
]
7353+
}
7354+
],
7355+
"source": [
7356+
"import numpy as np\n",
7357+
"from spotpython.fun.objectivefunctions import Analytical\n",
7358+
"from spotpython.spot import Spot\n",
7359+
"from spotpython.utils.init import fun_control_init\n",
7360+
"fun_control = fun_control_init(\n",
7361+
" lower=np.array([-1, -1]),\n",
7362+
" upper=np.array([1, 1])\n",
7363+
")\n",
7364+
"fun = Analytical().fun_sphere\n",
7365+
"S = Spot(fun=fun, fun_control=fun_control)\n",
7366+
"X0 = np.array([[0, 0], [0, 1], [1, 0], [1, 1]])\n",
7367+
"S.initialize_design_matrix(X_start=X0)\n",
7368+
"S.evaluate_initial_design()\n",
7369+
"print(f\"S.X: {S.X}\")\n",
7370+
"print(f\"S.y: {S.y}\")"
7371+
]
72227372
}
72237373
],
72247374
"metadata": {

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.21.2"
10+
version = "0.21.3"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/fun/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
"""
2+
This module implements the objective functions.
3+
4+
"""
5+
6+
from .objectivefunctions import Analytical
7+
from .hyperlight import HyperLight
8+
from .hypersklearn import HyperSklearn
9+
from .hypertorch import HyperTorch
10+
11+
__all__ = ["Analytical", "HyperLight", "HyperSklearn", "HyperTorch"]

src/spotpython/fun/objectivefunctions.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,7 +387,9 @@ def fun_wingwt(self, X: np.ndarray, fun_control: Optional[Dict] = None) -> np.nd
387387
the wing weight \( W \) using the following formula:
388388
389389
\[
390-
W = 0.036 \times S_W^{0.758} \times W_{fw}^{0.0035} \times \left( \frac{A}{\cos^2 \Lambda} \right)^{0.6} \times q^{0.006} \times \lambda^{0.04} \times \left( \frac{100 \times R_{tc}}{\cos \Lambda} \right)^{-0.3} \times (N_z \times W_{dg})^{0.49} + S_W \times W_p
390+
W = 0.036 \times S_W^{0.758} \times W_{fw}^{0.0035} \times \left( \frac{A}{\cos^2 \Lambda} \right)^{0.6}
391+
\times q^{0.006} \times \lambda^{0.04} \times \left( \frac{100 \times R_{tc}}{\cos \Lambda} \right)^{-0.3}
392+
\times (N_z \times W_{dg})^{0.49} + S_W \times W_p
391393
\]
392394
393395
where:

src/spotpython/plot/xai.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@
1313
from spotpython.light.loadmodel import load_light_from_checkpoint
1414
from spotpython.utils.classes import get_removed_attributes_and_base_net
1515
import pandas as pd
16-
from captum.attr import LayerConductance, LayerActivation, LayerIntegratedGradients
17-
from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation
16+
from captum.attr import LayerConductance
17+
from captum.attr import IntegratedGradients, DeepLift, GradientShap, FeatureAblation
1818
from matplotlib.ticker import MaxNLocator
1919
from spotpython.data.lightdatamodule import LightDataModule
2020
from spotpython.torch.dimensions import extract_linear_dims
@@ -1040,7 +1040,8 @@ def viz_net(
10401040
show_attrs (bool, optional):
10411041
whether to display non-tensor attributes of backward nodes (Requires PyTorch version >= 1.9)
10421042
show_saved (bool, optional):
1043-
whether to display saved tensor nodes that are not by custom autograd functions. Saved tensor nodes for custom functions, if present, are always displayed. (Requires PyTorch version >= 1.9)
1043+
whether to display saved tensor nodes that are not by custom autograd functions. Saved tensor nodes for custom functions, if present, are always displayed.
1044+
(Requires PyTorch version >= 1.9)
10441045
max_attr_chars (int, optional):
10451046
if show_attrs is True, sets max number of characters to display for any given attribute. Defaults to 50.
10461047
filename (str, optional):

src/spotpython/spot/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""
2+
This module implements the Spot class.
3+
4+
"""
5+
6+
from .spot import Spot
7+
8+
__all__ = [
9+
"Spot",
10+
]

0 commit comments

Comments
 (0)