From a326d55574b138889f9a0153e018eeddb300488e Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Wed, 29 Apr 2026 07:57:37 -0700 Subject: [PATCH 1/4] Wire LearnedFeatureImputation and map_heterogeneous_to_full for MultiTaskGP (#5192) Summary: X-link: https://github.com/meta-pytorch/botorch/pull/3296 Automatically configures learned feature imputation for models that pad heterogeneous per-task data to the full joint feature space. Models with native heterogeneity support are excluded from this automatic configuration. Differential Revision: D101841497 --- .../torch/botorch_modular/surrogate.py | 26 ++++++- ax/generators/torch/botorch_modular/utils.py | 70 +++++++++--------- ax/generators/torch/tests/test_surrogate.py | 74 ++++++++++++++++++- ax/generators/torch/tests/test_utils.py | 54 +++++++++++++- 4 files changed, 181 insertions(+), 43 deletions(-) diff --git a/ax/generators/torch/botorch_modular/surrogate.py b/ax/generators/torch/botorch_modular/surrogate.py index c586c61caec..09800143e2b 100644 --- a/ax/generators/torch/botorch_modular/surrogate.py +++ b/ax/generators/torch/botorch_modular/surrogate.py @@ -12,7 +12,7 @@ from collections import OrderedDict from collections.abc import Mapping, Sequence from copy import deepcopy -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from logging import Logger from typing import Any, cast @@ -67,6 +67,7 @@ from botorch.models.transforms.input import ( ChainedInputTransform, InputTransform, + LearnedFeatureImputation, Normalize, ) from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform @@ -1253,6 +1254,22 @@ def _submodel_input_constructor_mtgp( ) -> dict[str, Any]: if len(dataset.outcome_names) > 1: raise NotImplementedError("Multi-output Multi-task GPs are not yet supported.") + # If LearnedFeatureImputation is in the model config, tell construct_inputs + # to map heterogeneous per-task features to the full joint feature space. + # This must happen before the base call so construct_inputs can handle + # heterogeneous MultiTaskDatasets without raising. + uses_lfi = isinstance(model_config.input_transform_classes, list) and any( + issubclass(cls, LearnedFeatureImputation) + for cls in model_config.input_transform_classes + ) + if uses_lfi and "map_heterogeneous_to_full" not in model_config.model_options: + model_config = replace( + model_config, + model_options={ + **model_config.model_options, + "map_heterogeneous_to_full": True, + }, + ) formatted_model_inputs = _submodel_input_constructor_base( botorch_model_class=botorch_model_class, model_config=model_config, @@ -1266,9 +1283,12 @@ def _submodel_input_constructor_mtgp( # specify output tasks so that model.num_outputs = 1 # since the model only models a single outcome if formatted_model_inputs.get("output_tasks") is None: - # SSD doesn't use -1, so we need to normalize here + # SSD doesn't use -1, so we need to normalize here. Use the SSD's bound + # length since target_values is keyed by SSD column index — for + # heterogeneous MultiTaskDatasets this differs from the per-task + # dataset's feature_names length. task_feature = none_throws( - normalize_indices(indices=[task_feature], d=len(dataset.feature_names)) + normalize_indices(indices=[task_feature], d=len(search_space_digest.bounds)) )[0] if (search_space_digest.target_values is not None) and ( target_value := search_space_digest.target_values.get(task_feature) diff --git a/ax/generators/torch/botorch_modular/utils.py b/ax/generators/torch/botorch_modular/utils.py index f9f52a91eed..8770b6d321f 100644 --- a/ax/generators/torch/botorch_modular/utils.py +++ b/ax/generators/torch/botorch_modular/utils.py @@ -221,6 +221,32 @@ def use_model_list( return True +def _ensure_input_transform( + model_config: ModelConfig, + transform_cls: type[InputTransform], + position: int | None = None, +) -> None: + """Ensure ``transform_cls`` is in ``model_config.input_transform_classes``. + + If the user hasn't specified any transforms (``DEFAULT``), initialise the + list with ``[transform_cls]``. Otherwise append (or insert at ``position``) + only when the class isn't already present. Mutates ``model_config`` + in-place. + """ + itc = model_config.input_transform_classes + if isinstance(itc, list): + if transform_cls not in itc: + if position is not None: + itc.insert(position, transform_cls) + else: + itc.append(transform_cls) + else: + model_config.input_transform_classes = [transform_cls] + ito = model_config.input_transform_options or {} + ito.setdefault(transform_cls.__name__, {}) + model_config.input_transform_options = ito + + def copy_model_config_with_default_values( model_config: ModelConfig, dataset: SupervisedDataset, @@ -235,43 +261,15 @@ def copy_model_config_with_default_values( specified_model_class=model_config_copy.botorch_model_class, ) - # Handle heterogeneous multi-task datasets. + # Handle heterogeneous multi-task datasets: ensure Normalize is present + # and add LearnedFeatureImputation for models that don't handle + # heterogeneity natively. if isinstance(dataset, MultiTaskDataset) and dataset.has_heterogeneous_features: - if model_config_copy.botorch_model_class is HeterogeneousMTGP: - # HeterogeneousMTGP handles heterogeneity natively; just ensure - # Normalize is present (bounds are set later by the TL adapter). - itc = model_config_copy.input_transform_classes - if isinstance(itc, list): - if Normalize not in itc: - itc.insert(0, Normalize) - ito = model_config_copy.input_transform_options or {} - ito.setdefault("Normalize", {"bounds": None}) - model_config_copy.input_transform_options = ito - else: - model_config_copy.input_transform_classes = [Normalize] - ito = model_config_copy.input_transform_options or {} - ito.setdefault("Normalize", {"bounds": None}) - model_config_copy.input_transform_options = ito - else: - # Other models need Normalize + LFI to pad features via - # map_heterogeneous_to_full. - itc = model_config_copy.input_transform_classes - if isinstance(itc, list): - if Normalize not in itc: - itc.insert(0, Normalize) - ito = model_config_copy.input_transform_options or {} - ito.setdefault("Normalize", {"bounds": None}) - model_config_copy.input_transform_options = ito - if LearnedFeatureImputation not in itc: - itc.append(LearnedFeatureImputation) - else: - model_config_copy.input_transform_classes = [ - Normalize, - LearnedFeatureImputation, - ] - ito = model_config_copy.input_transform_options or {} - ito.setdefault("Normalize", {"bounds": None}) - model_config_copy.input_transform_options = ito + _ensure_input_transform(model_config_copy, Normalize, position=0) + if model_config_copy.botorch_model_class is not None and not issubclass( + model_config_copy.botorch_model_class, HeterogeneousMTGP + ): + _ensure_input_transform(model_config_copy, LearnedFeatureImputation) if model_config_copy.mll_class is None: model_config_copy.mll_class = ( diff --git a/ax/generators/torch/tests/test_surrogate.py b/ax/generators/torch/tests/test_surrogate.py index 82e4c824bdc..656d0bcfc3e 100644 --- a/ax/generators/torch/tests/test_surrogate.py +++ b/ax/generators/torch/tests/test_surrogate.py @@ -32,6 +32,7 @@ _construct_specified_input_transforms, _extract_model_kwargs, _make_botorch_input_transform, + _submodel_input_constructor_mtgp, submodel_input_constructor, Surrogate, SurrogateSpec, @@ -59,7 +60,12 @@ from botorch.models.model import Model, ModelList # noqa: F401 -- used in Mocks. from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood -from botorch.models.transforms.input import ChainedInputTransform, Log10, Normalize +from botorch.models.transforms.input import ( + ChainedInputTransform, + LearnedFeatureImputation, + Log10, + Normalize, +) from botorch.models.transforms.outcome import OutcomeTransform, Standardize from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset from botorch.utils.evaluation import compute_in_sample_model_fit_metric @@ -277,6 +283,72 @@ def test__make_botorch_input_transform(self) -> None: self.assertEqual(transform.indices.tolist(), [0]) self.assertEqual(transform.bounds.tolist(), [[1.0], [5.0]]) + def test_submodel_input_constructor_mtgp_map_heterogeneous(self) -> None: + """_submodel_input_constructor_mtgp passes map_heterogeneous_to_full + to construct_inputs when LFI is configured, enabling zero-padded + heterogeneous datasets to be used with MultiTaskGP.""" + ds_target = SupervisedDataset( + X=torch.tensor([[1.0, 0.0], [2.0, 0.0]]), + Y=torch.tensor([[1.0], [2.0]]), + feature_names=["x0", "task"], + outcome_names=["y_task_0"], + ) + ds_source = SupervisedDataset( + X=torch.tensor([[3.0, 4.0, 1.0], [5.0, 6.0, 1.0]]), + Y=torch.tensor([[3.0], [4.0]]), + feature_names=["x0", "x1", "task"], + outcome_names=["y_task_1"], + ) + mt_dataset = MultiTaskDataset( + datasets=[ds_target, ds_source], + target_outcome_name="y_task_0", + task_feature_index=-1, + ) + self.assertTrue(mt_dataset.has_heterogeneous_features) + ssd = SearchSpaceDigest( + feature_names=["x0", "x1", "task"], + bounds=[(0.0, 5.0), (0.0, 6.0), (0.0, 1.0)], + task_features=[2], + target_values={2: 0.0}, + ) + surrogate = Surrogate( + surrogate_spec=SurrogateSpec( + model_configs=[ModelConfig(botorch_model_class=MultiTaskGP)] + ) + ) + + with self.subTest("with LFI — construct_inputs succeeds"): + config_with_lfi = ModelConfig( + botorch_model_class=MultiTaskGP, + input_transform_classes=[Normalize, LearnedFeatureImputation], + ) + result = _submodel_input_constructor_mtgp( + botorch_model_class=MultiTaskGP, + model_config=config_with_lfi, + dataset=mt_dataset, + search_space_digest=ssd, + surrogate=surrogate, + ) + self.assertEqual(result["train_X"].shape[-1], 3) + + with self.subTest("without LFI — construct_inputs raises"): + from botorch.exceptions.errors import ( + UnsupportedError as BotorchUnsupportedError, + ) + + config_no_lfi = ModelConfig( + botorch_model_class=MultiTaskGP, + input_transform_classes=[Normalize], + ) + with self.assertRaises(BotorchUnsupportedError): + _submodel_input_constructor_mtgp( + botorch_model_class=MultiTaskGP, + model_config=config_no_lfi, + dataset=mt_dataset, + search_space_digest=ssd, + surrogate=surrogate, + ) + class SurrogateTest(TestCase): def setUp(self, cuda: bool = False) -> None: diff --git a/ax/generators/torch/tests/test_utils.py b/ax/generators/torch/tests/test_utils.py index 66ae1e11669..b2bfd7eccdf 100644 --- a/ax/generators/torch/tests/test_utils.py +++ b/ax/generators/torch/tests/test_utils.py @@ -70,7 +70,7 @@ from botorch.models.model_list_gp_regression import ModelListGP from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP -from botorch.models.transforms.input import LearnedFeatureImputation, Normalize +from botorch.models.transforms.input import LearnedFeatureImputation, Normalize, Warp from botorch.posteriors.ensemble import EnsemblePosterior from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset from botorch.utils.types import DEFAULT @@ -244,7 +244,7 @@ def test_copy_model_config_heterogeneous_mtgp(self) -> None: self.assertEqual(updated_config.input_transform_classes, [Normalize]) self.assertEqual( none_throws(updated_config.input_transform_options), - {"Normalize": {"bounds": None}}, + {"Normalize": {}}, ) # Explicit HeterogeneousMTGP behaves the same. @@ -257,7 +257,7 @@ def test_copy_model_config_heterogeneous_mtgp(self) -> None: self.assertEqual(updated_config.input_transform_classes, [Normalize]) self.assertEqual( none_throws(updated_config.input_transform_options), - {"Normalize": {"bounds": None}}, + {"Normalize": {}}, ) def test_copy_model_config_mtgp_with_lfi_injection(self) -> None: @@ -302,6 +302,54 @@ def test_copy_model_config_does_not_add_normalize_for_other_models(self) -> None self.assertEqual(updated_config.input_transform_classes, DEFAULT) self.assertEqual(updated_config.input_transform_options, {}) + def test_copy_model_config_adds_imputation_for_heterogeneous(self) -> None: + mt_dataset = self._get_heterogeneous_mt_dataset() + ssd = dataclasses.replace(self.search_space_digest, task_features=[-1]) + + with self.subTest("no_input_transform_classes"): + model_config = ModelConfig(botorch_model_class=MultiTaskGP) + updated_config = copy_model_config_with_default_values( + model_config=model_config, + dataset=mt_dataset, + search_space_digest=ssd, + ) + self.assertEqual(updated_config.botorch_model_class, MultiTaskGP) + self.assertEqual( + updated_config.input_transform_classes, + [Normalize, LearnedFeatureImputation], + ) + + with self.subTest("existing_transform_classes"): + model_config = ModelConfig( + botorch_model_class=MultiTaskGP, + input_transform_classes=[Warp], + input_transform_options={"Warp": {}}, + ) + updated_config = copy_model_config_with_default_values( + model_config=model_config, + dataset=mt_dataset, + search_space_digest=ssd, + ) + self.assertEqual( + updated_config.input_transform_classes, + [Normalize, Warp, LearnedFeatureImputation], + ) + + with self.subTest("imputation_already_present"): + model_config = ModelConfig( + botorch_model_class=MultiTaskGP, + input_transform_classes=[Normalize, LearnedFeatureImputation], + ) + updated_config = copy_model_config_with_default_values( + model_config=model_config, + dataset=mt_dataset, + search_space_digest=ssd, + ) + self.assertEqual( + updated_config.input_transform_classes, + [Normalize, LearnedFeatureImputation], + ) + def test_choose_model_class_discrete_features(self) -> None: # With discrete features, use MixedSingleTaskyGP. self.assertEqual( From 93d1adda05b530fe531139fb45550ec916f3fe37 Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Wed, 29 Apr 2026 07:57:37 -0700 Subject: [PATCH 2/4] Default to MultiTaskGP + LearnedFeatureImputation for heterogeneous TL (#5193) Summary: Switches the default heterogeneous transfer learning model from a specialized per-task kernel model to a standard multi-task GP with learned feature imputation. The previous default model class is marked as deprecated. Differential Revision: D102197137 --- .../input_constructors/input_transforms.py | 4 ++++ ax/generators/torch/botorch_modular/utils.py | 7 +++---- ax/generators/torch/tests/test_utils.py | 20 +++++++++++-------- 3 files changed, 19 insertions(+), 12 deletions(-) diff --git a/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py b/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py index 133218de7f0..5ca2e6ff72a 100644 --- a/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py +++ b/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py @@ -402,10 +402,14 @@ def _input_transform_argparse_learned_feature_imputation( torch.ones(d, dtype=dtype, device=torch_device), ] ) + # The target task is at position 0 (target_dataset is prepended above), so + # at posterior time — when X arrives without a task column — LFI applies + # the target task's imputation pattern. kwargs: dict[str, Any] = { "feature_indices": feature_indices, "d": d, "task_feature_index": task_feature_index, + "target_task": 0, "bounds": bounds, "device": torch_device, "dtype": dtype, diff --git a/ax/generators/torch/botorch_modular/utils.py b/ax/generators/torch/botorch_modular/utils.py index 8770b6d321f..36b16557ac3 100644 --- a/ax/generators/torch/botorch_modular/utils.py +++ b/ax/generators/torch/botorch_modular/utils.py @@ -319,16 +319,15 @@ def choose_model_class( ) # Check for heterogeneous multi-task datasets. If a model class was - # explicitly specified, respect it; otherwise default to HeterogeneousMTGP. + # explicitly specified, respect it; otherwise default to MultiTaskGP + # (LearnedFeatureImputation handles missing features). if ( search_space_digest.task_features and isinstance(dataset, MultiTaskDataset) and dataset.has_heterogeneous_features ): model_class = ( - specified_model_class - if specified_model_class is not None - else HeterogeneousMTGP + specified_model_class if specified_model_class is not None else MultiTaskGP ) logger.debug(f"Chose BoTorch model class: {model_class}.") return model_class diff --git a/ax/generators/torch/tests/test_utils.py b/ax/generators/torch/tests/test_utils.py index b2bfd7eccdf..8126998aa73 100644 --- a/ax/generators/torch/tests/test_utils.py +++ b/ax/generators/torch/tests/test_utils.py @@ -186,9 +186,9 @@ def test_choose_model_class_heterogeneous_task_features(self) -> None: mt_dataset = self._get_heterogeneous_mt_dataset() ssd = dataclasses.replace(self.search_space_digest, task_features=[-1]) - # Default: HeterogeneousMTGP. + # Default: MultiTaskGP (LearnedFeatureImputation handles missing features). self.assertEqual( - HeterogeneousMTGP, + MultiTaskGP, choose_model_class(dataset=mt_dataset, search_space_digest=ssd), ) @@ -233,19 +233,23 @@ def test_copy_model_config_heterogeneous_mtgp(self) -> None: mt_dataset = self._get_heterogeneous_mt_dataset() ssd = dataclasses.replace(self.search_space_digest, task_features=[-1]) - # Default (no model class specified) -> HeterogeneousMTGP. - # LFI is NOT injected; input_transform_classes stays DEFAULT. + # Default (no model class specified) -> MultiTaskGP. + # LFI is injected for MultiTaskGP with heterogeneous data. updated_config = copy_model_config_with_default_values( model_config=ModelConfig(), dataset=mt_dataset, search_space_digest=ssd, ) - self.assertEqual(updated_config.botorch_model_class, HeterogeneousMTGP) - self.assertEqual(updated_config.input_transform_classes, [Normalize]) + self.assertEqual(updated_config.botorch_model_class, MultiTaskGP) self.assertEqual( - none_throws(updated_config.input_transform_options), - {"Normalize": {}}, + updated_config.input_transform_classes, + [Normalize, LearnedFeatureImputation], ) + # LFI is present in transform classes but absent from options; its + # argparse computes kwargs from the dataset at construction time. + ito = none_throws(updated_config.input_transform_options) + self.assertEqual(ito, {"Normalize": {}}) + self.assertNotIn("LearnedFeatureImputation", ito) # Explicit HeterogeneousMTGP behaves the same. updated_config = copy_model_config_with_default_values( From 06d84cdce393e25fb55b1b82351dddd8c6d13059 Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Wed, 29 Apr 2026 07:57:37 -0700 Subject: [PATCH 3/4] Use StratifiedStandardize for per-task Y standardization in TL (#5194) Summary: Adds per-task outcome standardization to the transfer learning adapter, ensuring each task's observations are standardized independently rather than jointly. Updates the default transform pipeline to use TL-specific outcome transforms. This removes ambiguity on whether the right transforms have been applied (e.g. QuickBO/warm-starting), where standardization is not performed across, but within experiments. Differential Revision: D102197139 --- ax/adapter/registry.py | 1 + ax/adapter/transfer_learning/adapter.py | 10 ++++++---- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/ax/adapter/registry.py b/ax/adapter/registry.py index 52b14d3a6ba..0619c35c41e 100644 --- a/ax/adapter/registry.py +++ b/ax/adapter/registry.py @@ -139,6 +139,7 @@ ] Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY, StandardizeY] +TL_Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY] # Expected `List[Type[Transform]]` for 2nd anonymous parameter to # call `list.__add__` but got `List[Type[SearchSpaceToChoice]]`. diff --git a/ax/adapter/transfer_learning/adapter.py b/ax/adapter/transfer_learning/adapter.py index 92585add888..73d7146a4f5 100644 --- a/ax/adapter/transfer_learning/adapter.py +++ b/ax/adapter/transfer_learning/adapter.py @@ -23,7 +23,7 @@ Generators, GeneratorSetup, MBM_X_trans, - Y_trans, + TL_Y_trans, ) from ax.adapter.torch import FIT_MODEL_ERROR, TorchAdapter from ax.adapter.transfer_learning.utils import get_joint_search_space @@ -54,6 +54,7 @@ from ax.utils.common.logger import get_logger from botorch.models.multitask import MultiTaskGP from botorch.models.transforms.input import InputTransform, Normalize +from botorch.models.transforms.outcome import StratifiedStandardize from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset from gpytorch.kernels.kernel import Kernel from pyre_extensions import assert_is_instance @@ -793,7 +794,7 @@ def transfer_learning_generator_specs_constructor( Args: model_class: The MultiTask BoTorch Model to use in the BOTL. transform: Optional list of transforms to use in the Adapter. - Defaults to MBM_X_trans + [MetadataToTask] + Y_trans. + Defaults to MBM_X_trans + [MetadataToTask] + TL_Y_trans. jit_compile: Whether to use jit compilation in Pyro when the fully Bayesian model is used. torch_device: What torch device to use (defaults to None, i.e. falls back to @@ -828,7 +829,7 @@ def transfer_learning_generator_specs_constructor( input_transform_options: dict[str, dict[str, Any]] = { "Normalize": {}, } - transforms = transforms or MBM_X_trans + [MetadataToTask] + Y_trans + transforms = transforms or MBM_X_trans + [MetadataToTask] + TL_Y_trans transform_configs = get_derelativize_config( derelativize_with_raw_status_quo=derelativize_with_raw_status_quo ) @@ -846,6 +847,7 @@ def transfer_learning_generator_specs_constructor( botorch_model_class=model_class, model_options=botorch_model_kwargs or {}, input_transform_classes=input_transform_classes, + outcome_transform_classes=[StratifiedStandardize], input_transform_options=input_transform_options, mll_options=mll_kwargs, covar_module_class=covar_module_class, @@ -887,5 +889,5 @@ def transfer_learning_generator_specs_constructor( GENERATOR_KEY_TO_GENERATOR_SETUP["BOTL"] = GeneratorSetup( adapter_class=TransferLearningAdapter, generator_class=BoTorchGenerator, - transforms=MBM_X_trans + [MetadataToTask] + Y_trans, + transforms=MBM_X_trans + [MetadataToTask] + TL_Y_trans, ) From 1b629a3f7b49d6635ba215d00b39ea358648d0d3 Mon Sep 17 00:00:00 2001 From: Carl Hvarfner Date: Wed, 29 Apr 2026 07:57:37 -0700 Subject: [PATCH 4/4] Use get_heterogeneous_feature_mapping in LFI argparse dispatcher (#5195) Summary: Refactors the learned imputation argument dispatcher to delegate feature index computation to the dataset's built-in mapping utility. This eliminates duplicated feature-ordering logic and ensures consistency with the canonical ordering convention. Differential Revision: D102197138 --- .../input_constructors/input_transforms.py | 36 +++---------------- .../tests/test_input_transform_argparse.py | 4 +-- 2 files changed, 5 insertions(+), 35 deletions(-) diff --git a/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py b/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py index 5ca2e6ff72a..fd0d97e0abc 100644 --- a/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py +++ b/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py @@ -357,40 +357,12 @@ def _input_transform_argparse_learned_feature_imputation( ) input_transform_options = input_transform_options or {} - # Order datasets: target first, then remaining (same as HeterogeneousMTGP). - child_datasets = dataset.datasets.copy() - target_dataset = child_datasets.pop(dataset.target_outcome_name) - all_datasets = [target_dataset] + list(child_datasets.values()) - - # The feature_names[:task_feature_index] slice only works when the task - # column is the last column (index == -1). Guard against other positions - # the same way ImputedMultiTaskGP.construct_inputs does. + # Delegate feature ordering and index computation to MultiTaskDataset. + all_datasets, feature_indices_list, d = dataset.get_heterogeneous_feature_mapping() + feature_indices = dict(enumerate(feature_indices_list)) task_feature_index = ( - dataset.task_feature_index if (dataset.task_feature_index is not None) else -1 + dataset.task_feature_index if dataset.task_feature_index is not None else -1 ) - if task_feature_index != -1: - raise NotImplementedError( - "LearnedFeatureImputation argparse only supports " - "task_feature_index == -1. Got " - f"task_feature_index={task_feature_index}." - ) - - # Use target's feature order as canonical (NO alphabetical sort). - # Source-only features are appended at the end. - all_features: list[str] = list(target_dataset.feature_names[:task_feature_index]) - for ds in all_datasets[1:]: - for fn in ds.feature_names[:task_feature_index]: - if fn not in all_features: - all_features.append(fn) - d = len(all_features) - - # Map each task's features to indices in the global feature space. - feature_indices = { - task_idx: [ - all_features.index(fn) for fn in ds.feature_names[:task_feature_index] - ] - for task_idx, ds in enumerate(all_datasets) - } dtype = torch_dtype or torch.float64 # Constrain imputation values to [0, 1] since the preceding Normalize diff --git a/ax/generators/torch/tests/test_input_transform_argparse.py b/ax/generators/torch/tests/test_input_transform_argparse.py index 880c9fc48dc..acecf98a278 100644 --- a/ax/generators/torch/tests/test_input_transform_argparse.py +++ b/ax/generators/torch/tests/test_input_transform_argparse.py @@ -475,9 +475,7 @@ def test_argparse_learned_feature_imputation(self) -> None: target_outcome_name="y0", task_feature_index=0, ) - with self.assertRaisesRegex( - NotImplementedError, "task_feature_index == -1" - ): + with self.assertRaisesRegex(NotImplementedError, "task_feature_index.*-1"): input_transform_argparse( LearnedFeatureImputation, dataset=bad_ds,