From 491b30140de2ef84e3df6edb00809339cbf42d52 Mon Sep 17 00:00:00 2001 From: Andy Lin Date: Thu, 7 May 2026 15:24:57 -0700 Subject: [PATCH] Prefer predictive adapter over fallback in GenerationNode._fitted_adapter Summary: When BoTorch candidate generation fails after MAX_GEN_ATTEMPTS (e.g. due to search space exhaustion), the generation node falls back to Sobol for that particular gen call. This fallback overwrites `_generator_spec_to_gen_from` with a RandomAdapter, which cannot make predictions. The problem is that downstream analysis code reads `GenerationStrategy.adapter` to generate model-dependent plots (cross-validation, sensitivity, surface, modeled arm effects, etc.). Since the adapter now points to the Sobol fallback's RandomAdapter, all these analyses fail with "does not support predictions" or "TorchAdapter is required" errors -- even though the original fitted TorchAdapter is still preserved on the generator spec. This diff fixes `GenerationNode._fitted_adapter` to check: if the current adapter cannot predict, look for a fitted predictive adapter among the original `generator_specs` and prefer that instead. This is safe because the original TorchAdapter is never destroyed by the fallback -- it's just shadowed by the `_generator_spec_to_gen_from` override. Reviewed By: ItsMrLin Differential Revision: D99358260 --- ax/generation_strategy/generation_node.py | 19 +++++++++- .../tests/test_generation_node.py | 35 +++++++++++++++++++ 2 files changed, 53 insertions(+), 1 deletion(-) diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index 936f07ef1c2..7375bea6a68 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -349,15 +349,32 @@ def _fitted_adapter(self) -> Adapter | None: """Private property to return optional fitted_adapter from self.generator_spec_to_gen_from for convenience. If no model is fit, this will return None. + + If the current adapter (e.g. from a Sobol fallback after + ``_try_gen_with_fallback``) cannot predict, prefer a predictive adapter + from the original ``generator_specs`` when available. This ensures that + analysis code which relies on model predictions (e.g. cross-validation, + sensitivity, surface plots) can still use the fitted surrogate model + even after a transient fallback during candidate generation. """ try: # Using the private attribute since using the non-private `fitted_adapter` # property will raise a UserInputError if there is no fitted model. - return self.generator_spec_to_gen_from._fitted_adapter + adapter = self.generator_spec_to_gen_from._fitted_adapter except ModelError: # ModelError is raised if there are no fitted adapters to select from. return None + if adapter is not None and not adapter.can_predict: + for spec in self.generator_specs: + if ( + spec._fitted_adapter is not None + and spec._fitted_adapter.can_predict + ): + return spec._fitted_adapter + + return adapter + def __repr__(self) -> str: """String representation of this ``GenerationNode`` (note that it will abridge some aspects of ``TransitionCriterion`` and diff --git a/ax/generation_strategy/tests/test_generation_node.py b/ax/generation_strategy/tests/test_generation_node.py index 9e726165d11..0bd1595d08f 100644 --- a/ax/generation_strategy/tests/test_generation_node.py +++ b/ax/generation_strategy/tests/test_generation_node.py @@ -285,6 +285,41 @@ def test_gen_with_no_trial_type(self) -> None: self.assertIsNotNone(gr) self.assertNotIn("trial_type", none_throws(gr.gen_metadata)) + @mock_botorch_optimize + def test_fitted_adapter_prefers_predictive_over_fallback(self) -> None: + """After a Sobol fallback, _fitted_adapter should still return the + original predictive TorchAdapter rather than the fallback's + RandomAdapter. This ensures analysis code can generate model-dependent + plots even after a transient fallback during candidate generation.""" + node = GenerationNode( + name="test", + generator_specs=[ + GeneratorSpec( + generator_enum=Generators.BOTORCH_MODULAR, + generator_kwargs={}, + generator_gen_kwargs={}, + ), + ], + ) + node._fit(experiment=self.branin_experiment) + original_adapter = none_throws(node._fitted_adapter) + self.assertTrue(original_adapter.can_predict) + + # Simulate fallback: fit a Sobol fallback spec and override + # _generator_spec_to_gen_from, mimicking _try_gen_with_fallback. + fallback_spec = GeneratorSpec( + generator_enum=Generators.SOBOL, + generator_key_override="Fallback_Sobol", + ) + fallback_spec.fit(experiment=self.branin_experiment) + self.assertFalse(none_throws(fallback_spec._fitted_adapter).can_predict) + node._generator_spec_to_gen_from = fallback_spec + + # _fitted_adapter should still return the original predictive adapter. + adapter_after_fallback = none_throws(node._fitted_adapter) + self.assertTrue(adapter_after_fallback.can_predict) + self.assertIs(adapter_after_fallback, original_adapter) + @mock_botorch_optimize def test_generator_gen_kwargs_deepcopy(self) -> None: sampler = SobolQMCNormalSampler(torch.Size([1]))