Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ax/adapter/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
]

Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY, StandardizeY]
TL_Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY]

# Expected `List[Type[Transform]]` for 2nd anonymous parameter to
# call `list.__add__` but got `List[Type[SearchSpaceToChoice]]`.
Expand Down
10 changes: 6 additions & 4 deletions ax/adapter/transfer_learning/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
Generators,
GeneratorSetup,
MBM_X_trans,
Y_trans,
TL_Y_trans,
)
from ax.adapter.torch import FIT_MODEL_ERROR, TorchAdapter
from ax.adapter.transfer_learning.utils import get_joint_search_space
Expand Down Expand Up @@ -54,6 +54,7 @@
from ax.utils.common.logger import get_logger
from botorch.models.multitask import MultiTaskGP
from botorch.models.transforms.input import InputTransform, Normalize
from botorch.models.transforms.outcome import StratifiedStandardize
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
from gpytorch.kernels.kernel import Kernel
from pyre_extensions import assert_is_instance
Expand Down Expand Up @@ -793,7 +794,7 @@ def transfer_learning_generator_specs_constructor(
Args:
model_class: The MultiTask BoTorch Model to use in the BOTL.
transform: Optional list of transforms to use in the Adapter.
Defaults to MBM_X_trans + [MetadataToTask] + Y_trans.
Defaults to MBM_X_trans + [MetadataToTask] + TL_Y_trans.
jit_compile: Whether to use jit compilation in Pyro when the fully Bayesian
model is used.
torch_device: What torch device to use (defaults to None, i.e. falls back to
Expand Down Expand Up @@ -828,7 +829,7 @@ def transfer_learning_generator_specs_constructor(
input_transform_options: dict[str, dict[str, Any]] = {
"Normalize": {},
}
transforms = transforms or MBM_X_trans + [MetadataToTask] + Y_trans
transforms = transforms or MBM_X_trans + [MetadataToTask] + TL_Y_trans
transform_configs = get_derelativize_config(
derelativize_with_raw_status_quo=derelativize_with_raw_status_quo
)
Expand All @@ -846,6 +847,7 @@ def transfer_learning_generator_specs_constructor(
botorch_model_class=model_class,
model_options=botorch_model_kwargs or {},
input_transform_classes=input_transform_classes,
outcome_transform_classes=[StratifiedStandardize],
input_transform_options=input_transform_options,
mll_options=mll_kwargs,
covar_module_class=covar_module_class,
Expand Down Expand Up @@ -887,5 +889,5 @@ def transfer_learning_generator_specs_constructor(
GENERATOR_KEY_TO_GENERATOR_SETUP["BOTL"] = GeneratorSetup(
adapter_class=TransferLearningAdapter,
generator_class=BoTorchGenerator,
transforms=MBM_X_trans + [MetadataToTask] + Y_trans,
transforms=MBM_X_trans + [MetadataToTask] + TL_Y_trans,
)
Original file line number Diff line number Diff line change
Expand Up @@ -357,40 +357,12 @@ def _input_transform_argparse_learned_feature_imputation(
)
input_transform_options = input_transform_options or {}

# Order datasets: target first, then remaining (same as HeterogeneousMTGP).
child_datasets = dataset.datasets.copy()
target_dataset = child_datasets.pop(dataset.target_outcome_name)
all_datasets = [target_dataset] + list(child_datasets.values())

# The feature_names[:task_feature_index] slice only works when the task
# column is the last column (index == -1). Guard against other positions
# the same way ImputedMultiTaskGP.construct_inputs does.
# Delegate feature ordering and index computation to MultiTaskDataset.
all_datasets, feature_indices_list, d = dataset.get_heterogeneous_feature_mapping()
feature_indices = dict(enumerate(feature_indices_list))
task_feature_index = (
dataset.task_feature_index if (dataset.task_feature_index is not None) else -1
dataset.task_feature_index if dataset.task_feature_index is not None else -1
)
if task_feature_index != -1:
raise NotImplementedError(
"LearnedFeatureImputation argparse only supports "
"task_feature_index == -1. Got "
f"task_feature_index={task_feature_index}."
)

# Use target's feature order as canonical (NO alphabetical sort).
# Source-only features are appended at the end.
all_features: list[str] = list(target_dataset.feature_names[:task_feature_index])
for ds in all_datasets[1:]:
for fn in ds.feature_names[:task_feature_index]:
if fn not in all_features:
all_features.append(fn)
d = len(all_features)

# Map each task's features to indices in the global feature space.
feature_indices = {
task_idx: [
all_features.index(fn) for fn in ds.feature_names[:task_feature_index]
]
for task_idx, ds in enumerate(all_datasets)
}

dtype = torch_dtype or torch.float64
# Constrain imputation values to [0, 1] since the preceding Normalize
Expand All @@ -402,10 +374,14 @@ def _input_transform_argparse_learned_feature_imputation(
torch.ones(d, dtype=dtype, device=torch_device),
]
)
# The target task is at position 0 (target_dataset is prepended above), so
# at posterior time — when X arrives without a task column — LFI applies
# the target task's imputation pattern.
kwargs: dict[str, Any] = {
"feature_indices": feature_indices,
"d": d,
"task_feature_index": task_feature_index,
"target_task": 0,
"bounds": bounds,
"device": torch_device,
"dtype": dtype,
Expand Down
26 changes: 23 additions & 3 deletions ax/generators/torch/botorch_modular/surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from collections import OrderedDict
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass, field
from dataclasses import dataclass, field, replace
from logging import Logger
from typing import Any, cast

Expand Down Expand Up @@ -67,6 +67,7 @@
from botorch.models.transforms.input import (
ChainedInputTransform,
InputTransform,
LearnedFeatureImputation,
Normalize,
)
from botorch.models.transforms.outcome import ChainedOutcomeTransform, OutcomeTransform
Expand Down Expand Up @@ -1253,6 +1254,22 @@ def _submodel_input_constructor_mtgp(
) -> dict[str, Any]:
if len(dataset.outcome_names) > 1:
raise NotImplementedError("Multi-output Multi-task GPs are not yet supported.")
# If LearnedFeatureImputation is in the model config, tell construct_inputs
# to map heterogeneous per-task features to the full joint feature space.
# This must happen before the base call so construct_inputs can handle
# heterogeneous MultiTaskDatasets without raising.
uses_lfi = isinstance(model_config.input_transform_classes, list) and any(
issubclass(cls, LearnedFeatureImputation)
for cls in model_config.input_transform_classes
)
if uses_lfi and "map_heterogeneous_to_full" not in model_config.model_options:
model_config = replace(
model_config,
model_options={
**model_config.model_options,
"map_heterogeneous_to_full": True,
},
)
formatted_model_inputs = _submodel_input_constructor_base(
botorch_model_class=botorch_model_class,
model_config=model_config,
Expand All @@ -1266,9 +1283,12 @@ def _submodel_input_constructor_mtgp(
# specify output tasks so that model.num_outputs = 1
# since the model only models a single outcome
if formatted_model_inputs.get("output_tasks") is None:
# SSD doesn't use -1, so we need to normalize here
# SSD doesn't use -1, so we need to normalize here. Use the SSD's bound
# length since target_values is keyed by SSD column index — for
# heterogeneous MultiTaskDatasets this differs from the per-task
# dataset's feature_names length.
task_feature = none_throws(
normalize_indices(indices=[task_feature], d=len(dataset.feature_names))
normalize_indices(indices=[task_feature], d=len(search_space_digest.bounds))
)[0]
if (search_space_digest.target_values is not None) and (
target_value := search_space_digest.target_values.get(task_feature)
Expand Down
77 changes: 37 additions & 40 deletions ax/generators/torch/botorch_modular/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,32 @@ def use_model_list(
return True


def _ensure_input_transform(
model_config: ModelConfig,
transform_cls: type[InputTransform],
position: int | None = None,
) -> None:
"""Ensure ``transform_cls`` is in ``model_config.input_transform_classes``.

If the user hasn't specified any transforms (``DEFAULT``), initialise the
list with ``[transform_cls]``. Otherwise append (or insert at ``position``)
only when the class isn't already present. Mutates ``model_config``
in-place.
"""
itc = model_config.input_transform_classes
if isinstance(itc, list):
if transform_cls not in itc:
if position is not None:
itc.insert(position, transform_cls)
else:
itc.append(transform_cls)
else:
model_config.input_transform_classes = [transform_cls]
ito = model_config.input_transform_options or {}
ito.setdefault(transform_cls.__name__, {})
model_config.input_transform_options = ito


def copy_model_config_with_default_values(
model_config: ModelConfig,
dataset: SupervisedDataset,
Expand All @@ -235,43 +261,15 @@ def copy_model_config_with_default_values(
specified_model_class=model_config_copy.botorch_model_class,
)

# Handle heterogeneous multi-task datasets.
# Handle heterogeneous multi-task datasets: ensure Normalize is present
# and add LearnedFeatureImputation for models that don't handle
# heterogeneity natively.
if isinstance(dataset, MultiTaskDataset) and dataset.has_heterogeneous_features:
if model_config_copy.botorch_model_class is HeterogeneousMTGP:
# HeterogeneousMTGP handles heterogeneity natively; just ensure
# Normalize is present (bounds are set later by the TL adapter).
itc = model_config_copy.input_transform_classes
if isinstance(itc, list):
if Normalize not in itc:
itc.insert(0, Normalize)
ito = model_config_copy.input_transform_options or {}
ito.setdefault("Normalize", {"bounds": None})
model_config_copy.input_transform_options = ito
else:
model_config_copy.input_transform_classes = [Normalize]
ito = model_config_copy.input_transform_options or {}
ito.setdefault("Normalize", {"bounds": None})
model_config_copy.input_transform_options = ito
else:
# Other models need Normalize + LFI to pad features via
# map_heterogeneous_to_full.
itc = model_config_copy.input_transform_classes
if isinstance(itc, list):
if Normalize not in itc:
itc.insert(0, Normalize)
ito = model_config_copy.input_transform_options or {}
ito.setdefault("Normalize", {"bounds": None})
model_config_copy.input_transform_options = ito
if LearnedFeatureImputation not in itc:
itc.append(LearnedFeatureImputation)
else:
model_config_copy.input_transform_classes = [
Normalize,
LearnedFeatureImputation,
]
ito = model_config_copy.input_transform_options or {}
ito.setdefault("Normalize", {"bounds": None})
model_config_copy.input_transform_options = ito
_ensure_input_transform(model_config_copy, Normalize, position=0)
if model_config_copy.botorch_model_class is not None and not issubclass(
model_config_copy.botorch_model_class, HeterogeneousMTGP
):
_ensure_input_transform(model_config_copy, LearnedFeatureImputation)

if model_config_copy.mll_class is None:
model_config_copy.mll_class = (
Expand Down Expand Up @@ -321,16 +319,15 @@ def choose_model_class(
)

# Check for heterogeneous multi-task datasets. If a model class was
# explicitly specified, respect it; otherwise default to HeterogeneousMTGP.
# explicitly specified, respect it; otherwise default to MultiTaskGP
# (LearnedFeatureImputation handles missing features).
if (
search_space_digest.task_features
and isinstance(dataset, MultiTaskDataset)
and dataset.has_heterogeneous_features
):
model_class = (
specified_model_class
if specified_model_class is not None
else HeterogeneousMTGP
specified_model_class if specified_model_class is not None else MultiTaskGP
)
logger.debug(f"Chose BoTorch model class: {model_class}.")
return model_class
Expand Down
4 changes: 1 addition & 3 deletions ax/generators/torch/tests/test_input_transform_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,9 +475,7 @@ def test_argparse_learned_feature_imputation(self) -> None:
target_outcome_name="y0",
task_feature_index=0,
)
with self.assertRaisesRegex(
NotImplementedError, "task_feature_index == -1"
):
with self.assertRaisesRegex(NotImplementedError, "task_feature_index.*-1"):
input_transform_argparse(
LearnedFeatureImputation,
dataset=bad_ds,
Expand Down
74 changes: 73 additions & 1 deletion ax/generators/torch/tests/test_surrogate.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
_construct_specified_input_transforms,
_extract_model_kwargs,
_make_botorch_input_transform,
_submodel_input_constructor_mtgp,
submodel_input_constructor,
Surrogate,
SurrogateSpec,
Expand Down Expand Up @@ -59,7 +60,12 @@
from botorch.models.model import Model, ModelList # noqa: F401 -- used in Mocks.
from botorch.models.multitask import MultiTaskGP
from botorch.models.pairwise_gp import PairwiseGP, PairwiseLaplaceMarginalLogLikelihood
from botorch.models.transforms.input import ChainedInputTransform, Log10, Normalize
from botorch.models.transforms.input import (
ChainedInputTransform,
LearnedFeatureImputation,
Log10,
Normalize,
)
from botorch.models.transforms.outcome import OutcomeTransform, Standardize
from botorch.utils.datasets import MultiTaskDataset, SupervisedDataset
from botorch.utils.evaluation import compute_in_sample_model_fit_metric
Expand Down Expand Up @@ -277,6 +283,72 @@ def test__make_botorch_input_transform(self) -> None:
self.assertEqual(transform.indices.tolist(), [0])
self.assertEqual(transform.bounds.tolist(), [[1.0], [5.0]])

def test_submodel_input_constructor_mtgp_map_heterogeneous(self) -> None:
"""_submodel_input_constructor_mtgp passes map_heterogeneous_to_full
to construct_inputs when LFI is configured, enabling zero-padded
heterogeneous datasets to be used with MultiTaskGP."""
ds_target = SupervisedDataset(
X=torch.tensor([[1.0, 0.0], [2.0, 0.0]]),
Y=torch.tensor([[1.0], [2.0]]),
feature_names=["x0", "task"],
outcome_names=["y_task_0"],
)
ds_source = SupervisedDataset(
X=torch.tensor([[3.0, 4.0, 1.0], [5.0, 6.0, 1.0]]),
Y=torch.tensor([[3.0], [4.0]]),
feature_names=["x0", "x1", "task"],
outcome_names=["y_task_1"],
)
mt_dataset = MultiTaskDataset(
datasets=[ds_target, ds_source],
target_outcome_name="y_task_0",
task_feature_index=-1,
)
self.assertTrue(mt_dataset.has_heterogeneous_features)
ssd = SearchSpaceDigest(
feature_names=["x0", "x1", "task"],
bounds=[(0.0, 5.0), (0.0, 6.0), (0.0, 1.0)],
task_features=[2],
target_values={2: 0.0},
)
surrogate = Surrogate(
surrogate_spec=SurrogateSpec(
model_configs=[ModelConfig(botorch_model_class=MultiTaskGP)]
)
)

with self.subTest("with LFI — construct_inputs succeeds"):
config_with_lfi = ModelConfig(
botorch_model_class=MultiTaskGP,
input_transform_classes=[Normalize, LearnedFeatureImputation],
)
result = _submodel_input_constructor_mtgp(
botorch_model_class=MultiTaskGP,
model_config=config_with_lfi,
dataset=mt_dataset,
search_space_digest=ssd,
surrogate=surrogate,
)
self.assertEqual(result["train_X"].shape[-1], 3)

with self.subTest("without LFI — construct_inputs raises"):
from botorch.exceptions.errors import (
UnsupportedError as BotorchUnsupportedError,
)

config_no_lfi = ModelConfig(
botorch_model_class=MultiTaskGP,
input_transform_classes=[Normalize],
)
with self.assertRaises(BotorchUnsupportedError):
_submodel_input_constructor_mtgp(
botorch_model_class=MultiTaskGP,
model_config=config_no_lfi,
dataset=mt_dataset,
search_space_digest=ssd,
surrogate=surrogate,
)


class SurrogateTest(TestCase):
def setUp(self, cuda: bool = False) -> None:
Expand Down
Loading
Loading