Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions ax/adapter/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@
]

Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY, StandardizeY]
TL_Y_trans: list[type[Transform]] = [Derelativize, Winsorize, BilogY]

# Expected `List[Type[Transform]]` for 2nd anonymous parameter to
# call `list.__add__` but got `List[Type[SearchSpaceToChoice]]`.
Expand Down
21 changes: 19 additions & 2 deletions ax/adapter/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,7 @@ def _get_fit_args(
search_space: SearchSpace,
experiment_data: ExperimentData,
update_outcomes_and_parameters: bool,
data_parameters: list[str] | None = None,
) -> tuple[
list[SupervisedDataset],
list[list[TCandidateMetadata]] | None,
Expand All @@ -791,6 +792,11 @@ def _get_fit_args(
update_outcomes_and_parameters: Whether to update `self.outcomes` with
all outcomes found in the observations and `self.parameters` with
all parameters in the search space. Typically only used in `_fit`.
data_parameters: When provided, columns to extract from
``experiment_data``. Defaults to ``self.parameters``. Useful when
the model space is larger than the data (e.g. transfer learning
with heterogeneous search spaces where the data only contains
target columns but the model operates in a joint feature space).

Returns:
The datasets & metadata, extracted from the ``experiment_data``, and the
Expand Down Expand Up @@ -818,12 +824,23 @@ def _get_fit_args(
search_space_digest = extract_search_space_digest(
search_space=search_space, param_names=self.parameters
)
extract_params = (
data_parameters if data_parameters is not None else self.parameters
)
# When data_parameters differs from self.parameters, the SSD's
# task_feature indices don't match the data columns. Pass None
# to skip MultiTaskDataset wrapping (the caller handles it).
extraction_ssd = (
None
if data_parameters is not None and data_parameters != self.parameters
else search_space_digest
)
# Convert observations to datasets
datasets, ordered_outcomes, candidate_metadata = self._convert_experiment_data(
experiment_data=experiment_data,
outcomes=self.outcomes,
parameters=self.parameters,
search_space_digest=search_space_digest,
parameters=extract_params,
search_space_digest=extraction_ssd,
)
datasets = self._update_w_aux_exp_datasets(datasets=datasets)

Expand Down
192 changes: 73 additions & 119 deletions ax/adapter/transfer_learning/adapter.py

Large diffs are not rendered by default.

273 changes: 177 additions & 96 deletions ax/adapter/transfer_learning/tests/test_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,113 +5,194 @@

# pyre-strict

from unittest.mock import MagicMock, PropertyMock
from __future__ import annotations

from ax.adapter.transfer_learning.adapter import TransferLearningAdapter
from ax.core.parameter import ParameterType, RangeParameter
from ax.core.search_space import SearchSpace, SearchSpaceDigest
from ax.exceptions.core import UnsupportedError
from unittest.mock import MagicMock, patch

import torch
from ax.adapter.transfer_learning.adapter import TL_EXP, TransferLearningAdapter
from ax.adapter.transforms.metadata_to_task import MetadataToTask
from ax.core.arm import Arm
from ax.core.auxiliary_source import AuxiliarySource
from ax.core.experiment import Experiment
from ax.core.parameter import FixedParameter, ParameterType, RangeParameter
from ax.core.search_space import SearchSpace
from ax.generators.torch.botorch_modular.generator import BoTorchGenerator
from ax.utils.common.constants import Keys
from ax.utils.common.testutils import TestCase
from ax.utils.testing.core_stubs import get_experiment_with_observations
from ax.utils.testing.mock import mock_botorch_optimize
from pyre_extensions import assert_is_instance, none_throws


class ExpandSsdToJointSpaceTest(TestCase):
def setUp(self) -> None:
super().setUp()
self.adapter = MagicMock(spec=TransferLearningAdapter)

def _make_joint_ss(self, params: dict[str, tuple[float, float]]) -> SearchSpace:
return SearchSpace(
parameters=[
RangeParameter(
name=n,
lower=lo,
upper=hi,
parameter_type=ParameterType.FLOAT,
)
for n, (lo, hi) in params.items()
]
)
def _make_ss(params: dict[str, tuple[float, float]]) -> SearchSpace:
return SearchSpace(
parameters=[
RangeParameter(
name=n,
lower=lo,
upper=hi,
parameter_type=ParameterType.FLOAT,
)
for n, (lo, hi) in params.items()
]
)

def test_no_extra_params_returns_unchanged(self) -> None:
type(self.adapter).joint_search_space = PropertyMock(
return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1)})
)
ssd = SearchSpaceDigest(
feature_names=["x1", "x2", "task"],
bounds=[(0, 1), (0, 1), (0, 2)],
task_features=[2],
target_values={2: 0},
)
result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)
self.assertIs(result, ssd)

