Skip to content

Commit 17ac795

Browse files
Carl Hvarfnermeta-codesync[bot]
authored andcommitted
Fix in-design filtering regression (#5103)
Summary: Pull Request resolved: #5103 D94693361 introduced a regression: when source experiments have more parameters than target and status_quo is set, FillMissingParameters adds extra columns to target arm data during _compute_in_design. check_membership_df then returns [False] for all rows because df_cols != ss_params. All target arms are incorrectly filtered out. Fix: Override _compute_in_design in TransferLearningAdapter to use check_all_parameters_present=False. Reviewed By: saitcakmak Differential Revision: D97625737 fbshipit-source-id: e34339cf1b2bbf2445ab8bd7ac8b539900164751
1 parent c6a1770 commit 17ac795

1 file changed

Lines changed: 25 additions & 0 deletions

File tree

ax/adapter/transfer_learning/adapter.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -239,6 +239,31 @@ def _transform_data(
239239

240240
return experiment_data, search_space
241241

242+
def _compute_in_design(
243+
self,
244+
search_space: SearchSpace,
245+
experiment_data: ExperimentData,
246+
) -> list[bool]:
247+
"""Compute in-design status for heterogeneous transfer learning.
248+
249+
Overrides base class to use check_all_parameters_present=False, which
250+
tolerates extra columns in arm_data beyond the search space parameters.
251+
This is necessary for heterogeneous TL where FillMissingParameters adds
252+
columns for source-only parameters (e.g. 'z') to target arm data,
253+
causing the default extra_params check to reject all rows.
254+
"""
255+
experiment_data, _ = self._transform_data(
256+
experiment_data=experiment_data,
257+
search_space=search_space,
258+
transforms=self._raw_transforms[:1],
259+
transform_configs=self._transform_configs,
260+
assign_transforms=False,
261+
)
262+
return search_space.check_membership_df(
263+
arm_data=experiment_data.arm_data,
264+
check_all_parameters_present=False,
265+
)
266+
242267
def get_training_data(self, filter_in_design: bool = False) -> ExperimentData:
243268
"""Returns the training data for the current experiment, with its metadata
244269
updated to include the task value.

0 commit comments

Comments
 (0)