diff --git a/ax/generation_strategy/generation_node.py b/ax/generation_strategy/generation_node.py index 936f07ef1c2..7375bea6a68 100644 --- a/ax/generation_strategy/generation_node.py +++ b/ax/generation_strategy/generation_node.py @@ -349,15 +349,32 @@ def _fitted_adapter(self) -> Adapter | None: """Private property to return optional fitted_adapter from self.generator_spec_to_gen_from for convenience. If no model is fit, this will return None. + + If the current adapter (e.g. from a Sobol fallback after + ``_try_gen_with_fallback``) cannot predict, prefer a predictive adapter + from the original ``generator_specs`` when available. This ensures that + analysis code which relies on model predictions (e.g. cross-validation, + sensitivity, surface plots) can still use the fitted surrogate model + even after a transient fallback during candidate generation. """ try: # Using the private attribute since using the non-private `fitted_adapter` # property will raise a UserInputError if there is no fitted model. - return self.generator_spec_to_gen_from._fitted_adapter + adapter = self.generator_spec_to_gen_from._fitted_adapter except ModelError: # ModelError is raised if there are no fitted adapters to select from. return None + if adapter is not None and not adapter.can_predict: + for spec in self.generator_specs: + if ( + spec._fitted_adapter is not None + and spec._fitted_adapter.can_predict + ): + return spec._fitted_adapter + + return adapter + def __repr__(self) -> str: """String representation of this ``GenerationNode`` (note that it will abridge some aspects of ``TransitionCriterion`` and diff --git a/ax/generation_strategy/tests/test_generation_node.py b/ax/generation_strategy/tests/test_generation_node.py index 9e726165d11..0bd1595d08f 100644 --- a/ax/generation_strategy/tests/test_generation_node.py +++ b/ax/generation_strategy/tests/test_generation_node.py @@ -285,6 +285,41 @@ def test_gen_with_no_trial_type(self) -> None: self.assertIsNotNone(gr) self.assertNotIn("trial_type", none_throws(gr.gen_metadata)) + @mock_botorch_optimize + def test_fitted_adapter_prefers_predictive_over_fallback(self) -> None: + """After a Sobol fallback, _fitted_adapter should still return the + original predictive TorchAdapter rather than the fallback's + RandomAdapter. This ensures analysis code can generate model-dependent + plots even after a transient fallback during candidate generation.""" + node = GenerationNode( + name="test", + generator_specs=[ + GeneratorSpec( + generator_enum=Generators.BOTORCH_MODULAR, + generator_kwargs={}, + generator_gen_kwargs={}, + ), + ], + ) + node._fit(experiment=self.branin_experiment) + original_adapter = none_throws(node._fitted_adapter) + self.assertTrue(original_adapter.can_predict) + + # Simulate fallback: fit a Sobol fallback spec and override + # _generator_spec_to_gen_from, mimicking _try_gen_with_fallback. + fallback_spec = GeneratorSpec( + generator_enum=Generators.SOBOL, + generator_key_override="Fallback_Sobol", + ) + fallback_spec.fit(experiment=self.branin_experiment) + self.assertFalse(none_throws(fallback_spec._fitted_adapter).can_predict) + node._generator_spec_to_gen_from = fallback_spec + + # _fitted_adapter should still return the original predictive adapter. + adapter_after_fallback = none_throws(node._fitted_adapter) + self.assertTrue(adapter_after_fallback.can_predict) + self.assertIs(adapter_after_fallback, original_adapter) + @mock_botorch_optimize def test_generator_gen_kwargs_deepcopy(self) -> None: sampler = SobolQMCNormalSampler(torch.Size([1]))