def test_single_task_feature_inserts_before_task(self) -> None:
type(self.adapter).joint_search_space = PropertyMock(
return_value=self._make_joint_ss(
{"x1": (0, 1), "x2": (0, 1), "x3": (-2, 5)}
)
)
ssd = SearchSpaceDigest(
feature_names=["x1", "x2", "task"],
bounds=[(0, 1), (0, 1), (0, 2)],
task_features=[2],
target_values={2: 0},
)
result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)
self.assertEqual(result.feature_names, ["x1", "x2", "x3", "task"])
self.assertEqual(result.bounds, [(0, 1), (0, 1), (-2, 5), (0, 2)])
self.assertEqual(result.task_features, [3])
self.assertEqual(result.target_values, {3: 0})

def test_zero_task_features_appends(self) -> None:
type(self.adapter).joint_search_space = PropertyMock(
return_value=self._make_joint_ss({"x1": (0, 1), "x2": (-1, 3)})
)
ssd = SearchSpaceDigest(
feature_names=["x1"],
bounds=[(0, 1)],
)
result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)
self.assertEqual(result.feature_names, ["x1", "x2"])
self.assertEqual(result.bounds, [(0, 1), (-1, 3)])
def _gen_experiment(
experiment_name: str,
num_trials: int,
search_space: SearchSpace | None = None,
) -> Experiment:
exp = get_experiment_with_observations(
observations=torch.rand(num_trials, 1).tolist(),
search_space=search_space,
)
exp.name = experiment_name
return exp

def test_discrete_choices_on_task_feature_shifted(self) -> None:
type(self.adapter).joint_search_space = PropertyMock(
return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1), "x3": (0, 1)})

class SetSearchSpaceTest(TestCase):
"""_set_search_space adds source-only params to _model_space while
preserving target bounds for shared params."""

def test_model_space_has_source_only_params(self) -> None:
target_ss = _make_ss({"x": (0, 1), "y": (0, 1)})
source_ss = _make_ss({"x": (0, 5), "y": (0, 5), "z": (0, 5)})
target_exp = _gen_experiment("target", num_trials=3, search_space=target_ss)
source_exp = _gen_experiment("source", num_trials=3, search_space=source_ss)
source_exp.status_quo = Arm(parameters={"x": 1.0, "y": 1.0, "z": 2.5})
target_exp.auxiliary_experiments_by_purpose[TL_EXP] = [
AuxiliarySource(experiment=source_exp)
]
adapter = TransferLearningAdapter(
experiment=target_exp,
search_space=target_ss,
data=target_exp.lookup_data(),
generator=BoTorchGenerator(),
transforms=[MetadataToTask],
fit_on_init=False,
)
ssd = SearchSpaceDigest(
feature_names=["x1", "x2", "task"],
bounds=[(0, 1), (0, 1), (0, 2)],
task_features=[2],
target_values={2: 0},
discrete_choices={2: [0, 1, 2]},
with self.subTest("model_space_has_z"):
self.assertIn("z", adapter._model_space.parameters)
with self.subTest("search_space_unchanged"):
self.assertNotIn("z", adapter._search_space.parameters)
with self.subTest("backfilled_not_source_only"):
self.assertNotIn("z", adapter._source_only_params)
with self.subTest("shared_params_keep_target_bounds"):
x_param = assert_is_instance(
adapter._model_space.parameters["x"], RangeParameter
)
self.assertEqual(x_param.lower, 0.0)
self.assertEqual(x_param.upper, 1.0)
with self.subTest("source_only_without_backfill"):
source_ss2 = _make_ss({"x": (0, 5), "w": (0, 10)})
source_exp2 = _gen_experiment(
"source2", num_trials=3, search_space=source_ss2
)
target_exp.auxiliary_experiments_by_purpose[TL_EXP] = [
AuxiliarySource(experiment=source_exp2)
]
adapter2 = TransferLearningAdapter(
experiment=target_exp,
search_space=target_ss,
data=target_exp.lookup_data(),
generator=BoTorchGenerator(),
transforms=[MetadataToTask],
fit_on_init=False,
)
self.assertIn("w", adapter2._model_space.parameters)
self.assertIsInstance(adapter2._model_space.parameters["w"], RangeParameter)


class GetTargetDataParametersTest(TestCase):
"""_get_target_data_parameters filters joint params to target-only + task."""

def test_filters_source_only_params(self) -> None:
adapter = MagicMock(spec=TransferLearningAdapter)
adapter._source_only_params = {"z"}
joint_params = ["x", "y", "z", Keys.TASK_FEATURE_NAME.value]
result = TransferLearningAdapter._get_target_data_parameters(
adapter, joint_params
)
result = TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)
self.assertEqual(result.discrete_choices, {3: [0, 1, 2]})
self.assertEqual(result.task_features, [3])
self.assertEqual(result, ["x", "y", Keys.TASK_FEATURE_NAME.value])

def test_no_source_only_params_returns_all(self) -> None:
adapter = MagicMock(spec=TransferLearningAdapter)
adapter._source_only_params = set()
params = ["x", "y", Keys.TASK_FEATURE_NAME.value]
result = TransferLearningAdapter._get_target_data_parameters(adapter, params)
self.assertEqual(result, params)

