|
9 | 9 |
|
10 | 10 |
|
11 | 11 | import functools
|
12 |
| -from typing import Any, Callable |
| 12 | +from typing import Any, Callable, Dict, List, Optional, Tuple |
13 | 13 |
|
14 | 14 | import torch
|
15 |
| -from torchrec.sparse.jagged_tensor import KeyedJaggedTensor |
| 15 | +from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor |
16 | 16 |
|
17 | 17 | """
|
18 | 18 | Prepares KJT for PT2 tracing.
|
|
28 | 28 | def kjt_for_pt2_tracing(
|
29 | 29 | kjt: KeyedJaggedTensor,
|
30 | 30 | convert_to_vb: bool = False,
|
| 31 | + mark_length: bool = False, |
31 | 32 | ) -> KeyedJaggedTensor:
|
32 | 33 | # Breaking dependency cycle
|
33 | 34 | from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
|
@@ -78,25 +79,132 @@ def kjt_for_pt2_tracing(
|
78 | 79 | weights = kjt.weights_or_none()
|
79 | 80 | if weights is not None:
|
80 | 81 | torch._dynamo.decorators.mark_unbacked(weights, 0)
|
| 82 | + if mark_length: |
| 83 | + torch._dynamo.decorators.mark_unbacked(lengths, 0) |
81 | 84 |
|
82 |
| - return KeyedJaggedTensor( |
| 85 | + length_per_key_marked_dynamic = [] |
| 86 | + |
| 87 | + for length in kjt.length_per_key(): |
| 88 | + length_per_key_marked_dynamic.append(length) |
| 89 | + |
| 90 | + return PT2KeyedJaggedTensor( |
83 | 91 | keys=kjt.keys(),
|
84 | 92 | values=values,
|
85 | 93 | lengths=lengths,
|
86 | 94 | weights=weights,
|
87 | 95 | stride=stride if not is_vb else None,
|
88 | 96 | stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None,
|
89 | 97 | inverse_indices=inverse_indices,
|
| 98 | + length_per_key=(length_per_key_marked_dynamic if is_vb else None), |
90 | 99 | )
|
91 | 100 |
|
92 | 101 |
|
| 102 | +class PT2KeyedJaggedTensor(KeyedJaggedTensor): |
| 103 | + """ |
| 104 | + This subclass of KeyedJaggedTensor is used to support PT2 tracing. |
| 105 | + We can apply some modifications to make KJT friendly for PT2 tracing. |
| 106 | + """ |
| 107 | + |
| 108 | + def __init__( |
| 109 | + self, |
| 110 | + keys: List[str], |
| 111 | + values: torch.Tensor, |
| 112 | + weights: Optional[torch.Tensor] = None, |
| 113 | + lengths: Optional[torch.Tensor] = None, |
| 114 | + offsets: Optional[torch.Tensor] = None, |
| 115 | + stride: Optional[int] = None, |
| 116 | + stride_per_key_per_rank: Optional[List[List[int]]] = None, |
| 117 | + stride_per_key: Optional[List[int]] = None, |
| 118 | + length_per_key: Optional[List[int]] = None, |
| 119 | + lengths_offset_per_key: Optional[List[int]] = None, |
| 120 | + offset_per_key: Optional[List[int]] = None, |
| 121 | + index_per_key: Optional[Dict[str, int]] = None, |
| 122 | + jt_dict: Optional[Dict[str, JaggedTensor]] = None, |
| 123 | + inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None, |
| 124 | + ) -> None: |
| 125 | + super().__init__( |
| 126 | + keys=keys, |
| 127 | + values=values, |
| 128 | + weights=weights, |
| 129 | + lengths=lengths, |
| 130 | + offsets=offsets, |
| 131 | + stride=stride, |
| 132 | + stride_per_key_per_rank=stride_per_key_per_rank, |
| 133 | + stride_per_key=stride_per_key, |
| 134 | + length_per_key=None, |
| 135 | + lengths_offset_per_key=lengths_offset_per_key, |
| 136 | + offset_per_key=offset_per_key, |
| 137 | + index_per_key=index_per_key, |
| 138 | + jt_dict=jt_dict, |
| 139 | + inverse_indices=inverse_indices, |
| 140 | + ) |
| 141 | + self.length_per_key_tensors: List[torch.Tensor] = [] |
| 142 | + for length in length_per_key or []: |
| 143 | + # dynamo does not support directly mark integers as dynamic, we thus apply a trick to embed the integer into a tensor's size and mark the size as dynamic |
| 144 | + t = torch.empty((length, 0)) |
| 145 | + torch._dynamo.mark_dynamic(t, 0) |
| 146 | + self.length_per_key_tensors.append(t) |
| 147 | + |
| 148 | + self.stride_per_key_per_rank_tensor: List[List[torch.Tensor]] = [] |
| 149 | + for strides_per_key in stride_per_key_per_rank or []: |
| 150 | + strides_per_key_list: List[torch.Tensor] = [] |
| 151 | + for s in strides_per_key: |
| 152 | + t = torch.empty((s, 0)) |
| 153 | + torch._dynamo.mark_dynamic(t, 0) |
| 154 | + strides_per_key_list.append(t) |
| 155 | + self.stride_per_key_per_rank_tensor.append(strides_per_key_list) |
| 156 | + |
| 157 | + def length_per_key(self) -> List[int]: |
| 158 | + if len(self.length_per_key_tensors) > 0: |
| 159 | + # since size has been marked as dynamic, we get a list of dynamic integers |
| 160 | + self._length_per_key = [t.size(0) for t in self.length_per_key_tensors] |
| 161 | + else: |
| 162 | + self._length_per_key = super().length_per_key() |
| 163 | + return self._length_per_key |
| 164 | + |
| 165 | + def stride_per_key_per_rank(self) -> List[List[int]]: |
| 166 | + if len(self.stride_per_key_per_rank_tensor) > 0: |
| 167 | + self._stride_per_key_per_rank = [ |
| 168 | + [t.size(0) for t in strides_per_key_list] |
| 169 | + for strides_per_key_list in self.stride_per_key_per_rank_tensor |
| 170 | + ] |
| 171 | + else: |
| 172 | + self._stride_per_key_per_rank = super().stride_per_key_per_rank() |
| 173 | + return self._stride_per_key_per_rank |
| 174 | + |
| 175 | + |
93 | 176 | # pyre-ignore
|
94 | 177 | def default_pipeline_input_transformer(inp):
|
| 178 | + # different input items need different handlings |
95 | 179 | for attr_name in ["id_list_features", "id_score_list_features"]:
|
96 | 180 | if hasattr(inp, attr_name):
|
97 | 181 | attr = getattr(inp, attr_name)
|
98 | 182 | if isinstance(attr, KeyedJaggedTensor):
|
99 | 183 | setattr(inp, attr_name, kjt_for_pt2_tracing(attr))
|
| 184 | + for attr_name in [ |
| 185 | + "uhm_history_timestamps", |
| 186 | + "raw_uhm_history_timestamps", |
| 187 | + "event_id_list_feature_invert_indexes", |
| 188 | + ]: |
| 189 | + if hasattr(inp, attr_name): |
| 190 | + attr = getattr(inp, attr_name) |
| 191 | + if isinstance(attr, dict): |
| 192 | + for key in attr: |
| 193 | + torch._dynamo.decorators.mark_dynamic(attr[key], 0) |
| 194 | + if hasattr(inp, "supervision_label"): |
| 195 | + torch._dynamo.decorators.mark_dynamic(inp.supervision_label["keys"], 0) |
| 196 | + torch._dynamo.decorators.mark_dynamic(inp.supervision_label["values"], 0) |
| 197 | + |
| 198 | + for attr_name in ["event_id_list_features_seqs"]: |
| 199 | + if hasattr(inp, attr_name): |
| 200 | + attr = getattr(inp, attr_name) |
| 201 | + if isinstance(attr, dict): |
| 202 | + for key in attr: |
| 203 | + if isinstance(attr[key], KeyedJaggedTensor): |
| 204 | + attr[key] = kjt_for_pt2_tracing(attr[key], mark_length=True) |
| 205 | + |
| 206 | + setattr(inp, attr_name, attr) |
| 207 | + |
100 | 208 | return inp
|
101 | 209 |
|
102 | 210 |
|
|
0 commit comments