@@ -91,7 +91,50 @@ def generate_mesh_grid(X: np.ndarray, i: int, j: int, n_grid: int = 100):
9191 return X_i , X_j , grid_points
9292
9393
94- def plot_values (
94+ def simple_error_color (z_actual : float , z_predicted : float , eps : float = 1e-3 ) -> str :
95+ """
96+ Returns a color string based on the error between actual and predicted values.
97+
98+ Args:
99+ z_actual (float): The actual value.
100+ z_predicted (float): The predicted value.
101+ eps (float): Tolerance for considering values as close. Default is 1e-3.
102+
103+ Returns:
104+ str: "black" if actual > predicted + eps,
105+ "white" if actual < predicted - eps,
106+ "grey" otherwise.
107+ """
108+ # predicted value is smaller than actual value
109+ if z_actual > z_predicted + eps :
110+ return "black"
111+ # predicted value is larger than actual value
112+ elif z_actual < z_predicted - eps :
113+ return "white"
114+ # predicted value is close to actual value
115+ else :
116+ return "grey"
117+
118+
119+ def error_color (z_actual : float , z_predicted : float , eps : float = 1e-4 , max_error : float = 1e-3 ) -> str :
120+ """
121+ Returns a grayscale color string based on the error between actual and predicted values.
122+ Underprediction (z_predicted < z_actual) is black, overprediction is white, zero error is gray.
123+ The mapping is linear between -max_error (black) and +max_error (white).
124+ """
125+ diff = z_predicted - z_actual
126+ if abs (diff ) <= eps :
127+ scale = 0.5 # gray
128+ else :
129+ # Clamp diff to [-max_error, max_error]
130+ diff = max (- max_error , min (diff , max_error ))
131+ scale = 0.5 + 0.5 * (diff / max_error )
132+ scale = min (max (scale , 0.0 ), 1.0 )
133+ grey = int (scale * 255 )
134+ return f"#{ grey :02x} { grey :02x} { grey :02x} "
135+
136+
137+ def plot_3d_surface (
95138 ax : "matplotlib.axes.Axes" ,
96139 X : np .ndarray ,
97140 y : np .ndarray ,
@@ -103,7 +146,8 @@ def plot_values(
103146 zlabel : str = "Prediction" ,
104147 var_names : Optional [List [str ]] = None ,
105148 alpha : float = 0.8 ,
106- eps : float = 1e-3 ,
149+ eps : float = 1e-4 ,
150+ max_error : float = 1e-3 ,
107151 cmap : str = "jet" ,
108152 error_surface : bool = False ,
109153) -> None :
@@ -123,6 +167,7 @@ def plot_values(
123167 var_names (list of str or None): List of axis labels or None.
124168 alpha (float): Surface transparency.
125169 eps (float): Tolerance for error coloring.
170+ max_error (float): Maximum error for color scaling.
126171 cmap (str): Colormap for the surface.
127172 error_surface (bool): If True, scatter z is abs(y_actual - y_predicted).
128173
@@ -143,26 +188,76 @@ def plot_values(
143188 z_scatter = abs (z_actual - z_predicted )
144189 else :
145190 z_scatter = z_actual
146- if z_actual > z_predicted + eps :
147- color = "red"
148- elif z_actual < z_predicted - eps :
149- color = "green"
150- else :
151- color = "white"
191+ color = error_color (z_actual , z_predicted , eps , max_error )
152192 ax .scatter (x_point , y_point , z_scatter , color = color , s = 50 , edgecolor = "black" )
153193
154194
195+ def plot_contour (
196+ ax ,
197+ X_i : np .ndarray ,
198+ X_j : np .ndarray ,
199+ Z : np .ndarray ,
200+ X : np .ndarray ,
201+ y : np .ndarray ,
202+ model ,
203+ i : int ,
204+ j : int ,
205+ eps : float = 1e-4 ,
206+ max_error : float = 1e-3 ,
207+ var_names : Optional [List [str ]] = None ,
208+ cmap : str = "jet" ,
209+ levels : int = 30 ,
210+ title : str = "Prediction Contour" ,
211+ ) -> None :
212+ """
213+ Plot a filled contour plot with scatter points colored by prediction error.
214+
215+ Args:
216+ ax (matplotlib.axes.Axes): The matplotlib axis to plot on.
217+ X_i (np.ndarray): Meshgrid for the i-th dimension.
218+ X_j (np.ndarray): Meshgrid for the j-th dimension.
219+ Z (np.ndarray): Contour values (predicted or error), shape matching meshgrid.
220+ X (np.ndarray): Input data, shape (n_samples, k).
221+ y (np.ndarray): Target values, shape (n_samples,).
222+ model (object): Fitted model with predict().
223+ i (int): Index of first varied dimension.
224+ j (int): Index of second varied dimension.
225+ eps (float): Tolerance for coloring points based on prediction error.
226+ max_error (float): Maximum error for color scaling.
227+ var_names (list of str or None): List of axis labels or None.
228+ cmap (str): Colormap for the contour plot.
229+ levels (int): Number of contour levels.
230+ title (str): Title for the plot.
231+
232+ Returns:
233+ None
234+ """
235+ contour = ax .contourf (X_i , X_j , Z , cmap = cmap , levels = levels )
236+ plt .colorbar (contour , ax = ax )
237+ for idx in range (X .shape [0 ]):
238+ x_point = X [idx , i ]
239+ y_point = X [idx , j ]
240+ z_actual = y [idx ]
241+ z_predicted = model .predict (X [idx ].reshape (1 , - 1 ))[0 ]
242+ color = error_color (z_actual , z_predicted , eps , max_error )
243+ ax .scatter (x_point , y_point , color = color , s = 50 , edgecolor = "black" )
244+ ax .set_title (title )
245+ ax .set_xlabel (var_names [0 ] if var_names else f"Dimension { i } " )
246+ ax .set_ylabel (var_names [1 ] if var_names else f"Dimension { j } " )
247+
248+
155249def plotkd (
156250 model ,
157251 X : np .ndarray ,
158252 y : np .ndarray ,
159253 i : int = 0 ,
160254 j : int = 1 ,
161255 show : Optional [bool ] = True ,
162- alpha = 0.8 ,
163- eps = 1e-3 ,
256+ alpha : float = 0.8 ,
257+ eps : float = 1e-4 ,
258+ max_error : float = 1e-3 ,
164259 var_names : Optional [List [str ]] = None ,
165- cmap = "jet" ,
260+ cmap : str = "jet" ,
166261 n_grid : int = 100 ,
167262) -> None :
168263 """
@@ -176,10 +271,24 @@ def plotkd(
176271 j (int): Index of the second dimension to vary. Default is 1.
177272 show (bool): If True, displays the plot. Default is True.
178273 alpha (float): Transparency of the surface plot. Default is 0.8.
179- eps (float): Tolerance for coloring points based on prediction error. Default is 1e-3.
274+ eps (float): Tolerance for coloring points based on prediction error. Default is 1e-4.
275+ max_error (float): Maximum error for color scaling. Default is 1e-3.
180276 var_names (list of str, optional): List of variable names for axis labeling. If None, generic labels are used.
181277 cmap (str): Colormap for the surface and contour plots. Default is "jet".
182278 n_grid (int): Number of grid points per dimension for the mesh grid. Default is 100.
279+
280+ Examples:
281+ >>> import numpy as np
282+ >>> from spotpython.surrogate.kriging import Kriging
283+ >>> from spotpython.surrogate.plot import plotkd
284+ >>> # Training data
285+ >>> X_train = np.random.rand(100, 3) # 100 samples with 3 dimensions
286+ >>> y_train = np.sin(X_train[:, 0]) + np.cos(X_train[:, 1]) + X_train[:, 2] # Example target function
287+ >>> # Initialize and fit the Kriging model
288+ >>> model = Kriging().fit(X_train, y_train)
289+ >>> # Plot the Kriging surrogate for dimensions 0 and 1
290+ >>> plotkd(model, X_train, y_train, i=0, j=1, show=True)
291+
183292 """
184293 k = X .shape [1 ]
185294 if i >= k or j >= k :
@@ -198,7 +307,7 @@ def plotkd(
198307
199308 # Plot predicted values
200309 ax1 = fig .add_subplot (221 , projection = "3d" )
201- plot_values (
310+ plot_3d_surface (
202311 ax1 ,
203312 X ,
204313 y ,
@@ -211,13 +320,14 @@ def plotkd(
211320 var_names = var_names ,
212321 alpha = alpha ,
213322 eps = eps ,
323+ max_error = max_error ,
214324 cmap = cmap ,
215325 error_surface = False ,
216326 )
217327
218328 # Plot prediction error
219329 ax2 = fig .add_subplot (222 , projection = "3d" )
220- plot_values (
330+ plot_3d_surface (
221331 ax2 ,
222332 X ,
223333 y ,
@@ -230,49 +340,50 @@ def plotkd(
230340 var_names = var_names ,
231341 alpha = alpha ,
232342 eps = eps ,
343+ max_error = max_error ,
233344 cmap = cmap ,
234345 error_surface = True ,
235346 )
236347
237348 # Contour plot of predicted values
238349 ax3 = fig .add_subplot (223 )
239- contour = ax3 . contourf ( X_i , X_j , Z_pred , cmap = cmap , levels = 30 )
240- plt . colorbar ( contour , ax = ax3 )
241- for idx in range ( X . shape [ 0 ]):
242- x_point = X [ idx , i ]
243- y_point = X [ idx , j ]
244- z_actual = y [ idx ]
245- z_predicted = model . predict ( X [ idx ]. reshape ( 1 , - 1 ))[ 0 ]
246- if z_actual > z_predicted + eps :
247- color = "red"
248- elif z_actual < z_predicted - eps :
249- color = "green"
250- else :
251- color = "white"
252- ax3 . scatter ( x_point , y_point , color = color , s = 50 , edgecolor = "black" )
253- ax3 . set_title ( "Prediction Contour" )
254- ax3 . set_xlabel ( var_names [ 0 ] if var_names else f"Dimension { i } " )
255- ax3 . set_ylabel ( var_names [ 1 ] if var_names else f"Dimension { j } " )
350+ plot_contour (
351+ ax3 ,
352+ X_i ,
353+ X_j ,
354+ Z_pred ,
355+ X ,
356+ y ,
357+ model ,
358+ i ,
359+ j ,
360+ eps = eps ,
361+ max_error = max_error ,
362+ var_names = var_names ,
363+ cmap = cmap ,
364+ levels = 30 ,
365+ title = "Prediction Contour" ,
366+ )
256367
257368 # Contour plot of prediction error
258369 ax4 = fig .add_subplot (224 )
259- contour = ax4 . contourf ( X_i , X_j , Z_std , cmap = cmap , levels = 30 )
260- plt . colorbar ( contour , ax = ax4 )
261- for idx in range ( X . shape [ 0 ]):
262- x_point = X [ idx , i ]
263- y_point = X [ idx , j ]
264- z_actual = y [ idx ]
265- z_predicted = model . predict ( X [ idx ]. reshape ( 1 , - 1 ))[ 0 ]
266- if z_actual > z_predicted + eps :
267- color = "red"
268- elif z_actual < z_predicted - eps :
269- color = "green"
270- else :
271- color = "white"
272- ax4 . scatter ( x_point , y_point , color = color , s = 50 , edgecolor = "black" )
273- ax4 . set_title ( "Error Contour" )
274- ax4 . set_xlabel ( var_names [ 0 ] if var_names else f"Dimension { i } " )
275- ax4 . set_ylabel ( var_names [ 1 ] if var_names else f"Dimension { j } " )
370+ plot_contour (
371+ ax4 ,
372+ X_i ,
373+ X_j ,
374+ Z_std ,
375+ X ,
376+ y ,
377+ model ,
378+ i ,
379+ j ,
380+ eps = eps ,
381+ max_error = max_error ,
382+ var_names = var_names ,
383+ cmap = cmap ,
384+ levels = 30 ,
385+ title = "Error Contour" ,
386+ )
276387
277388 if show :
278389 plt .show ()
0 commit comments