def test_hierarchical_dependencies_at_task_idx_raises(self) -> None:
type(self.adapter).joint_search_space = PropertyMock(
return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1), "x3": (0, 1)})

class FitWithDataParametersTest(TestCase):
"""After _fit, self.parameters = joint params and SSD has joint bounds,
without needing _expand_ssd_to_joint_space."""

@mock_botorch_optimize
def test_fit_heterogeneous_ssd_has_joint_bounds(self) -> None:
target_ss = _make_ss({"x": (0, 1), "y": (0, 1)})
source_ss = _make_ss({"x": (0, 5), "y": (0, 5), "z": (0, 5)})
target_exp = _gen_experiment("target", num_trials=3, search_space=target_ss)
source_exp = _gen_experiment("source", num_trials=5, search_space=source_ss)
target_exp.auxiliary_experiments_by_purpose[TL_EXP] = [
AuxiliarySource(experiment=source_exp)
]
adapter = TransferLearningAdapter(
experiment=target_exp,
search_space=target_ss,
data=target_exp.lookup_data(),
generator=BoTorchGenerator(),
transforms=[MetadataToTask],
fit_on_init=False,
)
ssd = SearchSpaceDigest(
feature_names=["x1", "x2", "task"],
bounds=[(0, 1), (0, 1), (0, 2)],
task_features=[2],
target_values={2: 0},
hierarchical_dependencies={2: {0: [1]}},
adapter.outcomes = list(
none_throws(target_exp.optimization_config).objective.metric_names
)
with self.assertRaisesRegex(UnsupportedError, "hierarchical_dependencies"):
TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)
adapter.parameters = list(target_exp.search_space.parameters)

def test_multiple_task_features_raises(self) -> None:
type(self.adapter).joint_search_space = PropertyMock(
return_value=self._make_joint_ss({"x1": (0, 1), "x2": (0, 1), "x3": (0, 1)})
)
ssd = SearchSpaceDigest(
feature_names=["x1", "task1", "task2"],
bounds=[(0, 1), (0, 1), (0, 1)],
task_features=[1, 2],
with patch.object(
adapter.generator, "fit", wraps=none_throws(adapter.generator).fit
) as gen_fit:
experiment_data, search_space = adapter._process_and_transform_data(
experiment=target_exp,
)
adapter._fit(search_space=search_space, experiment_data=experiment_data)

gen_fit.assert_called_once()
ssd = gen_fit.call_args[1]["search_space_digest"]
# SSD feature names include source-only z (from the joint model space)
self.assertIn("z", ssd.feature_names)
# SSD bounds cover the joint space
z_idx = ssd.feature_names.index("z")
self.assertEqual(ssd.bounds[z_idx], (0.0, 5.0))
# self.parameters is the joint set
self.assertIn("z", adapter.parameters)
# Task feature is last
self.assertEqual(adapter.parameters[-1], Keys.TASK_FEATURE_NAME.value)

@mock_botorch_optimize
def test_fit_and_gen_heterogeneous(self) -> None:
"""Full fit+gen round-trip with heterogeneous search spaces.
No status_quo on source, so z is truly source-only (no backfill)."""
target_ss = _make_ss({"x": (0, 1), "y": (0, 1)})
source_ss = _make_ss({"x": (0, 5), "y": (0, 5), "z": (0, 5)})
target_exp = _gen_experiment("target", num_trials=3, search_space=target_ss)
source_exp = _gen_experiment("source", num_trials=5, search_space=source_ss)
target_exp.auxiliary_experiments_by_purpose[TL_EXP] = [
AuxiliarySource(experiment=source_exp)
]
adapter = TransferLearningAdapter(
experiment=target_exp,
search_space=target_ss,
data=target_exp.lookup_data(),
generator=BoTorchGenerator(),
transforms=[MetadataToTask],
fit_on_init=True,
)
with self.assertRaisesRegex(UnsupportedError, "Multiple task features"):
TransferLearningAdapter._expand_ssd_to_joint_space(self.adapter, ssd)
gr = adapter.gen(n=1)
# Generated arms should only have target params
for arm in gr.arms:
self.assertIn("x", arm.parameters)
self.assertIn("y", arm.parameters)
self.assertNotIn("z", arm.parameters)
Original file line number Diff line number Diff line change
Expand Up @@ -402,10 +402,14 @@ def _input_transform_argparse_learned_feature_imputation(
torch.ones(d, dtype=dtype, device=torch_device),
]
)
# The target task is at position 0 (target_dataset is prepended above), so
# at posterior time — when X arrives without a task column — LFI applies
# the target task's imputation pattern.
kwargs: dict[str, Any] = {
"feature_indices": feature_indices,
"d": d,
"task_feature_index": task_feature_index,
"target_task": 0,
"bounds": bounds,
"device": torch_device,
"dtype": dtype,
Expand Down
Loading
Loading