Skip to content

Commit c139171

Browse files
Microvefacebook-github-bot
authored andcommitted
Apply input transformation to eliminate recompilations (#2645)
Summary: KJT has Differential Revision: D66976511
1 parent 92b903f commit c139171

File tree

1 file changed

+111
-3
lines changed

1 file changed

+111
-3
lines changed

torchrec/pt2/utils.py

+111-3
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010

1111
import functools
12-
from typing import Any, Callable
12+
from typing import Any, Callable, Dict, List, Optional, Tuple
1313

1414
import torch
15-
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
15+
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
1616

1717
"""
1818
Prepares KJT for PT2 tracing.
@@ -28,6 +28,7 @@
2828
def kjt_for_pt2_tracing(
2929
kjt: KeyedJaggedTensor,
3030
convert_to_vb: bool = False,
31+
mark_length: bool = False,
3132
) -> KeyedJaggedTensor:
3233
# Breaking dependency cycle
3334
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
@@ -78,25 +79,132 @@ def kjt_for_pt2_tracing(
7879
weights = kjt.weights_or_none()
7980
if weights is not None:
8081
torch._dynamo.decorators.mark_unbacked(weights, 0)
82+
if mark_length:
83+
torch._dynamo.decorators.mark_unbacked(lengths, 0)
8184

82-
return KeyedJaggedTensor(
85+
length_per_key_marked_dynamic = []
86+
87+
for length in kjt.length_per_key():
88+
length_per_key_marked_dynamic.append(length)
89+
90+
return PT2KeyedJaggedTensor(
8391
keys=kjt.keys(),
8492
values=values,
8593
lengths=lengths,
8694
weights=weights,
8795
stride=stride if not is_vb else None,
8896
stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None,
8997
inverse_indices=inverse_indices,
98+
length_per_key=(length_per_key_marked_dynamic if is_vb else None),
9099
)
91100

92101

102+
class PT2KeyedJaggedTensor(KeyedJaggedTensor):
103+
"""
104+
This subclass of KeyedJaggedTensor is used to support PT2 tracing.
105+
We can apply some modifications to make KJT friendly for PT2 tracing.
106+
"""
107+
108+
def __init__(
109+
self,
110+
keys: List[str],
111+
values: torch.Tensor,
112+
weights: Optional[torch.Tensor] = None,
113+
lengths: Optional[torch.Tensor] = None,
114+
offsets: Optional[torch.Tensor] = None,
115+
stride: Optional[int] = None,
116+
stride_per_key_per_rank: Optional[List[List[int]]] = None,
117+
stride_per_key: Optional[List[int]] = None,
118+
length_per_key: Optional[List[int]] = None,
119+
lengths_offset_per_key: Optional[List[int]] = None,
120+
offset_per_key: Optional[List[int]] = None,
121+
index_per_key: Optional[Dict[str, int]] = None,
122+
jt_dict: Optional[Dict[str, JaggedTensor]] = None,
123+
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
124+
) -> None:
125+
super().__init__(
126+
keys=keys,
127+
values=values,
128+
weights=weights,
129+
lengths=lengths,
130+
offsets=offsets,
131+
stride=stride,
132+
stride_per_key_per_rank=stride_per_key_per_rank,
133+
stride_per_key=stride_per_key,
134+
length_per_key=None,
135+
lengths_offset_per_key=lengths_offset_per_key,
136+
offset_per_key=offset_per_key,
137+
index_per_key=index_per_key,
138+
jt_dict=jt_dict,
139+
inverse_indices=inverse_indices,
140+
)
141+
self.length_per_key_tensors: List[torch.Tensor] = []
142+
for length in length_per_key or []:
143+
# dynamo does not support directly mark integers as dynamic, we thus apply a trick to embed the integer into a tensor's size and mark the size as dynamic
144+
t = torch.empty((length, 0))
145+
torch._dynamo.mark_dynamic(t, 0)
146+
self.length_per_key_tensors.append(t)
147+
148+
self.stride_per_key_per_rank_tensor: List[List[torch.Tensor]] = []
149+
for strides_per_key in stride_per_key_per_rank or []:
150+
strides_per_key_list: List[torch.Tensor] = []
151+
for s in strides_per_key:
152+
t = torch.empty((s, 0))
153+
torch._dynamo.mark_dynamic(t, 0)
154+
strides_per_key_list.append(t)
155+
self.stride_per_key_per_rank_tensor.append(strides_per_key_list)
156+
157+
def length_per_key(self) -> List[int]:
158+
if len(self.length_per_key_tensors) > 0:
159+
# since size has been marked as dynamic, we get a list of dynamic integers
160+
self._length_per_key = [t.size(0) for t in self.length_per_key_tensors]
161+
else:
162+
self._length_per_key = super().length_per_key()
163+
return self._length_per_key
164+
165+
def stride_per_key_per_rank(self) -> List[List[int]]:
166+
if len(self.stride_per_key_per_rank_tensor) > 0:
167+
self._stride_per_key_per_rank = [
168+
[t.size(0) for t in strides_per_key_list]
169+
for strides_per_key_list in self.stride_per_key_per_rank_tensor
170+
]
171+
else:
172+
self._stride_per_key_per_rank = super().stride_per_key_per_rank()
173+
return self._stride_per_key_per_rank
174+
175+
93176
# pyre-ignore
94177
def default_pipeline_input_transformer(inp):
178+
# different input items need different handlings
95179
for attr_name in ["id_list_features", "id_score_list_features"]:
96180
if hasattr(inp, attr_name):
97181
attr = getattr(inp, attr_name)
98182
if isinstance(attr, KeyedJaggedTensor):
99183
setattr(inp, attr_name, kjt_for_pt2_tracing(attr))
184+
for attr_name in [
185+
"uhm_history_timestamps",
186+
"raw_uhm_history_timestamps",
187+
"event_id_list_feature_invert_indexes",
188+
]:
189+
if hasattr(inp, attr_name):
190+
attr = getattr(inp, attr_name)
191+
if isinstance(attr, dict):
192+
for key in attr:
193+
torch._dynamo.decorators.mark_dynamic(attr[key], 0)
194+
if hasattr(inp, "supervision_label"):
195+
torch._dynamo.decorators.mark_dynamic(inp.supervision_label["keys"], 0)
196+
torch._dynamo.decorators.mark_dynamic(inp.supervision_label["values"], 0)
197+
198+
for attr_name in ["event_id_list_features_seqs"]:
199+
if hasattr(inp, attr_name):
200+
attr = getattr(inp, attr_name)
201+
if isinstance(attr, dict):
202+
for key in attr:
203+
if isinstance(attr[key], KeyedJaggedTensor):
204+
attr[key] = kjt_for_pt2_tracing(attr[key], mark_length=True)
205+
206+
setattr(inp, attr_name, attr)
207+
100208
return inp
101209

102210

0 commit comments

Comments
 (0)