Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion ax/generation_strategy/generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,15 +349,32 @@ def _fitted_adapter(self) -> Adapter | None:
"""Private property to return optional fitted_adapter from
self.generator_spec_to_gen_from for convenience. If no model is fit,
this will return None.

If the current adapter (e.g. from a Sobol fallback after
``_try_gen_with_fallback``) cannot predict, prefer a predictive adapter
from the original ``generator_specs`` when available. This ensures that
analysis code which relies on model predictions (e.g. cross-validation,
sensitivity, surface plots) can still use the fitted surrogate model
even after a transient fallback during candidate generation.
"""
try:
# Using the private attribute since using the non-private `fitted_adapter`
# property will raise a UserInputError if there is no fitted model.
return self.generator_spec_to_gen_from._fitted_adapter
adapter = self.generator_spec_to_gen_from._fitted_adapter
except ModelError:
# ModelError is raised if there are no fitted adapters to select from.
return None

if adapter is not None and not adapter.can_predict:
for spec in self.generator_specs:
if (
spec._fitted_adapter is not None
and spec._fitted_adapter.can_predict
):
return spec._fitted_adapter

return adapter

def __repr__(self) -> str:
"""String representation of this ``GenerationNode`` (note that it
will abridge some aspects of ``TransitionCriterion`` and
Expand Down
35 changes: 35 additions & 0 deletions ax/generation_strategy/tests/test_generation_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,41 @@ def test_gen_with_no_trial_type(self) -> None:
self.assertIsNotNone(gr)
self.assertNotIn("trial_type", none_throws(gr.gen_metadata))

@mock_botorch_optimize
def test_fitted_adapter_prefers_predictive_over_fallback(self) -> None:
"""After a Sobol fallback, _fitted_adapter should still return the
original predictive TorchAdapter rather than the fallback's
RandomAdapter. This ensures analysis code can generate model-dependent
plots even after a transient fallback during candidate generation."""
node = GenerationNode(
name="test",
generator_specs=[
GeneratorSpec(
generator_enum=Generators.BOTORCH_MODULAR,
generator_kwargs={},
generator_gen_kwargs={},
),
],
)
node._fit(experiment=self.branin_experiment)
original_adapter = none_throws(node._fitted_adapter)
self.assertTrue(original_adapter.can_predict)

# Simulate fallback: fit a Sobol fallback spec and override
# _generator_spec_to_gen_from, mimicking _try_gen_with_fallback.
fallback_spec = GeneratorSpec(
generator_enum=Generators.SOBOL,
generator_key_override="Fallback_Sobol",
)
fallback_spec.fit(experiment=self.branin_experiment)
self.assertFalse(none_throws(fallback_spec._fitted_adapter).can_predict)
node._generator_spec_to_gen_from = fallback_spec

# _fitted_adapter should still return the original predictive adapter.
adapter_after_fallback = none_throws(node._fitted_adapter)
self.assertTrue(adapter_after_fallback.can_predict)
self.assertIs(adapter_after_fallback, original_adapter)

@mock_botorch_optimize
def test_generator_gen_kwargs_deepcopy(self) -> None:
sampler = SobolQMCNormalSampler(torch.Size([1]))
Expand Down
Loading