Skip to content

Commit ea00d74

Browse files
Carl Hvarfnerfacebook-github-bot
authored andcommitted
Impute per-dim empirical mean (not zero) for missing features in HeterogeneousMTGP (#3294)
Summary: HeterogeneousMTGP zero-pads missing per-task feature columns, which risks heavily skewing parameter ranges that are substantially outsize zero and lending these dimensions very difficult to fit. Impute the per-dim empirical mean of the train_Xs columns containing that dim instead, keeping padded values inside the model's input range. Falls back to 0 if no task contains a dim. Reviewed By: saitcakmak Differential Revision: D102390346
1 parent 4946523 commit ea00d74

2 files changed

Lines changed: 82 additions & 7 deletions

File tree

botorch/models/heterogeneous_mtgp.py

Lines changed: 65 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,8 +110,21 @@ def __init__(
110110
"""
111111
self.full_feature_dim = full_feature_dim
112112
self.feature_indices = feature_indices
113+
imputation_values = self._compute_imputation_values(
114+
train_Xs=train_Xs,
115+
feature_indices=feature_indices,
116+
full_feature_dim=full_feature_dim,
117+
)
118+
# The first time we map to full tensor, we have to pass in the imputation values
119+
# as they have not yet been registered as buffers - this has to wait until after
120+
# super().__init__.
113121
full_X = torch.cat(
114-
[self.map_to_full_tensor(X=X, task_index=i) for i, X in enumerate(train_Xs)]
122+
[
123+
self.map_to_full_tensor(
124+
X=X, task_index=i, imputation_values=imputation_values
125+
)
126+
for i, X in enumerate(train_Xs)
127+
]
115128
)
116129
full_Y = torch.cat(train_Ys)
117130
full_Yvar = None if train_Yvars is None else torch.cat(train_Yvars)
@@ -139,6 +152,7 @@ def __init__(
139152
outcome_transform=outcome_transform,
140153
validate_task_values=validate_task_values,
141154
)
155+
self.register_buffer("feature_imputation_values", imputation_values)
142156

143157
@classmethod
144158
def get_all_tasks(
@@ -158,36 +172,80 @@ def get_all_tasks(
158172
all_tasks_inferred = [0] + all_tasks_inferred
159173
return all_tasks_inferred, task_feature, num_non_task_features
160174

161-
def map_to_full_tensor(self, X: Tensor, task_index: int) -> Tensor:
175+
def map_to_full_tensor(
176+
self,
177+
X: Tensor,
178+
task_index: int,
179+
imputation_values: Tensor | None = None,
180+
) -> Tensor:
162181
"""Map a tensor of task-specific features to the full tensor of features,
163182
utilizing the feature indices to map each feature to its corresponding
164183
position in the full tensor. Also append the task index as the last column.
165-
The columns of the full tensor that are not used by the given task will be
166-
filled with zeros.
184+
The columns of the full tensor that are not used by the given task are
185+
filled with the per-dimension empirical mean computed across all tasks
186+
that contain that dimension (see ``_compute_imputation_values``). This
187+
avoids out-of-domain padding values that would otherwise be squashed by
188+
an input transform with fixed bounds (e.g. ``Normalize``).
167189
168190
Args:
169191
X: A tensor of shape ``(n x d_i)`` where ``d_i`` is the number of features
170192
in the original task dataset.
171193
task_index: The index of the task whose features are being mapped.
194+
imputation_values: Optional pre-computed imputation values. If not
195+
provided, uses ``self.feature_imputation_values``.
172196
173197
Returns:
174198
A tensor of shape ``(n x (self.full_feature_dim + 1))`` containing the
175199
mapped features.
176200
177201
Example:
178-
>>> # Suppose full feature dim is 3 and the feature indices for
179-
>>> # task 5 are [2, 0].
202+
>>> # Suppose full feature dim is 3, the feature indices for task 5
203+
>>> # are [2, 0], and the empirical mean for missing dim 1 is 7.0.
180204
>>> X = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
181205
>>> X_full = self.map_to_full_tensor(X=X, task_index=5)
182-
>>> # X_full = torch.tensor([[2.0, 0.0, 1.0, 5.0], [4.0, 0.0, 3.0, 5.0]])
206+
>>> # X_full = torch.tensor([[2.0, 7.0, 1.0, 5.0], [4.0, 7.0, 3.0, 5.0]])
183207
"""
208+
if imputation_values is None:
209+
imputation_values = self.feature_imputation_values
184210
X_full = torch.zeros(
185211
*X.shape[:-1], self.full_feature_dim + 1, dtype=X.dtype, device=X.device
186212
)
213+
X_full[..., : self.full_feature_dim] = imputation_values
187214
X_full[..., self.feature_indices[task_index]] = X
188215
X_full[..., -1] = task_index
189216
return X_full
190217

218+
@staticmethod
219+
def _compute_imputation_values(
220+
train_Xs: list[Tensor],
221+
feature_indices: list[list[int]],
222+
full_feature_dim: int,
223+
) -> Tensor:
224+
"""Compute per-dimension empirical mean across all tasks that contain
225+
each dimension of the joint feature space.
226+
227+
For each dimension ``d`` in ``[0, full_feature_dim)``, collects the values
228+
from every task's ``train_X`` column that maps to ``d`` and takes the mean.
229+
These values are used by ``map_to_full_tensor`` to impute missing dims when
230+
embedding a per-task ``X`` into the full feature space.
231+
232+
Returns:
233+
A tensor of shape ``(full_feature_dim,)`` with the per-dim mean. If a
234+
dimension is not present in any task (which should not occur under the
235+
constructor's invariants), the value defaults to 0.
236+
"""
237+
dtype = train_Xs[0].dtype
238+
device = train_Xs[0].device
239+
imputation = torch.zeros(full_feature_dim, dtype=dtype, device=device)
240+
for d in range(full_feature_dim):
241+
values: list[Tensor] = []
242+
for indices, X in zip(feature_indices, train_Xs):
243+
if d in indices and X.numel() > 0:
244+
values.append(X[..., indices.index(d)].reshape(-1))
245+
if values:
246+
imputation[d] = torch.cat(values).mean()
247+
return imputation
248+
191249
def posterior(
192250
self,
193251
X: Tensor,

test/models/test_heterogeneous_mtgp.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,23 @@ def test_standard_heterogeneous_mtgp(self) -> None:
122122
model.likelihood.noise_covar.noise.shape[-1], model.num_tasks
123123
)
124124

125+
with self.subTest("imputation_uses_per_dim_empirical_mean"):
126+
# Full feature space is [x1, x2, x3, x4, x5]. x3 is only in task 0,
127+
# x4 and x5 are only in task 2. Imputation values for missing dims
128+
# should equal the empirical mean of those columns across tasks.
129+
expected_x3_mean = self.ds1.X[:, 2].mean()
130+
expected_x4_mean = self.ds3.X[:, 2].mean()
131+
expected_x5_mean = self.ds3.X[:, 3].mean()
132+
self.assertAllClose(model.feature_imputation_values[2], expected_x3_mean)
133+
self.assertAllClose(model.feature_imputation_values[3], expected_x4_mean)
134+
self.assertAllClose(model.feature_imputation_values[4], expected_x5_mean)
135+
# Task 1 (ds2) does not have x3, x4, x5 -- those columns in the
136+
# full training tensor must equal the imputation values, not zero.
137+
task1_rows = model.train_inputs[0][model.train_inputs[0][:, -1] == 1]
138+
self.assertAllClose(task1_rows[:, 2], expected_x3_mean.expand(3))
139+
self.assertAllClose(task1_rows[:, 3], expected_x4_mean.expand(3))
140+
self.assertAllClose(task1_rows[:, 4], expected_x5_mean.expand(3))
141+
125142
# Evaluate the posterior (task column required).
126143
with self.assertRaisesRegex(UnsupportedError, "output_indices"):
127144
model.posterior(self.ds1.X, output_indices=[0, 1])

0 commit comments

Comments
 (0)