From 4b2b441e5adb6eda629b3252358ef4e4845d5aaf Mon Sep 17 00:00:00 2001 From: Sam Daulton Date: Fri, 8 May 2026 09:23:54 -0700 Subject: [PATCH] BONSAI: pin-and-project for constrained pruning (#5180) Summary: Extend BONSAI's irrelevance pruning to handle both equality and inequality constraints via a pin-and-project approach. Previously, BONSAI simply discarded pruned candidates that violated constraints. This was overly conservative (inequality) or completely broken (equality, where almost all single-dimension prunes violate the constraint). The new approach: 1. Set x_j = target[j] (unchanged) 2. Project the other dimensions onto the feasible set via SLSQP, keeping x_j pinned (and all previously pruned dims pinned) 3. Filter any candidates that remain infeasible after projection This is strictly better than discarding: it recovers feasibility when possible by adjusting other dimensions, while infeasible pins (where no adjustment can satisfy the constraints) are still caught. Key implementation details: - `_project_and_filter_pruned_candidates`: new function that uses `project_to_feasible_space_via_slsqp` with `fixed_features` to pin the pruned dim and all previously pruned dims. - Optimization: skip projection for dims not in any constraint's index set (pruning them can't violate anything). - Handles 2D inter-point constraint indices correctly. - `_prune_irrelevant_parameters` now accepts `bounds` parameter. Reviewed By: esantorella Differential Revision: D100256483 --- .../torch/botorch_modular/acquisition.py | 170 +++++++++++++++++- ax/generators/torch/tests/test_acquisition.py | 125 +++++++++++++ 2 files changed, 286 insertions(+), 9 deletions(-) diff --git a/ax/generators/torch/botorch_modular/acquisition.py b/ax/generators/torch/botorch_modular/acquisition.py index 167598cb961..97abfb68cd6 100644 --- a/ax/generators/torch/botorch_modular/acquisition.py +++ b/ax/generators/torch/botorch_modular/acquisition.py @@ -57,7 +57,11 @@ AnalyticExpectedUtilityOfBestOption, qExpectedUtilityOfBestOption, ) -from botorch.exceptions.errors import BotorchError, InputDataError +from botorch.exceptions.errors import ( + BotorchError, + CandidateGenerationError, + InputDataError, +) from botorch.generation.sampling import SamplingStrategy from botorch.models.model import Model from botorch.optim.optimize import ( @@ -72,7 +76,10 @@ optimize_acqf_mixed_alternating, should_use_mixed_alternating_optimizer, ) -from botorch.optim.parameter_constraints import evaluate_feasibility +from botorch.optim.parameter_constraints import ( + evaluate_feasibility, + project_to_feasible_space_via_slsqp, +) from botorch.utils.constraints import get_outcome_constraint_transforms from pyre_extensions import assert_is_instance, none_throws from torch import Tensor @@ -892,6 +899,7 @@ def optimize( inequality_constraints=inequality_constraints, equality_constraints=equality_constraints, fixed_features=fixed_features, + bounds=bounds, ) # Validate candidates before returning validate_candidates( @@ -1007,6 +1015,7 @@ def _prune_irrelevant_parameters( inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, fixed_features: dict[int, float] | None = None, + bounds: Tensor | None = None, ) -> tuple[Tensor, Tensor]: r"""Prune irrelevant parameters from the candidates using BONSAI. @@ -1042,6 +1051,11 @@ def _prune_irrelevant_parameters( corresponds to the `l_i`-th feature of that element. fixed_features: A map `{feature_index: value}` for features that should be fixed to a particular value during generation. + bounds: A `2 x d`-dim tensor of lower and upper parameter bounds. + Required when `inequality_constraints` or `equality_constraints` + are provided: pruned candidates are projected onto the feasible + set via SLSQP, and the projection needs the parameter bounds to + define the feasible region. Unused when no constraints are set. Returns: A two-element tuple containing an `q x d`-dim tensor of generated @@ -1085,12 +1099,14 @@ def _prune_irrelevant_parameters( # dense AF val final_af_val = dense_af_val # If the current incremental AF value is zero, then we skip pruning + has_constraints = bool(inequality_constraints or equality_constraints) if dense_incremental_af_val > 0.0: remaining_indices = set(range(candidates.shape[-1])) - excluded_indices # remove features that are already set to target_point remaining_indices -= set( (candidates[i] == target_point).nonzero().view(-1).tolist() ) + initial_remaining = set(remaining_indices) # len(remaining_indices) - 1 is used here so that we do not prune # every dimension for _ in range(len(remaining_indices) - 1): @@ -1107,13 +1123,23 @@ def _prune_irrelevant_parameters( indices=indices, targets=target_point[indices], ) - # remove candidates that violate constraints after pruning - pruned_candidates, indices = _remove_infeasible_candidates( - candidates=pruned_candidates, - indices=indices, - inequality_constraints=inequality_constraints, - equality_constraints=equality_constraints, - ) + # Project pruned candidates onto the feasible set + # (pinning the pruned dim and previously pruned dims), + # then filter any that remain infeasible. + if has_constraints: + previously_pruned = initial_remaining - remaining_indices + pruned_candidates, indices = ( + _project_and_filter_pruned_candidates( + candidates=pruned_candidates, + indices=indices, + target_point=target_point, + pruned_dims=previously_pruned, + bounds=none_throws(bounds), + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + fixed_features=fixed_features, + ) + ) if pruned_candidates.shape[0] == 0: # no feasible points, continue to # next candidate @@ -1253,3 +1279,129 @@ def _remove_infeasible_candidates( candidates = candidates[is_feasible] indices = indices[is_feasible] return candidates, indices + + +def _project_and_filter_pruned_candidates( + candidates: Tensor, + indices: Tensor, + target_point: Tensor, + pruned_dims: set[int], + bounds: Tensor, + inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None, + fixed_features: dict[int, float] | None = None, +) -> tuple[Tensor, Tensor]: + r"""Project pruned candidates onto the feasible set, then filter infeasible. + + Helper for ``Acquisition._prune_irrelevant_parameters`` (BONSAI). It is + only meaningful in the context of that greedy-pruning loop and is not + intended for standalone use. + + Background: BONSAI pruning evaluates a candidate dimension-by-dimension + by setting one dimension at a time to its target-point value. Each row + of ``candidates`` is one such trial -- the dense candidate with the + dimension at ``indices[i]`` swapped to ``target_point[indices[i]]``. + Under linear constraints, swapping a single dimension to the target + typically violates the constraints; rather than discarding the trial + (the prior behavior), we adjust the *other* free dimensions to recover + feasibility while keeping the swapped dimension and all previously + pruned dimensions pinned. Trials whose pins make the constraint system + infeasible -- and the rare case where projection succeeds but the + result still violates constraints -- are filtered out via the mask + returned to the caller. + + Args: + candidates: A ``b x 1 x d``-dim tensor of pruned candidates (one row + per single-dimension prune attempt for the current BONSAI + iteration). + indices: A ``b``-dim tensor indicating which dimension was pruned + in each batch element. + target_point: A ``d``-dim tensor of target values for pruning. + pruned_dims: Set of dimension indices already pruned in prior + greedy iterations (to be kept pinned during projection). + bounds: A ``2 x d``-dim tensor of lower and upper bounds. + inequality_constraints: Inequality constraints in BoTorch format. + equality_constraints: Equality constraints in BoTorch format. + fixed_features: A map ``{feature_index: value}`` from the caller. + These dimensions are excluded from pruning at the outer loop and + must also be pinned during projection so SLSQP cannot adjust + them while satisfying the constraints. Without this, fixed + features could be silently altered. + + Returns: + A two-element tuple of filtered ``(candidates, indices)``. + """ + # Pre-compute which dims participate in any constraint, and check whether + # any constraint is inter-point (2D index tensor). Inter-point constraints + # apply across the q-batch, but each row here is a single-candidate prune + # attempt -- ``project_to_feasible_space_via_slsqp`` cannot evaluate + # inter-point constraints on a 1 x d input. Fall back to the original + # filter-only behavior in that case. + constrained_dims: set[int] = set() + has_interpoint_constraint = False + for constraints in (inequality_constraints, equality_constraints): + if constraints is not None: + for c_indices, _, _ in constraints: + if c_indices.dim() == 1: + constrained_dims.update(c_indices.tolist()) + else: + constrained_dims.update(c_indices[:, -1].tolist()) + has_interpoint_constraint = True + if has_interpoint_constraint: + return _remove_infeasible_candidates( + candidates=candidates, + indices=indices, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + ) + + # Build fixed_features for previously pruned dims and the caller's + # fixed_features (both shared across all candidates in this iteration). + prev_fixed: dict[int, float] = {k: target_point[k].item() for k in pruned_dims} + if fixed_features is not None: + prev_fixed.update(fixed_features) + + feasible_mask = torch.ones(candidates.shape[0], dtype=torch.bool) + result = candidates.clone() + + for i in range(candidates.shape[0]): + j: int = int(indices[i].item()) + # If the pruned dim doesn't participate in any constraint, + # pruning it can't violate anything — skip projection. + if j not in constrained_dims: + continue + # Pin the currently pruned dim, all previously pruned dims, and the + # caller's fixed features. + fixed: dict[int, float | Tensor] = { + j: float(target_point[j].item()), + **prev_fixed, + } + try: + projected = project_to_feasible_space_via_slsqp( + X=candidates[i], # 1 x d + bounds=bounds, + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + fixed_features=fixed, + ) + result[i] = projected + except CandidateGenerationError: + # Pin makes the system infeasible — mark for removal. + # The post-projection feasibility check below is the safety net + # for any candidates that project but still violate constraints. + feasible_mask[i] = False + + # Final safety-net feasibility check after projection. + if feasible_mask.any(): + is_feasible = evaluate_feasibility( + X=result[feasible_mask], + inequality_constraints=inequality_constraints, + equality_constraints=equality_constraints, + ) + # Map back to the full mask. + feasible_subset_indices = feasible_mask.nonzero(as_tuple=True)[0] + for idx, feas in zip(feasible_subset_indices, is_feasible): + if not feas: + feasible_mask[idx] = False + + return result[feasible_mask], indices[feasible_mask] diff --git a/ax/generators/torch/tests/test_acquisition.py b/ax/generators/torch/tests/test_acquisition.py index d18ed99200e..c4232c21f7f 100644 --- a/ax/generators/torch/tests/test_acquisition.py +++ b/ax/generators/torch/tests/test_acquisition.py @@ -1831,6 +1831,7 @@ def test_prune_irrelevant_parameters_with_inequality_constraints(self) -> None: candidates=candidates, search_space_digest=search_space_digest, inequality_constraints=inequality_constraints, + bounds=torch.tensor([[0.0, 0.0], [1.0, 1.0]]), ) self.assertTrue(torch.equal(pruned_candidates, torch.tensor([[0.2, 0.8]]))) self.assertTrue(torch.equal(pruned_values, torch.tensor([0.91]))) @@ -1848,6 +1849,7 @@ def test_prune_irrelevant_parameters_with_inequality_constraints(self) -> None: inequality_constraints=[ (torch.tensor([0, 1]), torch.tensor([1.0, 1.0]), 1.5) ], + bounds=torch.tensor([[0.0, 0.0], [1.0, 1.0]]), ) # No pruning: setting either dim to 0.2 gives sum=1.0 < 1.5 (infeasible) self.assertTrue(torch.equal(pruned_candidates, torch.tensor([[0.8, 0.8]]))) @@ -2055,6 +2057,7 @@ def test_prune_irrelevant_parameters_with_constraints_exact_values(self) -> None 1.0, ) ], + bounds=torch.tensor([[0.0, 0.0], [1.0, 1.0]]), ) # Only dimension 0 should be pruned @@ -2062,6 +2065,128 @@ def test_prune_irrelevant_parameters_with_constraints_exact_values(self) -> None self.assertTrue(torch.equal(pruned_candidates, expected_candidate)) self.assertTrue(torch.equal(pruned_values, torch.tensor([1.0]))) + def test_prune_irrelevant_parameters_with_equality_constraints(self) -> None: + # Test pruning with an equality constraint (x1 + x2 + x3 = 1). + # When a dimension is pruned to its target, the remaining dims should + # be projected onto the equality constraint hyperplane. + search_space_digest = SearchSpaceDigest( + feature_names=["x1", "x2", "x3"], + bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0)], + ) + target_point = torch.tensor([1.0 / 3, 1.0 / 3, 1.0 / 3]) + acq = Acquisition( + surrogate=self.surrogate, + search_space_digest=search_space_digest, + torch_opt_config=dataclasses.replace( + self.torch_opt_config, + pruning_target_point=target_point, + ), + botorch_acqf_class=DummyAcquisitionFunction, + ) + mock_acqf = Mock() + mock_acqf._log = False + acq.acqf = mock_acqf + acq._instantiate_acquisition = Mock() + + # Candidate that satisfies x1 + x2 + x3 = 1. + candidates = torch.tensor([[0.5, 0.3, 0.2]]) + # Equality constraint: x1 + x2 + x3 = 1 + equality_constraints = [ + (torch.tensor([0, 1, 2]), torch.tensor([1.0, 1.0, 1.0]), 1.0) + ] + bounds = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) + + mock_evaluate = Mock( + side_effect=[ + torch.tensor([0.0]), # baseline af val + torch.tensor([1.0]), # dense af val + # After pruning dim 0 to 1/3 and projecting, the candidate + # still satisfies x1+x2+x3=1. Two pruning candidates + # (dim 1 and dim 2) survive projection. + torch.tensor([0.95, 0.90]), # pruned af vals + torch.tensor([0.93]), # second round pruned af val + ] + ) + acq.evaluate = mock_evaluate + + pruned_candidates, pruned_values = acq._prune_irrelevant_parameters( + candidates=candidates, + search_space_digest=search_space_digest, + equality_constraints=equality_constraints, + bounds=bounds, + ) + # Verify that pruning occurred and the result satisfies the constraint. + self.assertEqual(pruned_candidates.shape[-1], 3) + for i in range(pruned_candidates.shape[0]): + self.assertAlmostEqual( + pruned_candidates[i].sum().item(), + 1.0, + places=4, + ) + + def test_prune_irrelevant_parameters_fixed_features_pinned_in_projection( + self, + ) -> None: + # When constraints are active and `fixed_features` is provided, the + # SLSQP projection must pin the fixed dims so they cannot be silently + # adjusted to satisfy the constraint. + search_space_digest = SearchSpaceDigest( + feature_names=["x1", "x2", "x3"], + bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0)], + ) + target_point = torch.tensor([1.0 / 3, 1.0 / 3, 1.0 / 3]) + acq = Acquisition( + surrogate=self.surrogate, + search_space_digest=search_space_digest, + torch_opt_config=dataclasses.replace( + self.torch_opt_config, + pruning_target_point=target_point, + ), + botorch_acqf_class=DummyAcquisitionFunction, + ) + mock_acqf = Mock() + mock_acqf._log = False + acq.acqf = mock_acqf + acq._instantiate_acquisition = Mock() + + # Candidate that satisfies x1 + x2 + x3 = 1 with x1 fixed at 0.6. + candidates = torch.tensor([[0.6, 0.3, 0.1]]) + # Equality constraint: x1 + x2 + x3 = 1 + equality_constraints = [ + (torch.tensor([0, 1, 2]), torch.tensor([1.0, 1.0, 1.0]), 1.0) + ] + bounds = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]) + # Fix x1 to its current value. Pruning dim 1 (x2 -> 1/3) breaks the + # constraint; without pinning x1 in the projection, SLSQP could move + # x1 to recover feasibility, silently overwriting the fixed value. + fixed_features = {0: 0.6} + + mock_evaluate = Mock( + side_effect=[ + torch.tensor([0.0]), # baseline af val + torch.tensor([1.0]), # dense af val + # Only dim 1 and dim 2 are eligible (dim 0 is fixed). Both + # pruning attempts should yield projected candidates that + # keep x1 == 0.6 exactly. + torch.tensor([0.95, 0.90]), # pruned af vals + torch.tensor([0.93]), # second-round pruned af val + ] + ) + acq.evaluate = mock_evaluate + + pruned_candidates, _ = acq._prune_irrelevant_parameters( + candidates=candidates, + search_space_digest=search_space_digest, + equality_constraints=equality_constraints, + bounds=bounds, + fixed_features=fixed_features, + ) + # The fixed feature must be preserved exactly through projection, + # and the constraint must still be satisfied. + self.assertEqual(pruned_candidates.shape[-1], 3) + self.assertAlmostEqual(pruned_candidates[0, 0].item(), 0.6, places=6) + self.assertAlmostEqual(pruned_candidates[0].sum().item(), 1.0, places=4) + def test_prune_irrelevant_parameters_with_task_and_fidelity_features(self) -> None: # Test pruning with both task and fidelity features that should be excluded # from pruning