From 09c1741c540832da9871241fae03f169a6d09bc0 Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Tue, 12 May 2026 13:28:29 -0700 Subject: [PATCH] Wire LearnedFeatureImputation and map_heterogeneous_to_full for MultiTaskGP (#5192) Summary: X-link: https://github.com/meta-pytorch/botorch/pull/3296 Automatically configures learned feature imputation for models that pad heterogeneous per-task data to the full joint feature space. Models with native heterogeneity support are excluded from this automatic configuration. Reviewed By: saitcakmak Differential Revision: D101841497 --- .../torch/botorch_modular/surrogate.py | 26 ++++++- ax/generators/torch/botorch_modular/utils.py | 70 +++++++++-------- ax/generators/torch/tests/test_surrogate.py | 76 ++++++++++++++++++- ax/generators/torch/tests/test_utils.py | 54 ++++++++++++- 4 files changed, 183 insertions(+), 43 deletions(-) diff --git a/ax/generators/torch/botorch_modular/surrogate.py b/ax/generators/torch/botorch_modular/surrogate.py index c586c61caec..09800143e2b 100644 --- a/ax/generators/torch/botorch_modular/surrogate.py +++ b/ax/generators/torch/botorch_modular/surrogate.py @@ -12,7 +12,7 @@ from collections import OrderedDict from collections.abc import Mapping, Sequence from copy import deepcopy -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from logging import Logger from typing import Any, cast @@ -67,6 +67,7 @@ from botorch.models.transforms.input import ( ChainedInputTransform, InputTransform, + LearnedFeatureImputation, Normalize, ) from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform @@ -1253,6 +1254,22 @@ def _submodel_input_constructor_mtgp( ) -> dict[str, Any]: if len(dataset.outcome_names) > 1: raise NotImplementedError("Multi-output Multi-task GPs are not yet supported.") + # If LearnedFeatureImputation is in the model config, tell construct_inputs + # to map heterogeneous per-task features to the full joint feature space. + # This must happen before the base call so construct_inputs can handle + # heterogeneous MultiTaskDatasets without raising. + uses_lfi = isinstance(model_config.input_transform_classes, list) and any( + issubclass(cls, LearnedFeatureImputation) + for cls in model_config.input_transform_classes + ) + if uses_lfi and "map_heterogeneous_to_full" not in model_config.model_options: + model_config = replace( + model_config, + model_options={ + **model_config.model_options, + "map_heterogeneous_to_full": True, + }, + ) formatted_model_inputs = _submodel_input_constructor_base( botorch_model_class=botorch_model_class, model_config=model_config, @@ -1266,9 +1283,12 @@ def _submodel_input_constructor_mtgp( # specify output tasks so that model.num_outputs = 1 # since the model only models a single outcome if formatted_model_inputs.get("output_tasks") is None: - # SSD doesn't use -1, so we need to normalize here + # SSD doesn't use -1, so we need to normalize here. Use the SSD's bound + # length since target_values is keyed by SSD column index — for + # heterogeneous MultiTaskDatasets this differs from the per-task + # dataset's feature_names length. task_feature = none_throws( - normalize_indices(indices=[task_feature], d=len(dataset.feature_names)) + normalize_indices(indices=[task_feature], d=len(search_space_digest.bounds)) )[0] if (search_space_digest.target_values is not None) and ( target_value := search_space_digest.target_values.get(task_feature) diff --git a/ax/generators/torch/botorch_modular/utils.py b/ax/generators/torch/botorch_modular/utils.py index f9f52a91eed..8770b6d321f 100644 --- a/ax/generators/torch/botorch_modular/utils.py +++ b/ax/generators/torch/botorch_modular/utils.py @@ -221,6 +221,32 @@ def use_model_list( return True +def _ensure_input_transform( + model_config: ModelConfig, + transform_cls: type[InputTransform], + position: int | None = None, +) -> None: + """Ensure ``transform_cls`` is in ``model_config.input_transform_classes``. + + If the user hasn't specified any transforms (``DEFAULT``), initialise the + list with ``[transform_cls]``. Otherwise append (or insert at ``position``) + only when the class isn't already present. Mutates ``model_config`` + in-place. + """ + itc = model_config.input_transform_classes + if isinstance(itc, list): + if transform_cls not in itc: + if position is not None: + itc.insert(position, transform_cls) + else: + itc.append(transform_cls) + else: + model_config.input_transform_classes = [transform_cls] + ito = model_config.input_transform_options or {} + ito.setdefault(transform_cls.__name__, {}) + model_config.input_transform_options = ito + + def copy_model_config_with_default_values( model_config: ModelConfig, dataset: SupervisedDataset, @@ -235,43 +261,15 @@ def copy_model_config_with_default_values( specified_model_class=model_config_copy.botorch_model_class, ) - # Handle heterogeneous multi-task datasets. + # Handle heterogeneous multi-task datasets: ensure Normalize is present + # and add LearnedFeatureImputation for models that don't handle + # heterogeneity natively. if isinstance(dataset, MultiTaskDataset) and dataset.has_heterogeneous_features: - if model_config_copy.botorch_model_class is HeterogeneousMTGP: - # HeterogeneousMTGP handles heterogeneity natively; just ensure - # Normalize is present (bounds are set later by the TL adapter). - itc = model_config_copy.input_transform_classes - if isinstance(itc, list): - if Normalize not in itc: - itc.insert(0, Normalize) - ito = model_config_copy.input_transform_options or {} - ito.setdefault("Normalize", {"bounds": None}) - model_config_copy.input_transform_options = ito - else: - model_config_copy.input_transform_classes = [Normalize] - ito = model_config_copy.input_transform_options or {} - ito.setdefault("Normalize", {"bounds": None}) - model_config_copy.input_transform_options = ito - else: - # Other models need Normalize + LFI to pad features via - # map_heterogeneous_to_full. - itc = model_config_copy.input_transform_classes - if isinstance(itc, list): - if Normalize not in itc: - itc.insert(0, Normalize) - ito = model_config_copy.input_transform_options or {} - ito.setdefault("Normalize", {"bounds": None}) - model_config_copy.input_transform_options = ito - if LearnedFeatureImputation not in itc: - itc.append(LearnedFeatureImputation) - else: - model_config_copy.input_transform_classes = [ - Normalize, - LearnedFeatureImputation, - ] - ito = model_config_copy.input_transform_options or {} - ito.setdefault("Normalize", {"bounds": None}) - model_config_copy.input_transform_options = ito + _ensure_input_transform(model_config_copy, Normalize, position=0) + if model_config_copy.botorch_model_class is not None and not issubclass( + model_config_copy.botorch_model_class, HeterogeneousMTGP + ): + _ensure_input_transform(model_config_copy, LearnedFeatureImputation) if model_config_copy.mll_class is None: model_config_copy.mll_class = ( diff --git a/ax/generators/torch/tests/test_surrogate.py b/ax/generators/torch/tests/test_surrogate.py index 82e4c824bdc..80dfa36b178 100644 --- a/ax/generators/torch/tests/test_surrogate.py +++ b/ax/generators/torch/tests/test_surrogate.py @@ -32,6 +32,7 @@ _construct_specified_input_transforms, _extract_model_kwargs, _make_botorch_input_transform, + _submodel_input_constructor_mtgp, submodel_input_constructor, Surrogate, SurrogateSpec, @@ -59,7 +60,12 @@ from botorch.models.model import Model, ModelList # noqa: F401 -- used in Mocks. from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood -from botorch.models.transforms.input import ChainedInputTransform, Log10, Normalize +from botorch.models.transforms.input import ( + ChainedInputTransform, + LearnedFeatureImputation, + Log10, + Normalize, +) from botorch.models.transforms.outcome import OutcomeTransform, Standardize from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset from botorch.utils.evaluation import compute_in_sample_model_fit_metric @@ -277,6 +283,74 @@ def test__make_botorch_input_transform(self) -> None: self.assertEqual(transform.indices.tolist(), [0]) self.assertEqual(transform.bounds.tolist(), [[1.0], [5.0]]) + def test_submodel_input_constructor_mtgp_map_heterogeneous(self) -> None: + """_submodel_input_constructor_mtgp passes map_heterogeneous_to_full + to construct_inputs when LFI is configured, enabling zero-padded + heterogeneous datasets to be used with MultiTaskGP.""" + ds_target = SupervisedDataset( + X=torch.tensor([[1.0, 0.0], [2.0, 0.0]]), + Y=torch.tensor([[1.0], [2.0]]), + feature_names=["x0", "task"], + outcome_names=["y_task_0"], + ) + ds_source = SupervisedDataset( + X=torch.tensor([[3.0, 4.0, 1.0], [5.0, 6.0, 1.0]]), + Y=torch.tensor([[3.0], [4.0]]), + feature_names=["x0", "x1", "task"], + outcome_names=["y_task_1"], + ) + mt_dataset = MultiTaskDataset( + datasets=[ds_target, ds_source], + target_outcome_name="y_task_0", + task_feature_index=-1, + ) + self.assertTrue(mt_dataset.has_heterogeneous_features) + ssd = SearchSpaceDigest( + feature_names=["x0", "x1", "task"], + bounds=[(0.0, 5.0), (0.0, 6.0), (0.0, 1.0)], + task_features=[2], + target_values={2: 0.0}, + ) + surrogate = Surrogate( + surrogate_spec=SurrogateSpec( + model_configs=[ModelConfig(botorch_model_class=MultiTaskGP)] + ) + ) + + with self.subTest("with LFI — construct_inputs succeeds"): + config_with_lfi = ModelConfig( + botorch_model_class=MultiTaskGP, + input_transform_classes=[Normalize, LearnedFeatureImputation], + ) + result = _submodel_input_constructor_mtgp( + botorch_model_class=MultiTaskGP, + model_config=config_with_lfi, + dataset=mt_dataset, + search_space_digest=ssd, + surrogate=surrogate, + ) + self.assertEqual(result["train_X"].shape[-1], 3) + + with self.subTest("without LFI — construct_inputs raises"): + from botorch.exceptions.errors import ( + UnsupportedError as BotorchUnsupportedError, + ) + + config_no_lfi = ModelConfig( + botorch_model_class=MultiTaskGP, + input_transform_classes=[Normalize], + ) + with self.assertRaisesRegex( + BotorchUnsupportedError, "heterogeneous feature sets" + ): + _submodel_input_constructor_mtgp( + botorch_model_class=MultiTaskGP, + model_config=config_no_lfi, + dataset=mt_dataset, + search_space_digest=ssd, + surrogate=surrogate, + ) + class SurrogateTest(TestCase): def setUp(self, cuda: bool = False) -> None: diff --git a/ax/generators/torch/tests/test_utils.py b/ax/generators/torch/tests/test_utils.py index 66ae1e11669..b2bfd7eccdf 100644 --- a/ax/generators/torch/tests/test_utils.py +++ b/ax/generators/torch/tests/test_utils.py @@ -70,7 +70,7 @@ from botorch.models.model_list_gp_regression import ModelListGP from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP -from botorch.models.transforms.input import LearnedFeatureImputation, Normalize +from botorch.models.transforms.input import LearnedFeatureImputation, Normalize, Warp from botorch.posteriors.ensemble import EnsemblePosterior from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset from botorch.utils.types import DEFAULT @@ -244,7 +244,7 @@ def test_copy_model_config_heterogeneous_mtgp(self) -> None: self.assertEqual(updated_config.input_transform_classes, [Normalize]) self.assertEqual( none_throws(updated_config.input_transform_options), - {"Normalize": {"bounds": None}}, + {"Normalize": {}}, ) # Explicit HeterogeneousMTGP behaves the same. @@ -257,7 +257,7 @@ def test_copy_model_config_heterogeneous_mtgp(self) -> None: self.assertEqual(updated_config.input_transform_classes, [Normalize]) self.assertEqual( none_throws(updated_config.input_transform_options), - {"Normalize": {"bounds": None}}, + {"Normalize": {}}, ) def test_copy_model_config_mtgp_with_lfi_injection(self) -> None: @@ -302,6 +302,54 @@ def test_copy_model_config_does_not_add_normalize_for_other_models(self) -> None self.assertEqual(updated_config.input_transform_classes, DEFAULT) self.assertEqual(updated_config.input_transform_options, {}) + def test_copy_model_config_adds_imputation_for_heterogeneous(self) -> None: + mt_dataset = self._get_heterogeneous_mt_dataset() + ssd = dataclasses.replace(self.search_space_digest, task_features=[-1]) + + with self.subTest("no_input_transform_classes"): + model_config = ModelConfig(botorch_model_class=MultiTaskGP) + updated_config = copy_model_config_with_default_values( + model_config=model_config, + dataset=mt_dataset, + search_space_digest=ssd, + ) + self.assertEqual(updated_config.botorch_model_class, MultiTaskGP) + self.assertEqual( + updated_config.input_transform_classes, + [Normalize, LearnedFeatureImputation], + ) + + with self.subTest("existing_transform_classes"): + model_config = ModelConfig( + botorch_model_class=MultiTaskGP, + input_transform_classes=[Warp], + input_transform_options={"Warp": {}}, + ) + updated_config = copy_model_config_with_default_values( + model_config=model_config, + dataset=mt_dataset, + search_space_digest=ssd, + ) + self.assertEqual( + updated_config.input_transform_classes, + [Normalize, Warp, LearnedFeatureImputation], + ) + + with self.subTest("imputation_already_present"): + model_config = ModelConfig( + botorch_model_class=MultiTaskGP, + input_transform_classes=[Normalize, LearnedFeatureImputation], + ) + updated_config = copy_model_config_with_default_values( + model_config=model_config, + dataset=mt_dataset, + search_space_digest=ssd, + ) + self.assertEqual( + updated_config.input_transform_classes, + [Normalize, LearnedFeatureImputation], + ) + def test_choose_model_class_discrete_features(self) -> None: # With discrete features, use MixedSingleTaskyGP. self.assertEqual(