|
14 | 14 |
|
15 | 15 | import torch
|
16 | 16 | import torch.nn as nn
|
| 17 | +from tensordict import TensorDict |
17 | 18 | from torchrec.distributed.embedding_tower_sharding import (
|
18 | 19 | EmbeddingTowerCollectionSharder,
|
19 | 20 | EmbeddingTowerSharder,
|
|
46 | 47 | @dataclass
|
47 | 48 | class ModelInput(Pipelineable):
|
48 | 49 | float_features: torch.Tensor
|
49 |
| - idlist_features: KeyedJaggedTensor |
50 |
| - idscore_features: Optional[KeyedJaggedTensor] |
| 50 | + idlist_features: Union[KeyedJaggedTensor, TensorDict] |
| 51 | + idscore_features: Optional[Union[KeyedJaggedTensor, TensorDict]] |
51 | 52 | label: torch.Tensor
|
52 | 53 |
|
53 | 54 | @staticmethod
|
@@ -76,11 +77,13 @@ def generate(
|
76 | 77 | randomize_indices: bool = True,
|
77 | 78 | device: Optional[torch.device] = None,
|
78 | 79 | max_feature_lengths: Optional[List[int]] = None,
|
| 80 | + input_type: str = "kjt", |
79 | 81 | ) -> Tuple["ModelInput", List["ModelInput"]]:
|
80 | 82 | """
|
81 | 83 | Returns a global (single-rank training) batch
|
82 | 84 | and a list of local (multi-rank training) batches of world_size.
|
83 | 85 | """
|
| 86 | + |
84 | 87 | batch_size_by_rank = [batch_size] * world_size
|
85 | 88 | if variable_batch_size:
|
86 | 89 | batch_size_by_rank = [
|
@@ -199,11 +202,26 @@ def _validate_pooling_factor(
|
199 | 202 | )
|
200 | 203 | global_idlist_lengths.append(lengths)
|
201 | 204 | global_idlist_indices.append(indices)
|
202 |
| - global_idlist_kjt = KeyedJaggedTensor( |
203 |
| - keys=idlist_features, |
204 |
| - values=torch.cat(global_idlist_indices), |
205 |
| - lengths=torch.cat(global_idlist_lengths), |
206 |
| - ) |
| 205 | + |
| 206 | + if input_type == "kjt": |
| 207 | + global_idlist_input = KeyedJaggedTensor( |
| 208 | + keys=idlist_features, |
| 209 | + values=torch.cat(global_idlist_indices), |
| 210 | + lengths=torch.cat(global_idlist_lengths), |
| 211 | + ) |
| 212 | + elif input_type == "td": |
| 213 | + dict_of_nt = { |
| 214 | + k: torch.nested.nested_tensor_from_jagged( |
| 215 | + values=values, |
| 216 | + lengths=lengths, |
| 217 | + ) |
| 218 | + for k, values, lengths in zip( |
| 219 | + idlist_features, global_idlist_indices, global_idlist_lengths |
| 220 | + ) |
| 221 | + } |
| 222 | + global_idlist_input = TensorDict(source=dict_of_nt) |
| 223 | + else: |
| 224 | + raise ValueError(f"For IdList features, unknown input type {input_type}") |
207 | 225 |
|
208 | 226 | for idx in range(len(idscore_ind_ranges)):
|
209 | 227 | ind_range = idscore_ind_ranges[idx]
|
@@ -245,16 +263,25 @@ def _validate_pooling_factor(
|
245 | 263 | global_idscore_lengths.append(lengths)
|
246 | 264 | global_idscore_indices.append(indices)
|
247 | 265 | global_idscore_weights.append(weights)
|
248 |
| - global_idscore_kjt = ( |
249 |
| - KeyedJaggedTensor( |
250 |
| - keys=idscore_features, |
251 |
| - values=torch.cat(global_idscore_indices), |
252 |
| - lengths=torch.cat(global_idscore_lengths), |
253 |
| - weights=torch.cat(global_idscore_weights), |
| 266 | + |
| 267 | + if input_type == "kjt": |
| 268 | + global_idscore_input = ( |
| 269 | + KeyedJaggedTensor( |
| 270 | + keys=idscore_features, |
| 271 | + values=torch.cat(global_idscore_indices), |
| 272 | + lengths=torch.cat(global_idscore_lengths), |
| 273 | + weights=torch.cat(global_idscore_weights), |
| 274 | + ) |
| 275 | + if global_idscore_indices |
| 276 | + else None |
254 | 277 | )
|
255 |
| - if global_idscore_indices |
256 |
| - else None |
257 |
| - ) |
| 278 | + elif input_type == "td": |
| 279 | + assert ( |
| 280 | + len(idscore_features) == 0 |
| 281 | + ), "TensorDict does not support weighted features" |
| 282 | + global_idscore_input = None |
| 283 | + else: |
| 284 | + raise ValueError(f"For weighted features, unknown input type {input_type}") |
258 | 285 |
|
259 | 286 | if randomize_indices:
|
260 | 287 | global_float = torch.rand(
|
@@ -303,36 +330,57 @@ def _validate_pooling_factor(
|
303 | 330 | weights[lengths_cumsum[r] : lengths_cumsum[r + 1]]
|
304 | 331 | )
|
305 | 332 |
|
306 |
| - local_idlist_kjt = KeyedJaggedTensor( |
307 |
| - keys=idlist_features, |
308 |
| - values=torch.cat(local_idlist_indices), |
309 |
| - lengths=torch.cat(local_idlist_lengths), |
310 |
| - ) |
| 333 | + if input_type == "kjt": |
| 334 | + local_idlist_input = KeyedJaggedTensor( |
| 335 | + keys=idlist_features, |
| 336 | + values=torch.cat(local_idlist_indices), |
| 337 | + lengths=torch.cat(local_idlist_lengths), |
| 338 | + ) |
311 | 339 |
|
312 |
| - local_idscore_kjt = ( |
313 |
| - KeyedJaggedTensor( |
314 |
| - keys=idscore_features, |
315 |
| - values=torch.cat(local_idscore_indices), |
316 |
| - lengths=torch.cat(local_idscore_lengths), |
317 |
| - weights=torch.cat(local_idscore_weights), |
| 340 | + local_idscore_input = ( |
| 341 | + KeyedJaggedTensor( |
| 342 | + keys=idscore_features, |
| 343 | + values=torch.cat(local_idscore_indices), |
| 344 | + lengths=torch.cat(local_idscore_lengths), |
| 345 | + weights=torch.cat(local_idscore_weights), |
| 346 | + ) |
| 347 | + if local_idscore_indices |
| 348 | + else None |
| 349 | + ) |
| 350 | + elif input_type == "td": |
| 351 | + dict_of_nt = { |
| 352 | + k: torch.nested.nested_tensor_from_jagged( |
| 353 | + values=values, |
| 354 | + lengths=lengths, |
| 355 | + ) |
| 356 | + for k, values, lengths in zip( |
| 357 | + idlist_features, local_idlist_indices, local_idlist_lengths |
| 358 | + ) |
| 359 | + } |
| 360 | + local_idlist_input = TensorDict(source=dict_of_nt) |
| 361 | + assert ( |
| 362 | + len(idscore_features) == 0 |
| 363 | + ), "TensorDict does not support weighted features" |
| 364 | + local_idscore_input = None |
| 365 | + |
| 366 | + else: |
| 367 | + raise ValueError( |
| 368 | + f"For weighted features, unknown input type {input_type}" |
318 | 369 | )
|
319 |
| - if local_idscore_indices |
320 |
| - else None |
321 |
| - ) |
322 | 370 |
|
323 | 371 | local_input = ModelInput(
|
324 | 372 | float_features=global_float[r * batch_size : (r + 1) * batch_size],
|
325 |
| - idlist_features=local_idlist_kjt, |
326 |
| - idscore_features=local_idscore_kjt, |
| 373 | + idlist_features=local_idlist_input, |
| 374 | + idscore_features=local_idscore_input, |
327 | 375 | label=global_label[r * batch_size : (r + 1) * batch_size],
|
328 | 376 | )
|
329 | 377 | local_inputs.append(local_input)
|
330 | 378 |
|
331 | 379 | return (
|
332 | 380 | ModelInput(
|
333 | 381 | float_features=global_float,
|
334 |
| - idlist_features=global_idlist_kjt, |
335 |
| - idscore_features=global_idscore_kjt, |
| 382 | + idlist_features=global_idlist_input, |
| 383 | + idscore_features=global_idscore_input, |
336 | 384 | label=global_label,
|
337 | 385 | ),
|
338 | 386 | local_inputs,
|
@@ -623,8 +671,9 @@ def to(self, device: torch.device, non_blocking: bool = False) -> "ModelInput":
|
623 | 671 |
|
624 | 672 | def record_stream(self, stream: torch.Stream) -> None:
|
625 | 673 | self.float_features.record_stream(stream)
|
626 |
| - self.idlist_features.record_stream(stream) |
627 |
| - if self.idscore_features is not None: |
| 674 | + if isinstance(self.idlist_features, KeyedJaggedTensor): |
| 675 | + self.idlist_features.record_stream(stream) |
| 676 | + if isinstance(self.idscore_features, KeyedJaggedTensor): |
628 | 677 | self.idscore_features.record_stream(stream)
|
629 | 678 | self.label.record_stream(stream)
|
630 | 679 |
|
@@ -1831,6 +1880,8 @@ def forward(self, input: ModelInput) -> ModelInput:
|
1831 | 1880 | )
|
1832 | 1881 |
|
1833 | 1882 | # stride will be same but features will be joined
|
| 1883 | + assert isinstance(modified_input.idlist_features, KeyedJaggedTensor) |
| 1884 | + assert isinstance(self._extra_input.idlist_features, KeyedJaggedTensor) |
1834 | 1885 | modified_input.idlist_features = KeyedJaggedTensor.concat(
|
1835 | 1886 | [modified_input.idlist_features, self._extra_input.idlist_features]
|
1836 | 1887 | )
|
|
0 commit comments