Skip to content

Commit 1973616

Browse files
0.31.5
1 parent a49c3c7 commit 1973616

3 files changed

Lines changed: 225 additions & 97 deletions

File tree

notebooks/00_spotPython_tests.ipynb

Lines changed: 65 additions & 48 deletions
Large diffs are not rendered by default.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ build-backend = "setuptools.build_meta"
77

88
[project]
99
name = "spotpython"
10-
version = "0.31.4"
10+
version = "0.31.5"
1111
authors = [
1212
{ name="T. Bartz-Beielstein", email="tbb@bartzundbartz.de" }
1313
]

src/spotpython/surrogate/plot.py

Lines changed: 159 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,50 @@ def generate_mesh_grid(X: np.ndarray, i: int, j: int, n_grid: int = 100):
9191
return X_i, X_j, grid_points
9292

9393

94-
def plot_values(
94+
def simple_error_color(z_actual: float, z_predicted: float, eps: float = 1e-3) -> str:
95+
"""
96+
Returns a color string based on the error between actual and predicted values.
97+
98+
Args:
99+
z_actual (float): The actual value.
100+
z_predicted (float): The predicted value.
101+
eps (float): Tolerance for considering values as close. Default is 1e-3.
102+
103+
Returns:
104+
str: "black" if actual > predicted + eps,
105+
"white" if actual < predicted - eps,
106+
"grey" otherwise.
107+
"""
108+
# predicted value is smaller than actual value
109+
if z_actual > z_predicted + eps:
110+
return "black"
111+
# predicted value is larger than actual value
112+
elif z_actual < z_predicted - eps:
113+
return "white"
114+
# predicted value is close to actual value
115+
else:
116+
return "grey"
117+
118+
119+
def error_color(z_actual: float, z_predicted: float, eps: float = 1e-4, max_error: float = 1e-3) -> str:
120+
"""
121+
Returns a grayscale color string based on the error between actual and predicted values.
122+
Underprediction (z_predicted < z_actual) is black, overprediction is white, zero error is gray.
123+
The mapping is linear between -max_error (black) and +max_error (white).
124+
"""
125+
diff = z_predicted - z_actual
126+
if abs(diff) <= eps:
127+
scale = 0.5 # gray
128+
else:
129+
# Clamp diff to [-max_error, max_error]
130+
diff = max(-max_error, min(diff, max_error))
131+
scale = 0.5 + 0.5 * (diff / max_error)
132+
scale = min(max(scale, 0.0), 1.0)
133+
grey = int(scale * 255)
134+
return f"#{grey:02x}{grey:02x}{grey:02x}"
135+
136+
137+
def plot_3d_surface(
95138
ax: "matplotlib.axes.Axes",
96139
X: np.ndarray,
97140
y: np.ndarray,
@@ -103,7 +146,8 @@ def plot_values(
103146
zlabel: str = "Prediction",
104147
var_names: Optional[List[str]] = None,
105148
alpha: float = 0.8,
106-
eps: float = 1e-3,
149+
eps: float = 1e-4,
150+
max_error: float = 1e-3,
107151
cmap: str = "jet",
108152
error_surface: bool = False,
109153
) -> None:
@@ -123,6 +167,7 @@ def plot_values(
123167
var_names (list of str or None): List of axis labels or None.
124168
alpha (float): Surface transparency.
125169
eps (float): Tolerance for error coloring.
170+
max_error (float): Maximum error for color scaling.
126171
cmap (str): Colormap for the surface.
127172
error_surface (bool): If True, scatter z is abs(y_actual - y_predicted).
128173
@@ -143,26 +188,76 @@ def plot_values(
143188
z_scatter = abs(z_actual - z_predicted)
144189
else:
145190
z_scatter = z_actual
146-
if z_actual > z_predicted + eps:
147-
color = "red"
148-
elif z_actual < z_predicted - eps:
149-
color = "green"
150-
else:
151-
color = "white"
191+
color = error_color(z_actual, z_predicted, eps, max_error)
152192
ax.scatter(x_point, y_point, z_scatter, color=color, s=50, edgecolor="black")
153193

154194

195+
def plot_contour(
196+
ax,
197+
X_i: np.ndarray,
198+
X_j: np.ndarray,
199+
Z: np.ndarray,
200+
X: np.ndarray,
201+
y: np.ndarray,
202+
model,
203+
i: int,
204+
j: int,
205+
eps: float = 1e-4,
206+
max_error: float = 1e-3,
207+
var_names: Optional[List[str]] = None,
208+
cmap: str = "jet",
209+
levels: int = 30,
210+
title: str = "Prediction Contour",
211+
) -> None:
212+
"""
213+
Plot a filled contour plot with scatter points colored by prediction error.
214+
215+
Args:
216+
ax (matplotlib.axes.Axes): The matplotlib axis to plot on.
217+
X_i (np.ndarray): Meshgrid for the i-th dimension.
218+
X_j (np.ndarray): Meshgrid for the j-th dimension.
219+
Z (np.ndarray): Contour values (predicted or error), shape matching meshgrid.
220+
X (np.ndarray): Input data, shape (n_samples, k).
221+
y (np.ndarray): Target values, shape (n_samples,).
222+
model (object): Fitted model with predict().
223+
i (int): Index of first varied dimension.
224+
j (int): Index of second varied dimension.
225+
eps (float): Tolerance for coloring points based on prediction error.
226+
max_error (float): Maximum error for color scaling.
227+
var_names (list of str or None): List of axis labels or None.
228+
cmap (str): Colormap for the contour plot.
229+
levels (int): Number of contour levels.
230+
title (str): Title for the plot.
231+
232+
Returns:
233+
None
234+
"""
235+
contour = ax.contourf(X_i, X_j, Z, cmap=cmap, levels=levels)
236+
plt.colorbar(contour, ax=ax)
237+
for idx in range(X.shape[0]):
238+
x_point = X[idx, i]
239+
y_point = X[idx, j]
240+
z_actual = y[idx]
241+
z_predicted = model.predict(X[idx].reshape(1, -1))[0]
242+
color = error_color(z_actual, z_predicted, eps, max_error)
243+
ax.scatter(x_point, y_point, color=color, s=50, edgecolor="black")
244+
ax.set_title(title)
245+
ax.set_xlabel(var_names[0] if var_names else f"Dimension {i}")
246+
ax.set_ylabel(var_names[1] if var_names else f"Dimension {j}")
247+
248+
155249
def plotkd(
156250
model,
157251
X: np.ndarray,
158252
y: np.ndarray,
159253
i: int = 0,
160254
j: int = 1,
161255
show: Optional[bool] = True,
162-
alpha=0.8,
163-
eps=1e-3,
256+
alpha: float = 0.8,
257+
eps: float = 1e-4,
258+
max_error: float = 1e-3,
164259
var_names: Optional[List[str]] = None,
165-
cmap="jet",
260+
cmap: str = "jet",
166261
n_grid: int = 100,
167262
) -> None:
168263
"""
@@ -176,10 +271,24 @@ def plotkd(
176271
j (int): Index of the second dimension to vary. Default is 1.
177272
show (bool): If True, displays the plot. Default is True.
178273
alpha (float): Transparency of the surface plot. Default is 0.8.
179-
eps (float): Tolerance for coloring points based on prediction error. Default is 1e-3.
274+
eps (float): Tolerance for coloring points based on prediction error. Default is 1e-4.
275+
max_error (float): Maximum error for color scaling. Default is 1e-3.
180276
var_names (list of str, optional): List of variable names for axis labeling. If None, generic labels are used.
181277
cmap (str): Colormap for the surface and contour plots. Default is "jet".
182278
n_grid (int): Number of grid points per dimension for the mesh grid. Default is 100.
279+
280+
Examples:
281+
>>> import numpy as np
282+
>>> from spotpython.surrogate.kriging import Kriging
283+
>>> from spotpython.surrogate.plot import plotkd
284+
>>> # Training data
285+
>>> X_train = np.random.rand(100, 3) # 100 samples with 3 dimensions
286+
>>> y_train = np.sin(X_train[:, 0]) + np.cos(X_train[:, 1]) + X_train[:, 2] # Example target function
287+
>>> # Initialize and fit the Kriging model
288+
>>> model = Kriging().fit(X_train, y_train)
289+
>>> # Plot the Kriging surrogate for dimensions 0 and 1
290+
>>> plotkd(model, X_train, y_train, i=0, j=1, show=True)
291+
183292
"""
184293
k = X.shape[1]
185294
if i >= k or j >= k:
@@ -198,7 +307,7 @@ def plotkd(
198307

199308
# Plot predicted values
200309
ax1 = fig.add_subplot(221, projection="3d")
201-
plot_values(
310+
plot_3d_surface(
202311
ax1,
203312
X,
204313
y,
@@ -211,13 +320,14 @@ def plotkd(
211320
var_names=var_names,
212321
alpha=alpha,
213322
eps=eps,
323+
max_error=max_error,
214324
cmap=cmap,
215325
error_surface=False,
216326
)
217327

218328
# Plot prediction error
219329
ax2 = fig.add_subplot(222, projection="3d")
220-
plot_values(
330+
plot_3d_surface(
221331
ax2,
222332
X,
223333
y,
@@ -230,49 +340,50 @@ def plotkd(
230340
var_names=var_names,
231341
alpha=alpha,
232342
eps=eps,
343+
max_error=max_error,
233344
cmap=cmap,
234345
error_surface=True,
235346
)
236347

237348
# Contour plot of predicted values
238349
ax3 = fig.add_subplot(223)
239-
contour = ax3.contourf(X_i, X_j, Z_pred, cmap=cmap, levels=30)
240-
plt.colorbar(contour, ax=ax3)
241-
for idx in range(X.shape[0]):
242-
x_point = X[idx, i]
243-
y_point = X[idx, j]
244-
z_actual = y[idx]
245-
z_predicted = model.predict(X[idx].reshape(1, -1))[0]
246-
if z_actual > z_predicted + eps:
247-
color = "red"
248-
elif z_actual < z_predicted - eps:
249-
color = "green"
250-
else:
251-
color = "white"
252-
ax3.scatter(x_point, y_point, color=color, s=50, edgecolor="black")
253-
ax3.set_title("Prediction Contour")
254-
ax3.set_xlabel(var_names[0] if var_names else f"Dimension {i}")
255-
ax3.set_ylabel(var_names[1] if var_names else f"Dimension {j}")
350+
plot_contour(
351+
ax3,
352+
X_i,
353+
X_j,
354+
Z_pred,
355+
X,
356+
y,
357+
model,
358+
i,
359+
j,
360+
eps=eps,
361+
max_error=max_error,
362+
var_names=var_names,
363+
cmap=cmap,
364+
levels=30,
365+
title="Prediction Contour",
366+
)
256367

257368
# Contour plot of prediction error
258369
ax4 = fig.add_subplot(224)
259-
contour = ax4.contourf(X_i, X_j, Z_std, cmap=cmap, levels=30)
260-
plt.colorbar(contour, ax=ax4)
261-
for idx in range(X.shape[0]):
262-
x_point = X[idx, i]
263-
y_point = X[idx, j]
264-
z_actual = y[idx]
265-
z_predicted = model.predict(X[idx].reshape(1, -1))[0]
266-
if z_actual > z_predicted + eps:
267-
color = "red"
268-
elif z_actual < z_predicted - eps:
269-
color = "green"
270-
else:
271-
color = "white"
272-
ax4.scatter(x_point, y_point, color=color, s=50, edgecolor="black")
273-
ax4.set_title("Error Contour")
274-
ax4.set_xlabel(var_names[0] if var_names else f"Dimension {i}")
275-
ax4.set_ylabel(var_names[1] if var_names else f"Dimension {j}")
370+
plot_contour(
371+
ax4,
372+
X_i,
373+
X_j,
374+
Z_std,
375+
X,
376+
y,
377+
model,
378+
i,
379+
j,
380+
eps=eps,
381+
max_error=max_error,
382+
var_names=var_names,
383+
cmap=cmap,
384+
levels=30,
385+
title="Error Contour",
386+
)
276387

277388
if show:
278389
plt.show()

0 commit comments

Comments
 (0)