@@ -110,8 +110,21 @@ def __init__(
110110 """
111111 self .full_feature_dim = full_feature_dim
112112 self .feature_indices = feature_indices
113+ imputation_values = self ._compute_imputation_values (
114+ train_Xs = train_Xs ,
115+ feature_indices = feature_indices ,
116+ full_feature_dim = full_feature_dim ,
117+ )
118+ # The first time we map to full tensor, we have to pass in the imputation values
119+ # as they have not yet been registered as buffers - this has to wait until after
120+ # super().__init__.
113121 full_X = torch .cat (
114- [self .map_to_full_tensor (X = X , task_index = i ) for i , X in enumerate (train_Xs )]
122+ [
123+ self .map_to_full_tensor (
124+ X = X , task_index = i , imputation_values = imputation_values
125+ )
126+ for i , X in enumerate (train_Xs )
127+ ]
115128 )
116129 full_Y = torch .cat (train_Ys )
117130 full_Yvar = None if train_Yvars is None else torch .cat (train_Yvars )
@@ -139,6 +152,7 @@ def __init__(
139152 outcome_transform = outcome_transform ,
140153 validate_task_values = validate_task_values ,
141154 )
155+ self .register_buffer ("feature_imputation_values" , imputation_values )
142156
143157 @classmethod
144158 def get_all_tasks (
@@ -158,36 +172,80 @@ def get_all_tasks(
158172 all_tasks_inferred = [0 ] + all_tasks_inferred
159173 return all_tasks_inferred , task_feature , num_non_task_features
160174
161- def map_to_full_tensor (self , X : Tensor , task_index : int ) -> Tensor :
175+ def map_to_full_tensor (
176+ self ,
177+ X : Tensor ,
178+ task_index : int ,
179+ imputation_values : Tensor | None = None ,
180+ ) -> Tensor :
162181 """Map a tensor of task-specific features to the full tensor of features,
163182 utilizing the feature indices to map each feature to its corresponding
164183 position in the full tensor. Also append the task index as the last column.
165- The columns of the full tensor that are not used by the given task will be
166- filled with zeros.
184+ The columns of the full tensor that are not used by the given task are
185+ filled with the per-dimension empirical mean computed across all tasks
186+ that contain that dimension (see ``_compute_imputation_values``). This
187+ avoids out-of-domain padding values that would otherwise be squashed by
188+ an input transform with fixed bounds (e.g. ``Normalize``).
167189
168190 Args:
169191 X: A tensor of shape ``(n x d_i)`` where ``d_i`` is the number of features
170192 in the original task dataset.
171193 task_index: The index of the task whose features are being mapped.
194+ imputation_values: Optional pre-computed imputation values. If not
195+ provided, uses ``self.feature_imputation_values``.
172196
173197 Returns:
174198 A tensor of shape ``(n x (self.full_feature_dim + 1))`` containing the
175199 mapped features.
176200
177201 Example:
178- >>> # Suppose full feature dim is 3 and the feature indices for
179- >>> # task 5 are [2, 0].
202+ >>> # Suppose full feature dim is 3, the feature indices for task 5
203+ >>> # are [2, 0], and the empirical mean for missing dim 1 is 7.0 .
180204 >>> X = torch.tensor([[1.0, 2.0], [3.0, 4.0]])
181205 >>> X_full = self.map_to_full_tensor(X=X, task_index=5)
182- >>> # X_full = torch.tensor([[2.0, 0 .0, 1.0, 5.0], [4.0, 0 .0, 3.0, 5.0]])
206+ >>> # X_full = torch.tensor([[2.0, 7 .0, 1.0, 5.0], [4.0, 7 .0, 3.0, 5.0]])
183207 """
208+ if imputation_values is None :
209+ imputation_values = self .feature_imputation_values
184210 X_full = torch .zeros (
185211 * X .shape [:- 1 ], self .full_feature_dim + 1 , dtype = X .dtype , device = X .device
186212 )
213+ X_full [..., : self .full_feature_dim ] = imputation_values
187214 X_full [..., self .feature_indices [task_index ]] = X
188215 X_full [..., - 1 ] = task_index
189216 return X_full
190217
218+ @staticmethod
219+ def _compute_imputation_values (
220+ train_Xs : list [Tensor ],
221+ feature_indices : list [list [int ]],
222+ full_feature_dim : int ,
223+ ) -> Tensor :
224+ """Compute per-dimension empirical mean across all tasks that contain
225+ each dimension of the joint feature space.
226+
227+ For each dimension ``d`` in ``[0, full_feature_dim)``, collects the values
228+ from every task's ``train_X`` column that maps to ``d`` and takes the mean.
229+ These values are used by ``map_to_full_tensor`` to impute missing dims when
230+ embedding a per-task ``X`` into the full feature space.
231+
232+ Returns:
233+ A tensor of shape ``(full_feature_dim,)`` with the per-dim mean. If a
234+ dimension is not present in any task (which should not occur under the
235+ constructor's invariants), the value defaults to 0.
236+ """
237+ dtype = train_Xs [0 ].dtype
238+ device = train_Xs [0 ].device
239+ imputation = torch .zeros (full_feature_dim , dtype = dtype , device = device )
240+ for d in range (full_feature_dim ):
241+ values : list [Tensor ] = []
242+ for indices , X in zip (feature_indices , train_Xs ):
243+ if d in indices and X .numel () > 0 :
244+ values .append (X [..., indices .index (d )].reshape (- 1 ))
245+ if values :
246+ imputation [d ] = torch .cat (values ).mean ()
247+ return imputation
248+
191249 def posterior (
192250 self ,
193251 X : Tensor ,
0 commit comments