Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 161 additions & 9 deletions ax/generators/torch/botorch_modular/acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,11 @@
AnalyticExpectedUtilityOfBestOption,
qExpectedUtilityOfBestOption,
)
from botorch.exceptions.errors import BotorchError, InputDataError
from botorch.exceptions.errors import (
BotorchError,
CandidateGenerationError,
InputDataError,
)
from botorch.generation.sampling import SamplingStrategy
from botorch.models.model import Model
from botorch.optim.optimize import (
Expand All @@ -72,7 +76,10 @@
optimize_acqf_mixed_alternating,
should_use_mixed_alternating_optimizer,
)
from botorch.optim.parameter_constraints import evaluate_feasibility
from botorch.optim.parameter_constraints import (
evaluate_feasibility,
project_to_feasible_space_via_slsqp,
)
from botorch.utils.constraints import get_outcome_constraint_transforms
from pyre_extensions import assert_is_instance, none_throws
from torch import Tensor
Expand Down Expand Up @@ -892,6 +899,7 @@ def optimize(
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
fixed_features=fixed_features,
bounds=bounds,
)
# Validate candidates before returning
validate_candidates(
Expand Down Expand Up @@ -1007,6 +1015,7 @@ def _prune_irrelevant_parameters(
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
fixed_features: dict[int, float] | None = None,
bounds: Tensor | None = None,
) -> tuple[Tensor, Tensor]:
r"""Prune irrelevant parameters from the candidates using BONSAI.

Expand Down Expand Up @@ -1042,6 +1051,11 @@ def _prune_irrelevant_parameters(
corresponds to the `l_i`-th feature of that element.
fixed_features: A map `{feature_index: value}` for features that
should be fixed to a particular value during generation.
bounds: A `2 x d`-dim tensor of lower and upper parameter bounds.
Required when `inequality_constraints` or `equality_constraints`
are provided: pruned candidates are projected onto the feasible
set via SLSQP, and the projection needs the parameter bounds to
define the feasible region. Unused when no constraints are set.

Returns:
A two-element tuple containing an `q x d`-dim tensor of generated
Expand Down Expand Up @@ -1085,12 +1099,14 @@ def _prune_irrelevant_parameters(
# dense AF val
final_af_val = dense_af_val
# If the current incremental AF value is zero, then we skip pruning
has_constraints = bool(inequality_constraints or equality_constraints)
if dense_incremental_af_val > 0.0:
remaining_indices = set(range(candidates.shape[-1])) - excluded_indices
# remove features that are already set to target_point
remaining_indices -= set(
(candidates[i] == target_point).nonzero().view(-1).tolist()
)
initial_remaining = set(remaining_indices)
# len(remaining_indices) - 1 is used here so that we do not prune
# every dimension
for _ in range(len(remaining_indices) - 1):
Expand All @@ -1107,13 +1123,23 @@ def _prune_irrelevant_parameters(
indices=indices,
targets=target_point[indices],
)
# remove candidates that violate constraints after pruning
pruned_candidates, indices = _remove_infeasible_candidates(
candidates=pruned_candidates,
indices=indices,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
)
# Project pruned candidates onto the feasible set
# (pinning the pruned dim and previously pruned dims),
# then filter any that remain infeasible.
if has_constraints:
previously_pruned = initial_remaining - remaining_indices
pruned_candidates, indices = (
_project_and_filter_pruned_candidates(
candidates=pruned_candidates,
indices=indices,
target_point=target_point,
pruned_dims=previously_pruned,
bounds=none_throws(bounds),
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
fixed_features=fixed_features,
)
)
if pruned_candidates.shape[0] == 0:
# no feasible points, continue to
# next candidate
Expand Down Expand Up @@ -1253,3 +1279,129 @@ def _remove_infeasible_candidates(
candidates = candidates[is_feasible]
indices = indices[is_feasible]
return candidates, indices


def _project_and_filter_pruned_candidates(
candidates: Tensor,
indices: Tensor,
target_point: Tensor,
pruned_dims: set[int],
bounds: Tensor,
inequality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
equality_constraints: list[tuple[Tensor, Tensor, float]] | None = None,
fixed_features: dict[int, float] | None = None,
) -> tuple[Tensor, Tensor]:
r"""Project pruned candidates onto the feasible set, then filter infeasible.

Helper for ``Acquisition._prune_irrelevant_parameters`` (BONSAI). It is
only meaningful in the context of that greedy-pruning loop and is not
intended for standalone use.

Background: BONSAI pruning evaluates a candidate dimension-by-dimension
by setting one dimension at a time to its target-point value. Each row
of ``candidates`` is one such trial -- the dense candidate with the
dimension at ``indices[i]`` swapped to ``target_point[indices[i]]``.
Under linear constraints, swapping a single dimension to the target
typically violates the constraints; rather than discarding the trial
(the prior behavior), we adjust the *other* free dimensions to recover
feasibility while keeping the swapped dimension and all previously
pruned dimensions pinned. Trials whose pins make the constraint system
infeasible -- and the rare case where projection succeeds but the
result still violates constraints -- are filtered out via the mask
returned to the caller.

Args:
candidates: A ``b x 1 x d``-dim tensor of pruned candidates (one row
per single-dimension prune attempt for the current BONSAI
iteration).
indices: A ``b``-dim tensor indicating which dimension was pruned
in each batch element.
target_point: A ``d``-dim tensor of target values for pruning.
pruned_dims: Set of dimension indices already pruned in prior
greedy iterations (to be kept pinned during projection).
bounds: A ``2 x d``-dim tensor of lower and upper bounds.
inequality_constraints: Inequality constraints in BoTorch format.
equality_constraints: Equality constraints in BoTorch format.
fixed_features: A map ``{feature_index: value}`` from the caller.
These dimensions are excluded from pruning at the outer loop and
must also be pinned during projection so SLSQP cannot adjust
them while satisfying the constraints. Without this, fixed
features could be silently altered.

Returns:
A two-element tuple of filtered ``(candidates, indices)``.
"""
# Pre-compute which dims participate in any constraint, and check whether
# any constraint is inter-point (2D index tensor). Inter-point constraints
# apply across the q-batch, but each row here is a single-candidate prune
# attempt -- ``project_to_feasible_space_via_slsqp`` cannot evaluate
# inter-point constraints on a 1 x d input. Fall back to the original
# filter-only behavior in that case.
constrained_dims: set[int] = set()
has_interpoint_constraint = False
for constraints in (inequality_constraints, equality_constraints):
if constraints is not None:
for c_indices, _, _ in constraints:
if c_indices.dim() == 1:
constrained_dims.update(c_indices.tolist())
else:
constrained_dims.update(c_indices[:, -1].tolist())
has_interpoint_constraint = True
if has_interpoint_constraint:
return _remove_infeasible_candidates(
candidates=candidates,
indices=indices,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
)

# Build fixed_features for previously pruned dims and the caller's
# fixed_features (both shared across all candidates in this iteration).
prev_fixed: dict[int, float] = {k: target_point[k].item() for k in pruned_dims}
if fixed_features is not None:
prev_fixed.update(fixed_features)

feasible_mask = torch.ones(candidates.shape[0], dtype=torch.bool)
result = candidates.clone()

for i in range(candidates.shape[0]):
j: int = int(indices[i].item())
# If the pruned dim doesn't participate in any constraint,
# pruning it can't violate anything — skip projection.
if j not in constrained_dims:
continue
# Pin the currently pruned dim, all previously pruned dims, and the
# caller's fixed features.
fixed: dict[int, float | Tensor] = {
j: float(target_point[j].item()),
**prev_fixed,
}
try:
projected = project_to_feasible_space_via_slsqp(
X=candidates[i], # 1 x d
bounds=bounds,
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
fixed_features=fixed,
)
result[i] = projected
except CandidateGenerationError:
# Pin makes the system infeasible — mark for removal.
# The post-projection feasibility check below is the safety net
# for any candidates that project but still violate constraints.
feasible_mask[i] = False

# Final safety-net feasibility check after projection.
if feasible_mask.any():
is_feasible = evaluate_feasibility(
X=result[feasible_mask],
inequality_constraints=inequality_constraints,
equality_constraints=equality_constraints,
)
# Map back to the full mask.
feasible_subset_indices = feasible_mask.nonzero(as_tuple=True)[0]
for idx, feas in zip(feasible_subset_indices, is_feasible):
if not feas:
feasible_mask[idx] = False

return result[feasible_mask], indices[feasible_mask]
125 changes: 125 additions & 0 deletions ax/generators/torch/tests/test_acquisition.py
Original file line number Diff line number Diff line change
Expand Up @@ -1831,6 +1831,7 @@ def test_prune_irrelevant_parameters_with_inequality_constraints(self) -> None:
candidates=candidates,
search_space_digest=search_space_digest,
inequality_constraints=inequality_constraints,
bounds=torch.tensor([[0.0, 0.0], [1.0, 1.0]]),
)
self.assertTrue(torch.equal(pruned_candidates, torch.tensor([[0.2, 0.8]])))
self.assertTrue(torch.equal(pruned_values, torch.tensor([0.91])))
Expand All @@ -1848,6 +1849,7 @@ def test_prune_irrelevant_parameters_with_inequality_constraints(self) -> None:
inequality_constraints=[
(torch.tensor([0, 1]), torch.tensor([1.0, 1.0]), 1.5)
],
bounds=torch.tensor([[0.0, 0.0], [1.0, 1.0]]),
)
# No pruning: setting either dim to 0.2 gives sum=1.0 < 1.5 (infeasible)
self.assertTrue(torch.equal(pruned_candidates, torch.tensor([[0.8, 0.8]])))
Expand Down Expand Up @@ -2055,13 +2057,136 @@ def test_prune_irrelevant_parameters_with_constraints_exact_values(self) -> None
1.0,
)
],
bounds=torch.tensor([[0.0, 0.0], [1.0, 1.0]]),
)

# Only dimension 0 should be pruned
expected_candidate = torch.tensor([[0.1, 1.0]])
self.assertTrue(torch.equal(pruned_candidates, expected_candidate))
self.assertTrue(torch.equal(pruned_values, torch.tensor([1.0])))

def test_prune_irrelevant_parameters_with_equality_constraints(self) -> None:
# Test pruning with an equality constraint (x1 + x2 + x3 = 1).
# When a dimension is pruned to its target, the remaining dims should
# be projected onto the equality constraint hyperplane.
search_space_digest = SearchSpaceDigest(
feature_names=["x1", "x2", "x3"],
bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0)],
)
target_point = torch.tensor([1.0 / 3, 1.0 / 3, 1.0 / 3])
acq = Acquisition(
surrogate=self.surrogate,
search_space_digest=search_space_digest,
torch_opt_config=dataclasses.replace(
self.torch_opt_config,
pruning_target_point=target_point,
),
botorch_acqf_class=DummyAcquisitionFunction,
)
mock_acqf = Mock()
mock_acqf._log = False
acq.acqf = mock_acqf
acq._instantiate_acquisition = Mock()

# Candidate that satisfies x1 + x2 + x3 = 1.
candidates = torch.tensor([[0.5, 0.3, 0.2]])
# Equality constraint: x1 + x2 + x3 = 1
equality_constraints = [
(torch.tensor([0, 1, 2]), torch.tensor([1.0, 1.0, 1.0]), 1.0)
]
bounds = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]])

mock_evaluate = Mock(
side_effect=[
torch.tensor([0.0]), # baseline af val
torch.tensor([1.0]), # dense af val
# After pruning dim 0 to 1/3 and projecting, the candidate
# still satisfies x1+x2+x3=1. Two pruning candidates
# (dim 1 and dim 2) survive projection.
torch.tensor([0.95, 0.90]), # pruned af vals
torch.tensor([0.93]), # second round pruned af val
]
)
acq.evaluate = mock_evaluate

pruned_candidates, pruned_values = acq._prune_irrelevant_parameters(
candidates=candidates,
search_space_digest=search_space_digest,
equality_constraints=equality_constraints,
bounds=bounds,
)
# Verify that pruning occurred and the result satisfies the constraint.
self.assertEqual(pruned_candidates.shape[-1], 3)
for i in range(pruned_candidates.shape[0]):
self.assertAlmostEqual(
pruned_candidates[i].sum().item(),
1.0,
places=4,
)

def test_prune_irrelevant_parameters_fixed_features_pinned_in_projection(
self,
) -> None:
# When constraints are active and `fixed_features` is provided, the
# SLSQP projection must pin the fixed dims so they cannot be silently
# adjusted to satisfy the constraint.
search_space_digest = SearchSpaceDigest(
feature_names=["x1", "x2", "x3"],
bounds=[(0.0, 1.0), (0.0, 1.0), (0.0, 1.0)],
)
target_point = torch.tensor([1.0 / 3, 1.0 / 3, 1.0 / 3])
acq = Acquisition(
surrogate=self.surrogate,
search_space_digest=search_space_digest,
torch_opt_config=dataclasses.replace(
self.torch_opt_config,
pruning_target_point=target_point,
),
botorch_acqf_class=DummyAcquisitionFunction,
)
mock_acqf = Mock()
mock_acqf._log = False
acq.acqf = mock_acqf
acq._instantiate_acquisition = Mock()

# Candidate that satisfies x1 + x2 + x3 = 1 with x1 fixed at 0.6.
candidates = torch.tensor([[0.6, 0.3, 0.1]])
# Equality constraint: x1 + x2 + x3 = 1
equality_constraints = [
(torch.tensor([0, 1, 2]), torch.tensor([1.0, 1.0, 1.0]), 1.0)
]
bounds = torch.tensor([[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]])
# Fix x1 to its current value. Pruning dim 1 (x2 -> 1/3) breaks the
# constraint; without pinning x1 in the projection, SLSQP could move
# x1 to recover feasibility, silently overwriting the fixed value.
fixed_features = {0: 0.6}

mock_evaluate = Mock(
side_effect=[
torch.tensor([0.0]), # baseline af val
torch.tensor([1.0]), # dense af val
# Only dim 1 and dim 2 are eligible (dim 0 is fixed). Both
# pruning attempts should yield projected candidates that
# keep x1 == 0.6 exactly.
torch.tensor([0.95, 0.90]), # pruned af vals
torch.tensor([0.93]), # second-round pruned af val
]
)
acq.evaluate = mock_evaluate

pruned_candidates, _ = acq._prune_irrelevant_parameters(
candidates=candidates,
search_space_digest=search_space_digest,
equality_constraints=equality_constraints,
bounds=bounds,
fixed_features=fixed_features,
)
# The fixed feature must be preserved exactly through projection,
# and the constraint must still be satisfied.
self.assertEqual(pruned_candidates.shape[-1], 3)
self.assertAlmostEqual(pruned_candidates[0, 0].item(), 0.6, places=6)
self.assertAlmostEqual(pruned_candidates[0].sum().item(), 1.0, places=4)

def test_prune_irrelevant_parameters_with_task_and_fidelity_features(self) -> None:
# Test pruning with both task and fidelity features that should be excluded
# from pruning
Expand Down
Loading