3636from spotpython .utils .convert import get_shape
3737from spotpython .utils .init import fun_control_init , optimizer_control_init , surrogate_control_init , design_control_init
3838from spotpython .utils .compare import selectNew
39- from spotpython .utils .aggregate import aggregate_mean_var , select_distant_points
39+ from spotpython .utils .aggregate import aggregate_mean_var , select_distant_points , select_best_cluster
4040from spotpython .utils .repair import remove_nan , repair_non_numeric
4141from spotpython .utils .file import get_experiment_filename , get_result_filename
4242from spotpython .budget .ocba import get_ocba_X
@@ -215,10 +215,6 @@ def __init__(
215215 # Kernel selection from fun_control (NEW)
216216 self .kernel = None
217217 self .kernel_params = None
218- if fun_control is not None :
219- self .kernel = fun_control .get ("kernel" , "gauss" )
220- self .kernel_params = fun_control .get ("kernel_params" , {})
221-
222218 self .counter = 0
223219 self .success_rate = 0.0
224220 self .success_counter = 0
@@ -227,6 +223,7 @@ def __init__(
227223 # small value:
228224 self .eps = sqrt (spacing (1 ))
229225
226+ self .selection_method = "distant" # or "best"
230227 self ._set_fun (fun )
231228
232229 self ._set_bounds_and_dim ()
@@ -449,11 +446,16 @@ def _set_additional_attributes(self) -> None:
449446 self .progress_file = self .fun_control ["progress_file" ]
450447 self .tkagg = self .fun_control ["tkagg" ]
451448 self .min_success_rate = self .fun_control ["min_success_rate" ]
449+ self .selection_method = self .fun_control ["selection_method" ]
452450 # self.success_counter = 0
453451 if self .tkagg :
454452 matplotlib .use ("TkAgg" )
455453 self .verbosity = self .fun_control ["verbosity" ]
456454 self .acquisition_failure_strategy = self .fun_control ["acquisition_failure_strategy" ]
455+ self .kernel = self .fun_control ["kernel" ]
456+ self .kernel_params = self .fun_control ["kernel_params" ]
457+
458+ # Surrogate control attributes:
457459 self .max_surrogate_points = self .surrogate_control ["max_surrogate_points" ]
458460 self .use_nystrom = self .surrogate_control ["use_nystrom" ]
459461 self .nystrom_m = self .surrogate_control ["nystrom_m" ]
@@ -1640,6 +1642,57 @@ def update_stats(self) -> None:
16401642 # variance of the best mean y value so far:
16411643 self .min_var_y = self .var_y [argmin (self .mean_y )]
16421644
1645+ def selection_dispatcher (self ) -> Tuple [np .ndarray , np .ndarray ]:
1646+ """
1647+ Dispatcher for selection methods.
1648+ Depending on the value of `self.selection_method`,
1649+ it calls the appropriate selection function.
1650+ Args:
1651+ self (object): Spot object
1652+ Returns:
1653+ Tuple[numpy.ndarray, numpy.ndarray]:
1654+ selected design points and their corresponding function values
1655+ Attributes:
1656+ self.selection_method (str):
1657+ selection method to use
1658+ self.X (numpy.ndarray):
1659+ design points
1660+ self.y (numpy.ndarray):
1661+ function values
1662+ self.max_surrogate_points (int):
1663+ maximum number of points to select
1664+ Examples:
1665+ >>> import numpy as np
1666+ from spotpython.fun.objectivefunctions import Analytical
1667+ from spotpython.spot import spot
1668+ from spotpython.utils.init import (
1669+ fun_control_init, optimizer_control_init, surrogate_control_init, design_control_init
1670+ )
1671+ # number of initial points:
1672+ ni = 0
1673+ X_start = np.array([[0, 0], [0, 1], [ 1, 0], [1, 1], [1, 1]])
1674+ fun = Analytical().fun_sphere
1675+ fun_control = fun_control_init(
1676+ lower = np.array([-1, -1]),
1677+ upper = np.array([1, 1])
1678+ )
1679+ design_control=design_control_init(init_size=ni)
1680+ S = spot.Spot(fun=fun,
1681+ fun_control=fun_control,
1682+ design_control=design_control,)
1683+ S.initialize_design(X_start=X_start)
1684+ S.update_stats()
1685+ X_S, y_S = S.selection_dispatcher()
1686+ print(f"Selected X_S: {X_S}")
1687+ print(f"Selected y_S: {y_S}")
1688+ """
1689+ if self .selection_method == "distant" :
1690+ return select_distant_points (X = self .X , y = self .y , k = self .max_surrogate_points )
1691+ elif self .selection_method == "best" :
1692+ return select_best_cluster (X = self .X , y = self .y , k = self .max_surrogate_points )
1693+ # If no selection is needed, return all points:
1694+ return self .X , self .y
1695+
16431696 def fit_surrogate (self ) -> None :
16441697 """
16451698 Fit surrogate model. The surrogate model
@@ -1705,15 +1758,15 @@ def fit_surrogate(self) -> None:
17051758 logger .debug ("In fit_surrogate(): self.X.shape: %s" , self .X .shape )
17061759 logger .debug ("In fit_surrogate(): self.y.shape: %s" , self .y .shape )
17071760 # Pass kernel options to surrogate if Kriging is used
1708- if hasattr (self .surrogate , "kernel" ):
1761+ if hasattr (self .surrogate , "kernel" ) and isinstance ( self . surrogate , Kriging ) :
17091762 self .surrogate .kernel = self .kernel
17101763 self .surrogate .kernel_params = self .kernel_params
17111764 X_points = self .X .shape [0 ]
17121765 y_points = self .y .shape [0 ]
17131766 if X_points == y_points :
17141767 if (X_points > self .max_surrogate_points ) and (self .use_nystrom is False ):
1715- logger .info ("Selecting distant points for surrogate fitting." )
1716- X_S , y_S = select_distant_points ( X = self .X , y = self . y , k = self . max_surrogate_points )
1768+ logger .info ("Selecting subset of points for surrogate fitting." )
1769+ X_S , y_S = self .selection_dispatcher ( )
17171770 else :
17181771 X_S = self .X
17191772 y_S = self .y
0 commit comments