diff --git a/ax/adapter/torch.py b/ax/adapter/torch.py index 63568ff66e2..8383d65b120 100644 --- a/ax/adapter/torch.py +++ b/ax/adapter/torch.py @@ -774,6 +774,7 @@ def _get_fit_args( search_space: SearchSpace, experiment_data: ExperimentData, update_outcomes_and_parameters: bool, + data_parameters: list[str] | None = None, ) -> tuple[ list[SupervisedDataset], list[list[TCandidateMetadata]] | None, @@ -791,6 +792,11 @@ def _get_fit_args( update_outcomes_and_parameters: Whether to update `self.outcomes` with all outcomes found in the observations and `self.parameters` with all parameters in the search space. Typically only used in `_fit`. + data_parameters: When provided, columns to extract from + ``experiment_data``. Defaults to ``self.parameters``. Useful when + the model space is larger than the data (e.g. transfer learning + with heterogeneous search spaces where the data only contains + target columns but the model operates in a joint feature space). Returns: The datasets & metadata, extracted from the ``experiment_data``, and the @@ -818,12 +824,23 @@ def _get_fit_args( search_space_digest = extract_search_space_digest( search_space=search_space, param_names=self.parameters ) + extract_params = ( + data_parameters if data_parameters is not None else self.parameters + ) + # When data_parameters differs from self.parameters, the SSD's + # task_feature indices don't match the data columns. Pass None + # to skip MultiTaskDataset wrapping (the caller handles it). + extraction_ssd = ( + None + if data_parameters is not None and data_parameters != self.parameters + else search_space_digest + ) # Convert observations to datasets datasets, ordered_outcomes, candidate_metadata = self._convert_experiment_data( experiment_data=experiment_data, outcomes=self.outcomes, - parameters=self.parameters, - search_space_digest=search_space_digest, + parameters=extract_params, + search_space_digest=extraction_ssd, ) datasets = self._update_w_aux_exp_datasets(datasets=datasets) diff --git a/ax/adapter/transfer_learning/adapter.py b/ax/adapter/transfer_learning/adapter.py index 92585add888..0a180f9b346 100644 --- a/ax/adapter/transfer_learning/adapter.py +++ b/ax/adapter/transfer_learning/adapter.py @@ -7,7 +7,6 @@ from __future__ import annotations -import dataclasses import warnings from collections.abc import Mapping, Sequence from logging import Logger @@ -196,6 +195,24 @@ def __init__( default_model_gen_options=default_model_gen_options, ) + def _set_search_space(self, search_space: SearchSpace) -> None: + """Set search space and model space for transfer learning. + + Overrides the base class to add source-only params (as RangeParameters) + to ``_model_space`` while preserving target bounds for shared params. + This ensures the SSD naturally covers the full joint feature space + without needing post-hoc expansion, and Normalize is anchored to target + bounds so target data maps to [0, 1]. + """ + self._search_space = search_space.clone() + model_space = search_space.clone() + self._source_only_params: set[str] = set() + for name, param in self.joint_search_space.parameters.items(): + if name not in model_space.parameters and isinstance(param, RangeParameter): + model_space.add_parameter(param.clone()) + self._source_only_params.add(name) + self._model_space = model_space + def _transform_data( self, experiment_data: ExperimentData, @@ -505,94 +522,18 @@ def _get_task_datasets( ) return task_datasets - def _expand_ssd_to_joint_space( + def _get_target_data_parameters( self, - search_space_digest: SearchSpaceDigest, - ) -> SearchSpaceDigest: - """Expand SSD bounds and feature_names to cover the joint search space. - - The SSD produced by ``_get_fit_args`` reflects the target search space. - When source experiments have additional parameters, the model operates - in the full joint feature space. This method appends bounds and feature - names for source-only parameters so that input transforms receive - correct full-space bounds. + all_params: list[str], + ) -> list[str]: + """Filter a joint parameter list to target-only params + task feature. + + Source-only params (those added by ``_set_search_space`` from the joint + space) are excluded because the target experiment data does not have + those columns. Uses untransformed names, which are stable across + transforms (Range params are never renamed by OneHot, IntToFloat, etc.). """ - existing_names = set(search_space_digest.feature_names) - extra_names: list[str] = [] - extra_bounds: list[tuple[int | float, int | float]] = [] - # Only collect parameters absent from the target SSD. Shared - # parameters that appear in both target and source keep the target - # bounds -- source observations outside those bounds will normalize - # outside [0, 1]. This is intentional, as the GP hyperprior is calibrated - # for a __target__ task in [0, 1]^D. - for name, param in self.joint_search_space.parameters.items(): - if name not in existing_names and isinstance(param, RangeParameter): - extra_names.append(name) - extra_bounds.append((param.lower, param.upper)) - if not extra_names: - return search_space_digest - # Insert source-only params before the task feature - task_features = search_space_digest.task_features - if len(task_features) == 1: - tf_idx = task_features[0] - names = list(search_space_digest.feature_names) - bounds = list(search_space_digest.bounds) - # Raise if index-based fields (other than the task feature - # itself) reference indices at or above tf_idx, since we would - # need to shift them when inserting extra params. - for field_name in ( - "ordinal_features", - "categorical_features", - "fidelity_features", - ): - indices = getattr(search_space_digest, field_name) - if any(i >= tf_idx for i in indices): - raise UnsupportedError( - f"Cannot expand SSD: {field_name} contains index >= {tf_idx}." - ) - if any( - i >= tf_idx and i not in task_features - for i in search_space_digest.discrete_choices - ): - raise UnsupportedError( - f"Cannot expand SSD: discrete_choices contains index >= {tf_idx}." - ) - if search_space_digest.hierarchical_dependencies is not None and any( - i >= tf_idx for i in search_space_digest.hierarchical_dependencies - ): - raise UnsupportedError( - "Cannot expand SSD: hierarchical_dependencies contains " - f"index >= {tf_idx}." - ) - names[tf_idx:tf_idx] = extra_names - bounds[tf_idx:tf_idx] = extra_bounds - n_extra = len(extra_names) - new_task_features = [tf_idx + n_extra] - new_target_values = dict(search_space_digest.target_values) - if tf_idx in new_target_values: - new_target_values[new_task_features[0]] = new_target_values.pop(tf_idx) - new_discrete = dict(search_space_digest.discrete_choices) - if tf_idx in new_discrete: - new_discrete[new_task_features[0]] = new_discrete.pop(tf_idx) - return dataclasses.replace( - search_space_digest, - feature_names=names, - bounds=bounds, - task_features=new_task_features, - target_values=new_target_values, - discrete_choices=new_discrete, - ) - elif len(task_features) == 0: - # No task feature -- just append. - return dataclasses.replace( - search_space_digest, - feature_names=search_space_digest.feature_names + extra_names, - bounds=search_space_digest.bounds + extra_bounds, - ) - else: - raise UnsupportedError( - "Multiple task features are not supported in transfer learning." - ) + return [p for p in all_params if p not in self._source_only_params] def _fit( self, @@ -610,15 +551,20 @@ def _fit( if experiment_data.arm_data.empty: # Temporarily unset self.outcomes to avoid an error in _get_fit_args. self.outcomes = [] + # Pre-compute the joint param ordering (mirrors _get_fit_args logic) + # so we can derive the target-only subset for data extraction. + all_params = list(search_space.parameters.keys()) + task_name = Keys.TASK_FEATURE_NAME.value + if task_name in all_params: + all_params.remove(task_name) + all_params.append(task_name) + target_data_params = self._get_target_data_parameters(all_params) datasets, candidate_metadata, search_space_digest = self._get_fit_args( search_space=search_space, experiment_data=experiment_data, update_outcomes_and_parameters=True, + data_parameters=target_data_params, ) - # Expand SSD bounds to cover source-only params from the joint search - # space. This ensures Normalize (and other input transforms) get bounds - # for the full feature space, not just the target dims. - search_space_digest = self._expand_ssd_to_joint_space(search_space_digest) if experiment_data.arm_data.empty: self.outcomes = outcomes # Temporarily set datasets to None. We will construct empty datasets @@ -656,12 +602,13 @@ def _cross_validate( ) -> list[ObservationData]: if self.parameters is None: raise ValueError(FIT_MODEL_ERROR.format(action="_cross_validate")) + target_data_params = self._get_target_data_parameters(self.parameters) datasets, _, search_space_digest = self._get_fit_args( search_space=search_space, experiment_data=cv_training_data, update_outcomes_and_parameters=False, + data_parameters=target_data_params, ) - search_space_digest = self._expand_ssd_to_joint_space(search_space_digest) # Add the task feature to SSD, to ensure that a multi-task model is selected. if len(search_space_digest.task_features) > 1: raise UnsupportedError( @@ -728,9 +675,10 @@ def gen( if fixed_features is None: fixed_features = ObservationFeatures(parameters={}) fixed_features.parameters.setdefault(name, target_p.value) - # Fix source-only params so the optimizer doesn't search over them. - # Center is a reasonable default; LearnedFeatureImputation overwrites - # these with learned values when configured. + # Fix source-only params that ARE in the search space (e.g. injected + # as FixedParam with a backfill value) so the optimizer doesn't search + # over them. Params NOT in the search space are handled by the model + # internally (HeterogeneousMTGP natively, LFI for MultiTaskGP). joint_center = self.joint_search_space.compute_naive_center() for name, param in self.joint_search_space.parameters.items(): if name not in search_space.parameters and isinstance( @@ -739,6 +687,11 @@ def gen( if fixed_features is None: fixed_features = ObservationFeatures(parameters={}) fixed_features.parameters.setdefault(name, joint_center[name]) + # At gen time, restrict self.parameters to params that exist in the + # gen-time search space. Source-only params absent from _search_space + # are handled by the model (LFI imputation or HeterogeneousMTGP). + saved_parameters = self.parameters + self.parameters = self._get_target_data_parameters(self.parameters) generator_run = super().gen( n=n, search_space=search_space, @@ -747,6 +700,7 @@ def gen( fixed_features=fixed_features, model_gen_options=model_gen_options, ) + self.parameters = saved_parameters # Remove the parameters that are not in the target experiment's search # space, and update candidate_metadata_by_arm_signature to reflect the # new arm. We use the experiment's search space rather than diff --git a/ax/adapter/transfer_learning/tests/test_adapter.py b/ax/adapter/transfer_learning/tests/test_adapter.py index 30902bbad7a..5f220326b86 100644 --- a/ax/adapter/transfer_learning/tests/test_adapter.py +++ b/ax/adapter/transfer_learning/tests/test_adapter.py @@ -5,113 +5,192 @@ # pyre-strict -from unittest.mock import MagicMock, PropertyMock +from __future__ import annotations -from ax.adapter.transfer_learning.adapter import TransferLearningAdapter -from ax.core.parameter import ParameterType, RangeParameter -from ax.core.search_space import SearchSpace, SearchSpaceDigest -from ax.exceptions.core import UnsupportedError +from unittest.mock import MagicMock, patch + +import torch +from ax.adapter.transfer_learning.adapter import TL_EXP, TransferLearningAdapter +from ax.adapter.transforms.metadata_to_task import MetadataToTask +from ax.core.arm import Arm +from ax.core.auxiliary_source import AuxiliarySource +from ax.core.experiment import Experiment +from ax.core.parameter import FixedParameter, ParameterType, RangeParameter +from ax.core.search_space import SearchSpace +from ax.generators.torch.botorch_modular.generator import BoTorchGenerator +from ax.utils.common.constants import Keys from ax.utils.common.testutils import TestCase +from ax.utils.testing.core_stubs import get_experiment_with_observations +from ax.utils.testing.mock import mock_botorch_optimize +from pyre_extensions import none_throws -class ExpandSsdToJointSpaceTest(TestCase): - def setUp(self) -> None: - super().setUp() - self.adapter = MagicMock(spec=TransferLearningAdapter) - - def _make_joint_ss(self, params: dict[str, tuple[float, float]]) -> SearchSpace: - return SearchSpace( - parameters=[ - RangeParameter( - name=n, - lower=lo, - upper=hi, - parameter_type=ParameterType.FLOAT, - ) - for n, (lo, hi) in params.items() - ] - ) +def _make_ss(params: dict[str, tuple[float, float]]) -> SearchSpace: + return SearchSpace( + parameters=[ + RangeParameter( + name=n, + lower=lo, + upper=hi, + parameter_type=ParameterType.FLOAT, + ) + for n, (lo, hi) in params.items() + ] + ) - def test_no_extra_params_returns_unchanged(self) -> None: - type(self.adapter).joint_search_space = PropertyMock( - return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1)}) - ) - ssd = SearchSpaceDigest( - feature_names=["x1", "x2", "task"], - bounds=[(0, 1), (0, 1), (0, 2)], - task_features=[2], - target_values={2: 0}, - ) - result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd) - self.assertIs(result, ssd) - def test_single_task_feature_inserts_before_task(self) -> None: - type(self.adapter).joint_search_space = PropertyMock( - return_value=self._make_joint_ss( - {"x1": (0, 1), "x2": (0, 1), "x3": (-2, 5)} - ) - ) - ssd = SearchSpaceDigest( - feature_names=["x1", "x2", "task"], - bounds=[(0, 1), (0, 1), (0, 2)], - task_features=[2], - target_values={2: 0}, - ) - result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd) - self.assertEqual(result.feature_names, ["x1", "x2", "x3", "task"]) - self.assertEqual(result.bounds, [(0, 1), (0, 1), (-2, 5), (0, 2)]) - self.assertEqual(result.task_features, [3]) - self.assertEqual(result.target_values, {3: 0}) - - def test_zero_task_features_appends(self) -> None: - type(self.adapter).joint_search_space = PropertyMock( - return_value=self._make_joint_ss({"x1": (0, 1), "x2": (-1, 3)}) - ) - ssd = SearchSpaceDigest( - feature_names=["x1"], - bounds=[(0, 1)], - ) - result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd) - self.assertEqual(result.feature_names, ["x1", "x2"]) - self.assertEqual(result.bounds, [(0, 1), (-1, 3)]) +def _gen_experiment( + experiment_name: str, + num_trials: int, + search_space: SearchSpace | None = None, +) -> Experiment: + exp = get_experiment_with_observations( + observations=torch.rand(num_trials, 1).tolist(), + search_space=search_space, + ) + exp.name = experiment_name + return exp + + +class SetSearchSpaceTest(TestCase): + """_set_search_space adds source-only params to _model_space while + preserving target bounds for shared params.""" - def test_discrete_choices_on_task_feature_shifted(self) -> None: - type(self.adapter).joint_search_space = PropertyMock( - return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1), "x3": (0, 1)}) + def test_model_space_has_source_only_params(self) -> None: + target_ss = _make_ss({"x": (0, 1), "y": (0, 1)}) + source_ss = _make_ss({"x": (0, 5), "y": (0, 5), "z": (0, 5)}) + target_exp = _gen_experiment("target", num_trials=3, search_space=target_ss) + source_exp = _gen_experiment("source", num_trials=3, search_space=source_ss) + source_exp.status_quo = Arm(parameters={"x": 1.0, "y": 1.0, "z": 2.5}) + target_exp.auxiliary_experiments_by_purpose[TL_EXP] = [ + AuxiliarySource(experiment=source_exp) + ] + adapter = TransferLearningAdapter( + experiment=target_exp, + search_space=target_ss, + data=target_exp.lookup_data(), + generator=BoTorchGenerator(), + transforms=[MetadataToTask], + fit_on_init=False, ) - ssd = SearchSpaceDigest( - feature_names=["x1", "x2", "task"], - bounds=[(0, 1), (0, 1), (0, 2)], - task_features=[2], - target_values={2: 0}, - discrete_choices={2: [0, 1, 2]}, + with self.subTest("model_space_has_z"): + self.assertIn("z", adapter._model_space.parameters) + with self.subTest("search_space_has_z_as_fixed"): + self.assertIn("z", adapter._search_space.parameters) + self.assertIsInstance(adapter._search_space.parameters["z"], FixedParameter) + with self.subTest("shared_params_keep_target_bounds"): + x_param = adapter._model_space.parameters["x"] + self.assertEqual(x_param.lower, 0.0) + self.assertEqual(x_param.upper, 1.0) + with self.subTest("source_only_without_backfill"): + source_ss2 = _make_ss({"x": (0, 5), "w": (0, 10)}) + source_exp2 = _gen_experiment( + "source2", num_trials=3, search_space=source_ss2 + ) + target_exp.auxiliary_experiments_by_purpose[TL_EXP] = [ + AuxiliarySource(experiment=source_exp2) + ] + adapter2 = TransferLearningAdapter( + experiment=target_exp, + search_space=target_ss, + data=target_exp.lookup_data(), + generator=BoTorchGenerator(), + transforms=[MetadataToTask], + fit_on_init=False, + ) + self.assertIn("w", adapter2._model_space.parameters) + self.assertIsInstance(adapter2._model_space.parameters["w"], RangeParameter) + + +class GetTargetDataParametersTest(TestCase): + """_get_target_data_parameters filters joint params to target-only + task.""" + + def test_filters_source_only_params(self) -> None: + adapter = MagicMock(spec=TransferLearningAdapter) + adapter._source_only_params = {"z"} + joint_params = ["x", "y", "z", Keys.TASK_FEATURE_NAME.value] + result = TransferLearningAdapter._get_target_data_parameters( + adapter, joint_params ) - result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd) - self.assertEqual(result.discrete_choices, {3: [0, 1, 2]}) - self.assertEqual(result.task_features, [3]) + self.assertEqual(result, ["x", "y", Keys.TASK_FEATURE_NAME.value]) + + def test_no_source_only_params_returns_all(self) -> None: + adapter = MagicMock(spec=TransferLearningAdapter) + adapter._source_only_params = set() + params = ["x", "y", Keys.TASK_FEATURE_NAME.value] + result = TransferLearningAdapter._get_target_data_parameters(adapter, params) + self.assertEqual(result, params) - def test_hierarchical_dependencies_at_task_idx_raises(self) -> None: - type(self.adapter).joint_search_space = PropertyMock( - return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1), "x3": (0, 1)}) + +class FitWithDataParametersTest(TestCase): + """After _fit, self.parameters = joint params and SSD has joint bounds, + without needing _expand_ssd_to_joint_space.""" + + @mock_botorch_optimize + def test_fit_heterogeneous_ssd_has_joint_bounds(self) -> None: + target_ss = _make_ss({"x": (0, 1), "y": (0, 1)}) + source_ss = _make_ss({"x": (0, 5), "y": (0, 5), "z": (0, 5)}) + target_exp = _gen_experiment("target", num_trials=3, search_space=target_ss) + source_exp = _gen_experiment("source", num_trials=5, search_space=source_ss) + source_exp.status_quo = Arm(parameters={"x": 1.0, "y": 1.0, "z": 2.5}) + target_exp.auxiliary_experiments_by_purpose[TL_EXP] = [ + AuxiliarySource(experiment=source_exp) + ] + adapter = TransferLearningAdapter( + experiment=target_exp, + search_space=target_ss, + data=target_exp.lookup_data(), + generator=BoTorchGenerator(), + transforms=[MetadataToTask], + fit_on_init=False, ) - ssd = SearchSpaceDigest( - feature_names=["x1", "x2", "task"], - bounds=[(0, 1), (0, 1), (0, 2)], - task_features=[2], - target_values={2: 0}, - hierarchical_dependencies={2: {0: [1]}}, + adapter.outcomes = list( + none_throws(target_exp.optimization_config).objective.metric_names ) - with self.assertRaisesRegex(UnsupportedError, "hierarchical_dependencies"): - TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd) + adapter.parameters = list(target_exp.search_space.parameters) - def test_multiple_task_features_raises(self) -> None: - type(self.adapter).joint_search_space = PropertyMock( - return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1), "x3": (0, 1)}) - ) - ssd = SearchSpaceDigest( - feature_names=["x1", "task1", "task2"], - bounds=[(0, 1), (0, 1), (0, 1)], - task_features=[1, 2], + with patch.object( + adapter.generator, "fit", wraps=none_throws(adapter.generator).fit + ) as gen_fit: + experiment_data, search_space = adapter._process_and_transform_data( + experiment=target_exp, + ) + adapter._fit(search_space=search_space, experiment_data=experiment_data) + + gen_fit.assert_called_once() + ssd = gen_fit.call_args[1]["search_space_digest"] + # SSD feature names include source-only z (from the joint model space) + self.assertIn("z", ssd.feature_names) + # SSD bounds cover the joint space + z_idx = ssd.feature_names.index("z") + self.assertEqual(ssd.bounds[z_idx], (0.0, 5.0)) + # self.parameters is the joint set + self.assertIn("z", adapter.parameters) + # Task feature is last + self.assertEqual(adapter.parameters[-1], Keys.TASK_FEATURE_NAME.value) + + @mock_botorch_optimize + def test_fit_and_gen_heterogeneous(self) -> None: + """Full fit+gen round-trip with heterogeneous search spaces.""" + target_ss = _make_ss({"x": (0, 1), "y": (0, 1)}) + source_ss = _make_ss({"x": (0, 5), "y": (0, 5), "z": (0, 5)}) + target_exp = _gen_experiment("target", num_trials=3, search_space=target_ss) + source_exp = _gen_experiment("source", num_trials=5, search_space=source_ss) + source_exp.status_quo = Arm(parameters={"x": 1.0, "y": 1.0, "z": 2.5}) + target_exp.auxiliary_experiments_by_purpose[TL_EXP] = [ + AuxiliarySource(experiment=source_exp) + ] + adapter = TransferLearningAdapter( + experiment=target_exp, + search_space=target_ss, + data=target_exp.lookup_data(), + generator=BoTorchGenerator(), + transforms=[MetadataToTask], + fit_on_init=True, ) - with self.assertRaisesRegex(UnsupportedError, "Multiple task features"): - TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd) + gr = adapter.gen(n=1) + # Generated arms should only have target params + for arm in gr.arms: + self.assertIn("x", arm.parameters) + self.assertIn("y", arm.parameters) + self.assertNotIn("z", arm.parameters) diff --git a/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py b/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py index 133218de7f0..5ca2e6ff72a 100644 --- a/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py +++ b/ax/generators/torch/botorch_modular/input_constructors/input_transforms.py @@ -402,10 +402,14 @@ def _input_transform_argparse_learned_feature_imputation( torch.ones(d, dtype=dtype, device=torch_device), ] ) + # The target task is at position 0 (target_dataset is prepended above), so + # at posterior time — when X arrives without a task column — LFI applies + # the target task's imputation pattern. kwargs: dict[str, Any] = { "feature_indices": feature_indices, "d": d, "task_feature_index": task_feature_index, + "target_task": 0, "bounds": bounds, "device": torch_device, "dtype": dtype, diff --git a/ax/generators/torch/botorch_modular/surrogate.py b/ax/generators/torch/botorch_modular/surrogate.py index c586c61caec..09800143e2b 100644 --- a/ax/generators/torch/botorch_modular/surrogate.py +++ b/ax/generators/torch/botorch_modular/surrogate.py @@ -12,7 +12,7 @@ from collections import OrderedDict from collections.abc import Mapping, Sequence from copy import deepcopy -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from logging import Logger from typing import Any, cast @@ -67,6 +67,7 @@ from botorch.models.transforms.input import ( ChainedInputTransform, InputTransform, + LearnedFeatureImputation, Normalize, ) from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform @@ -1253,6 +1254,22 @@ def _submodel_input_constructor_mtgp( ) -> dict[str, Any]: if len(dataset.outcome_names) > 1: raise NotImplementedError("Multi-output Multi-task GPs are not yet supported.") + # If LearnedFeatureImputation is in the model config, tell construct_inputs + # to map heterogeneous per-task features to the full joint feature space. + # This must happen before the base call so construct_inputs can handle + # heterogeneous MultiTaskDatasets without raising. + uses_lfi = isinstance(model_config.input_transform_classes, list) and any( + issubclass(cls, LearnedFeatureImputation) + for cls in model_config.input_transform_classes + ) + if uses_lfi and "map_heterogeneous_to_full" not in model_config.model_options: + model_config = replace( + model_config, + model_options={ + **model_config.model_options, + "map_heterogeneous_to_full": True, + }, + ) formatted_model_inputs = _submodel_input_constructor_base( botorch_model_class=botorch_model_class, model_config=model_config, @@ -1266,9 +1283,12 @@ def _submodel_input_constructor_mtgp( # specify output tasks so that model.num_outputs = 1 # since the model only models a single outcome if formatted_model_inputs.get("output_tasks") is None: - # SSD doesn't use -1, so we need to normalize here + # SSD doesn't use -1, so we need to normalize here. Use the SSD's bound + # length since target_values is keyed by SSD column index — for + # heterogeneous MultiTaskDatasets this differs from the per-task + # dataset's feature_names length. task_feature = none_throws( - normalize_indices(indices=[task_feature], d=len(dataset.feature_names)) + normalize_indices(indices=[task_feature], d=len(search_space_digest.bounds)) )[0] if (search_space_digest.target_values is not None) and ( target_value := search_space_digest.target_values.get(task_feature) diff --git a/ax/generators/torch/botorch_modular/utils.py b/ax/generators/torch/botorch_modular/utils.py index f9f52a91eed..36b16557ac3 100644 --- a/ax/generators/torch/botorch_modular/utils.py +++ b/ax/generators/torch/botorch_modular/utils.py @@ -221,6 +221,32 @@ def use_model_list( return True +def _ensure_input_transform( + model_config: ModelConfig, + transform_cls: type[InputTransform], + position: int | None = None, +) -> None: + """Ensure ``transform_cls`` is in ``model_config.input_transform_classes``. + + If the user hasn't specified any transforms (``DEFAULT``), initialise the + list with ``[transform_cls]``. Otherwise append (or insert at ``position``) + only when the class isn't already present. Mutates ``model_config`` + in-place. + """ + itc = model_config.input_transform_classes + if isinstance(itc, list): + if transform_cls not in itc: + if position is not None: + itc.insert(position, transform_cls) + else: + itc.append(transform_cls) + else: + model_config.input_transform_classes = [transform_cls] + ito = model_config.input_transform_options or {} + ito.setdefault(transform_cls.__name__, {}) + model_config.input_transform_options = ito + + def copy_model_config_with_default_values( model_config: ModelConfig, dataset: SupervisedDataset, @@ -235,43 +261,15 @@ def copy_model_config_with_default_values( specified_model_class=model_config_copy.botorch_model_class, ) - # Handle heterogeneous multi-task datasets. + # Handle heterogeneous multi-task datasets: ensure Normalize is present + # and add LearnedFeatureImputation for models that don't handle + # heterogeneity natively. if isinstance(dataset, MultiTaskDataset) and dataset.has_heterogeneous_features: - if model_config_copy.botorch_model_class is HeterogeneousMTGP: - # HeterogeneousMTGP handles heterogeneity natively; just ensure - # Normalize is present (bounds are set later by the TL adapter). - itc = model_config_copy.input_transform_classes - if isinstance(itc, list): - if Normalize not in itc: - itc.insert(0, Normalize) - ito = model_config_copy.input_transform_options or {} - ito.setdefault("Normalize", {"bounds": None}) - model_config_copy.input_transform_options = ito - else: - model_config_copy.input_transform_classes = [Normalize] - ito = model_config_copy.input_transform_options or {} - ito.setdefault("Normalize", {"bounds": None}) - model_config_copy.input_transform_options = ito - else: - # Other models need Normalize + LFI to pad features via - # map_heterogeneous_to_full. - itc = model_config_copy.input_transform_classes - if isinstance(itc, list): - if Normalize not in itc: - itc.insert(0, Normalize) - ito = model_config_copy.input_transform_options or {} - ito.setdefault("Normalize", {"bounds": None}) - model_config_copy.input_transform_options = ito - if LearnedFeatureImputation not in itc: - itc.append(LearnedFeatureImputation) - else: - model_config_copy.input_transform_classes = [ - Normalize, - LearnedFeatureImputation, - ] - ito = model_config_copy.input_transform_options or {} - ito.setdefault("Normalize", {"bounds": None}) - model_config_copy.input_transform_options = ito + _ensure_input_transform(model_config_copy, Normalize, position=0) + if model_config_copy.botorch_model_class is not None and not issubclass( + model_config_copy.botorch_model_class, HeterogeneousMTGP + ): + _ensure_input_transform(model_config_copy, LearnedFeatureImputation) if model_config_copy.mll_class is None: model_config_copy.mll_class = ( @@ -321,16 +319,15 @@ def choose_model_class( ) # Check for heterogeneous multi-task datasets. If a model class was - # explicitly specified, respect it; otherwise default to HeterogeneousMTGP. + # explicitly specified, respect it; otherwise default to MultiTaskGP + # (LearnedFeatureImputation handles missing features). if ( search_space_digest.task_features and isinstance(dataset, MultiTaskDataset) and dataset.has_heterogeneous_features ): model_class = ( - specified_model_class - if specified_model_class is not None - else HeterogeneousMTGP + specified_model_class if specified_model_class is not None else MultiTaskGP ) logger.debug(f"Chose BoTorch model class: {model_class}.") return model_class diff --git a/ax/generators/torch/tests/test_surrogate.py b/ax/generators/torch/tests/test_surrogate.py index 82e4c824bdc..80dfa36b178 100644 --- a/ax/generators/torch/tests/test_surrogate.py +++ b/ax/generators/torch/tests/test_surrogate.py @@ -32,6 +32,7 @@ _construct_specified_input_transforms, _extract_model_kwargs, _make_botorch_input_transform, + _submodel_input_constructor_mtgp, submodel_input_constructor, Surrogate, SurrogateSpec, @@ -59,7 +60,12 @@ from botorch.models.model import Model, ModelList # noqa: F401 -- used in Mocks. from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood -from botorch.models.transforms.input import ChainedInputTransform, Log10, Normalize +from botorch.models.transforms.input import ( + ChainedInputTransform, + LearnedFeatureImputation, + Log10, + Normalize, +) from botorch.models.transforms.outcome import OutcomeTransform, Standardize from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset from botorch.utils.evaluation import compute_in_sample_model_fit_metric @@ -277,6 +283,74 @@ def test__make_botorch_input_transform(self) -> None: self.assertEqual(transform.indices.tolist(), [0]) self.assertEqual(transform.bounds.tolist(), [[1.0], [5.0]]) + def test_submodel_input_constructor_mtgp_map_heterogeneous(self) -> None: + """_submodel_input_constructor_mtgp passes map_heterogeneous_to_full + to construct_inputs when LFI is configured, enabling zero-padded + heterogeneous datasets to be used with MultiTaskGP.""" + ds_target = SupervisedDataset( + X=torch.tensor([[1.0, 0.0], [2.0, 0.0]]), + Y=torch.tensor([[1.0], [2.0]]), + feature_names=["x0", "task"], + outcome_names=["y_task_0"], + ) + ds_source = SupervisedDataset( + X=torch.tensor([[3.0, 4.0, 1.0], [5.0, 6.0, 1.0]]), + Y=torch.tensor([[3.0], [4.0]]), + feature_names=["x0", "x1", "task"], + outcome_names=["y_task_1"], + ) + mt_dataset = MultiTaskDataset( + datasets=[ds_target, ds_source], + target_outcome_name="y_task_0", + task_feature_index=-1, + ) + self.assertTrue(mt_dataset.has_heterogeneous_features) + ssd = SearchSpaceDigest( + feature_names=["x0", "x1", "task"], + bounds=[(0.0, 5.0), (0.0, 6.0), (0.0, 1.0)], + task_features=[2], + target_values={2: 0.0}, + ) + surrogate = Surrogate( + surrogate_spec=SurrogateSpec( + model_configs=[ModelConfig(botorch_model_class=MultiTaskGP)] + ) + ) + + with self.subTest("with LFI — construct_inputs succeeds"): + config_with_lfi = ModelConfig( + botorch_model_class=MultiTaskGP, + input_transform_classes=[Normalize, LearnedFeatureImputation], + ) + result = _submodel_input_constructor_mtgp( + botorch_model_class=MultiTaskGP, + model_config=config_with_lfi, + dataset=mt_dataset, + search_space_digest=ssd, + surrogate=surrogate, + ) + self.assertEqual(result["train_X"].shape[-1], 3) + + with self.subTest("without LFI — construct_inputs raises"): + from botorch.exceptions.errors import ( + UnsupportedError as BotorchUnsupportedError, + ) + + config_no_lfi = ModelConfig( + botorch_model_class=MultiTaskGP, + input_transform_classes=[Normalize], + ) + with self.assertRaisesRegex( + BotorchUnsupportedError, "heterogeneous feature sets" + ): + _submodel_input_constructor_mtgp( + botorch_model_class=MultiTaskGP, + model_config=config_no_lfi, + dataset=mt_dataset, + search_space_digest=ssd, + surrogate=surrogate, + ) + class SurrogateTest(TestCase): def setUp(self, cuda: bool = False) -> None: diff --git a/ax/generators/torch/tests/test_utils.py b/ax/generators/torch/tests/test_utils.py index 66ae1e11669..8126998aa73 100644 --- a/ax/generators/torch/tests/test_utils.py +++ b/ax/generators/torch/tests/test_utils.py @@ -70,7 +70,7 @@ from botorch.models.model_list_gp_regression import ModelListGP from botorch.models.multitask import MultiTaskGP from botorch.models.pairwise_gp import PairwiseGP -from botorch.models.transforms.input import LearnedFeatureImputation, Normalize +from botorch.models.transforms.input import LearnedFeatureImputation, Normalize, Warp from botorch.posteriors.ensemble import EnsemblePosterior from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset from botorch.utils.types import DEFAULT @@ -186,9 +186,9 @@ def test_choose_model_class_heterogeneous_task_features(self) -> None: mt_dataset = self._get_heterogeneous_mt_dataset() ssd = dataclasses.replace(self.search_space_digest, task_features=[-1]) - # Default: HeterogeneousMTGP. + # Default: MultiTaskGP (LearnedFeatureImputation handles missing features). self.assertEqual( - HeterogeneousMTGP, + MultiTaskGP, choose_model_class(dataset=mt_dataset, search_space_digest=ssd), ) @@ -233,19 +233,23 @@ def test_copy_model_config_heterogeneous_mtgp(self) -> None: mt_dataset = self._get_heterogeneous_mt_dataset() ssd = dataclasses.replace(self.search_space_digest, task_features=[-1]) - # Default (no model class specified) -> HeterogeneousMTGP. - # LFI is NOT injected; input_transform_classes stays DEFAULT. + # Default (no model class specified) -> MultiTaskGP. + # LFI is injected for MultiTaskGP with heterogeneous data. updated_config = copy_model_config_with_default_values( model_config=ModelConfig(), dataset=mt_dataset, search_space_digest=ssd, ) - self.assertEqual(updated_config.botorch_model_class, HeterogeneousMTGP) - self.assertEqual(updated_config.input_transform_classes, [Normalize]) + self.assertEqual(updated_config.botorch_model_class, MultiTaskGP) self.assertEqual( - none_throws(updated_config.input_transform_options), - {"Normalize": {"bounds": None}}, + updated_config.input_transform_classes, + [Normalize, LearnedFeatureImputation], ) + # LFI is present in transform classes but absent from options; its + # argparse computes kwargs from the dataset at construction time. + ito = none_throws(updated_config.input_transform_options) + self.assertEqual(ito, {"Normalize": {}}) + self.assertNotIn("LearnedFeatureImputation", ito) # Explicit HeterogeneousMTGP behaves the same. updated_config = copy_model_config_with_default_values( @@ -257,7 +261,7 @@ def test_copy_model_config_heterogeneous_mtgp(self) -> None: self.assertEqual(updated_config.input_transform_classes, [Normalize]) self.assertEqual( none_throws(updated_config.input_transform_options), - {"Normalize": {"bounds": None}}, + {"Normalize": {}}, ) def test_copy_model_config_mtgp_with_lfi_injection(self) -> None: @@ -302,6 +306,54 @@ def test_copy_model_config_does_not_add_normalize_for_other_models(self) -> None self.assertEqual(updated_config.input_transform_classes, DEFAULT) self.assertEqual(updated_config.input_transform_options, {}) + def test_copy_model_config_adds_imputation_for_heterogeneous(self) -> None: + mt_dataset = self._get_heterogeneous_mt_dataset() + ssd = dataclasses.replace(self.search_space_digest, task_features=[-1]) + + with self.subTest("no_input_transform_classes"): + model_config = ModelConfig(botorch_model_class=MultiTaskGP) + updated_config = copy_model_config_with_default_values( + model_config=model_config, + dataset=mt_dataset, + search_space_digest=ssd, + ) + self.assertEqual(updated_config.botorch_model_class, MultiTaskGP) + self.assertEqual( + updated_config.input_transform_classes, + [Normalize, LearnedFeatureImputation], + ) + + with self.subTest("existing_transform_classes"): + model_config = ModelConfig( + botorch_model_class=MultiTaskGP, + input_transform_classes=[Warp], + input_transform_options={"Warp": {}}, + ) + updated_config = copy_model_config_with_default_values( + model_config=model_config, + dataset=mt_dataset, + search_space_digest=ssd, + ) + self.assertEqual( + updated_config.input_transform_classes, + [Normalize, Warp, LearnedFeatureImputation], + ) + + with self.subTest("imputation_already_present"): + model_config = ModelConfig( + botorch_model_class=MultiTaskGP, + input_transform_classes=[Normalize, LearnedFeatureImputation], + ) + updated_config = copy_model_config_with_default_values( + model_config=model_config, + dataset=mt_dataset, + search_space_digest=ssd, + ) + self.assertEqual( + updated_config.input_transform_classes, + [Normalize, LearnedFeatureImputation], + ) + def test_choose_model_class_discrete_features(self) -> None: # With discrete features, use MixedSingleTaskyGP. self.assertEqual(