Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 19 additions & 2 deletions ax/adapter/torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,7 @@ def _get_fit_args(
search_space: SearchSpace,
experiment_data: ExperimentData,
update_outcomes_and_parameters: bool,
data_parameters: list[str] | None = None,
) -> tuple[
list[SupervisedDataset],
list[list[TCandidateMetadata]] | None,
Expand All @@ -791,6 +792,11 @@ def _get_fit_args(
update_outcomes_and_parameters: Whether to update `self.outcomes` with
all outcomes found in the observations and `self.parameters` with
all parameters in the search space. Typically only used in `_fit`.
data_parameters: When provided, columns to extract from
``experiment_data``. Defaults to ``self.parameters``. Useful when
the model space is larger than the data (e.g. transfer learning
with heterogeneous search spaces where the data only contains
target columns but the model operates in a joint feature space).

Returns:
The datasets & metadata, extracted from the ``experiment_data``, and the
Expand Down Expand Up @@ -818,12 +824,23 @@ def _get_fit_args(
search_space_digest = extract_search_space_digest(
search_space=search_space, param_names=self.parameters
)
extract_params = (
data_parameters if data_parameters is not None else self.parameters
)
# When data_parameters differs from self.parameters, the SSD's
# task_feature indices don't match the data columns. Pass None
# to skip MultiTaskDataset wrapping (the caller handles it).
extraction_ssd = (
None
if data_parameters is not None and data_parameters != self.parameters
else search_space_digest
)
# Convert observations to datasets
datasets, ordered_outcomes, candidate_metadata = self._convert_experiment_data(
experiment_data=experiment_data,
outcomes=self.outcomes,
parameters=self.parameters,
search_space_digest=search_space_digest,
parameters=extract_params,
search_space_digest=extraction_ssd,
)
datasets = self._update_w_aux_exp_datasets(datasets=datasets)

Expand Down
144 changes: 49 additions & 95 deletions ax/adapter/transfer_learning/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from __future__ import annotations

import dataclasses
import warnings
from collections.abc import Mapping, Sequence
from logging import Logger
Expand Down Expand Up @@ -196,6 +195,24 @@ def __init__(
default_model_gen_options=default_model_gen_options,
)

def _set_search_space(self, search_space: SearchSpace) -> None:
"""Set search space and model space for transfer learning.

Overrides the base class to add source-only params (as RangeParameters)
to ``_model_space`` while preserving target bounds for shared params.
This ensures the SSD naturally covers the full joint feature space
without needing post-hoc expansion, and Normalize is anchored to target
bounds so target data maps to [0, 1].
"""
self._search_space = search_space.clone()
model_space = search_space.clone()
self._source_only_params: set[str] = set()
for name, param in self.joint_search_space.parameters.items():
if name not in model_space.parameters and isinstance(param, RangeParameter):
model_space.add_parameter(param.clone())
self._source_only_params.add(name)
self._model_space = model_space

def _transform_data(
self,
experiment_data: ExperimentData,
Expand Down Expand Up @@ -505,94 +522,18 @@ def _get_task_datasets(
)
return task_datasets

def _expand_ssd_to_joint_space(
def _get_target_data_parameters(
self,
search_space_digest: SearchSpaceDigest,
) -> SearchSpaceDigest:
"""Expand SSD bounds and feature_names to cover the joint search space.

The SSD produced by ``_get_fit_args`` reflects the target search space.
When source experiments have additional parameters, the model operates
in the full joint feature space. This method appends bounds and feature
names for source-only parameters so that input transforms receive
correct full-space bounds.
all_params: list[str],
) -> list[str]:
"""Filter a joint parameter list to target-only params + task feature.

Source-only params (those added by ``_set_search_space`` from the joint
space) are excluded because the target experiment data does not have
those columns. Uses untransformed names, which are stable across
transforms (Range params are never renamed by OneHot, IntToFloat, etc.).
"""
existing_names = set(search_space_digest.feature_names)
extra_names: list[str] = []
extra_bounds: list[tuple[int | float, int | float]] = []
# Only collect parameters absent from the target SSD. Shared
# parameters that appear in both target and source keep the target
# bounds -- source observations outside those bounds will normalize
# outside [0, 1]. This is intentional, as the GP hyperprior is calibrated
# for a __target__ task in [0, 1]^D.
for name, param in self.joint_search_space.parameters.items():
if name not in existing_names and isinstance(param, RangeParameter):
extra_names.append(name)
extra_bounds.append((param.lower, param.upper))
if not extra_names:
return search_space_digest
# Insert source-only params before the task feature
task_features = search_space_digest.task_features
if len(task_features) == 1:
tf_idx = task_features[0]
names = list(search_space_digest.feature_names)
bounds = list(search_space_digest.bounds)
# Raise if index-based fields (other than the task feature
# itself) reference indices at or above tf_idx, since we would
# need to shift them when inserting extra params.
for field_name in (
"ordinal_features",
"categorical_features",
"fidelity_features",
):
indices = getattr(search_space_digest, field_name)
if any(i >= tf_idx for i in indices):
raise UnsupportedError(
f"Cannot expand SSD: {field_name} contains index >= {tf_idx}."
)
if any(
i >= tf_idx and i not in task_features
for i in search_space_digest.discrete_choices
):
raise UnsupportedError(
f"Cannot expand SSD: discrete_choices contains index >= {tf_idx}."
)
if search_space_digest.hierarchical_dependencies is not None and any(
i >= tf_idx for i in search_space_digest.hierarchical_dependencies
):
raise UnsupportedError(
"Cannot expand SSD: hierarchical_dependencies contains "
f"index >= {tf_idx}."
)
names[tf_idx:tf_idx] = extra_names
bounds[tf_idx:tf_idx] = extra_bounds
n_extra = len(extra_names)
new_task_features = [tf_idx + n_extra]
new_target_values = dict(search_space_digest.target_values)
if tf_idx in new_target_values:
new_target_values[new_task_features[0]] = new_target_values.pop(tf_idx)
new_discrete = dict(search_space_digest.discrete_choices)
if tf_idx in new_discrete:
new_discrete[new_task_features[0]] = new_discrete.pop(tf_idx)
return dataclasses.replace(
search_space_digest,
feature_names=names,
bounds=bounds,
task_features=new_task_features,
target_values=new_target_values,
discrete_choices=new_discrete,
)
elif len(task_features) == 0:
# No task feature -- just append.
return dataclasses.replace(
search_space_digest,
feature_names=search_space_digest.feature_names + extra_names,
bounds=search_space_digest.bounds + extra_bounds,
)
else:
raise UnsupportedError(
"Multiple task features are not supported in transfer learning."
)
return [p for p in all_params if p not in self._source_only_params]

def _fit(
self,
Expand All @@ -610,15 +551,20 @@ def _fit(
if experiment_data.arm_data.empty:
# Temporarily unset self.outcomes to avoid an error in _get_fit_args.
self.outcomes = []
# Pre-compute the joint param ordering (mirrors _get_fit_args logic)
# so we can derive the target-only subset for data extraction.
all_params = list(search_space.parameters.keys())
task_name = Keys.TASK_FEATURE_NAME.value
if task_name in all_params:
all_params.remove(task_name)
all_params.append(task_name)
target_data_params = self._get_target_data_parameters(all_params)
datasets, candidate_metadata, search_space_digest = self._get_fit_args(
search_space=search_space,
experiment_data=experiment_data,
update_outcomes_and_parameters=True,
data_parameters=target_data_params,
)
# Expand SSD bounds to cover source-only params from the joint search
# space. This ensures Normalize (and other input transforms) get bounds
# for the full feature space, not just the target dims.
search_space_digest = self._expand_ssd_to_joint_space(search_space_digest)
if experiment_data.arm_data.empty:
self.outcomes = outcomes
# Temporarily set datasets to None. We will construct empty datasets
Expand Down Expand Up @@ -656,12 +602,13 @@ def _cross_validate(
) -> list[ObservationData]:
if self.parameters is None:
raise ValueError(FIT_MODEL_ERROR.format(action="_cross_validate"))
target_data_params = self._get_target_data_parameters(self.parameters)
datasets, _, search_space_digest = self._get_fit_args(
search_space=search_space,
experiment_data=cv_training_data,
update_outcomes_and_parameters=False,
data_parameters=target_data_params,
)
search_space_digest = self._expand_ssd_to_joint_space(search_space_digest)
# Add the task feature to SSD, to ensure that a multi-task model is selected.
if len(search_space_digest.task_features) > 1:
raise UnsupportedError(
Expand Down Expand Up @@ -728,9 +675,10 @@ def gen(
if fixed_features is None:
fixed_features = ObservationFeatures(parameters={})
fixed_features.parameters.setdefault(name, target_p.value)
# Fix source-only params so the optimizer doesn't search over them.
# Center is a reasonable default; LearnedFeatureImputation overwrites
# these with learned values when configured.
# Fix source-only params that ARE in the search space (e.g. injected
# as FixedParam with a backfill value) so the optimizer doesn't search
# over them. Params NOT in the search space are handled by the model
# internally (HeterogeneousMTGP natively, LFI for MultiTaskGP).
joint_center = self.joint_search_space.compute_naive_center()
for name, param in self.joint_search_space.parameters.items():
if name not in search_space.parameters and isinstance(
Expand All @@ -739,6 +687,11 @@ def gen(
if fixed_features is None:
fixed_features = ObservationFeatures(parameters={})
fixed_features.parameters.setdefault(name, joint_center[name])
# At gen time, restrict self.parameters to params that exist in the
# gen-time search space. Source-only params absent from _search_space
# are handled by the model (LFI imputation or HeterogeneousMTGP).
saved_parameters = self.parameters
self.parameters = self._get_target_data_parameters(self.parameters)
generator_run = super().gen(
n=n,
search_space=search_space,
Expand All @@ -747,6 +700,7 @@ def gen(
fixed_features=fixed_features,
model_gen_options=model_gen_options,
)
self.parameters = saved_parameters
# Remove the parameters that are not in the target experiment's search
# space, and update candidate_metadata_by_arm_signature to reflect the
# new arm. We use the experiment's search space rather than
Expand Down
Loading
Loading