Skip to content

Commit c172745

Browse files
Microvefacebook-github-bot
authored andcommitted
in progress
Differential Revision: D66976511
1 parent 92b903f commit c172745

File tree

1 file changed

+106
-3
lines changed

1 file changed

+106
-3
lines changed

torchrec/pt2/utils.py

+106-3
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99

1010

1111
import functools
12-
from typing import Any, Callable
12+
from typing import Any, Callable, Dict, List, Optional, Tuple
1313

1414
import torch
15-
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
15+
from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor
1616

1717
"""
1818
Prepares KJT for PT2 tracing.
@@ -28,6 +28,7 @@
2828
def kjt_for_pt2_tracing(
2929
kjt: KeyedJaggedTensor,
3030
convert_to_vb: bool = False,
31+
mark_length: bool = False,
3132
) -> KeyedJaggedTensor:
3233
# Breaking dependency cycle
3334
from torchrec.sparse.jagged_tensor import KeyedJaggedTensor
@@ -78,25 +79,127 @@ def kjt_for_pt2_tracing(
7879
weights = kjt.weights_or_none()
7980
if weights is not None:
8081
torch._dynamo.decorators.mark_unbacked(weights, 0)
82+
if mark_length:
83+
torch._dynamo.decorators.mark_unbacked(lengths, 0)
8184

82-
return KeyedJaggedTensor(
85+
length_per_key_marked_dynamic = []
86+
87+
for length in kjt.length_per_key():
88+
length_per_key_marked_dynamic.append(length)
89+
90+
return PT2KeyedJaggedTensor(
8391
keys=kjt.keys(),
8492
values=values,
8593
lengths=lengths,
8694
weights=weights,
8795
stride=stride if not is_vb else None,
8896
stride_per_key_per_rank=kjt.stride_per_key_per_rank() if is_vb else None,
8997
inverse_indices=inverse_indices,
98+
length_per_key=(length_per_key_marked_dynamic if is_vb else None),
9099
)
91100

92101

102+
class PT2KeyedJaggedTensor(KeyedJaggedTensor):
103+
"""
104+
This subclass of KeyedJaggedTensor is used to support PT2 tracing.
105+
We can apply some modifications to make KJT friendly for PT2 tracing.
106+
"""
107+
108+
def __init__(
109+
self,
110+
keys: List[str],
111+
values: torch.Tensor,
112+
weights: Optional[torch.Tensor] = None,
113+
lengths: Optional[torch.Tensor] = None,
114+
offsets: Optional[torch.Tensor] = None,
115+
stride: Optional[int] = None,
116+
stride_per_key_per_rank: Optional[List[List[int]]] = None,
117+
# Below exposed to ensure torch.script-able
118+
stride_per_key: Optional[List[int]] = None,
119+
length_per_key: Optional[List[int]] = None,
120+
lengths_offset_per_key: Optional[List[int]] = None,
121+
offset_per_key: Optional[List[int]] = None,
122+
index_per_key: Optional[Dict[str, int]] = None,
123+
jt_dict: Optional[Dict[str, JaggedTensor]] = None,
124+
inverse_indices: Optional[Tuple[List[str], torch.Tensor]] = None,
125+
) -> None:
126+
super().__init__(
127+
keys=keys,
128+
values=values,
129+
weights=weights,
130+
lengths=lengths,
131+
offsets=offsets,
132+
stride=stride,
133+
stride_per_key_per_rank=stride_per_key_per_rank,
134+
stride_per_key=stride_per_key,
135+
length_per_key=None,
136+
lengths_offset_per_key=lengths_offset_per_key,
137+
offset_per_key=offset_per_key,
138+
index_per_key=index_per_key,
139+
jt_dict=jt_dict,
140+
inverse_indices=inverse_indices,
141+
)
142+
self.length_per_key_tensors: List[torch.Tensor] = []
143+
for length in length_per_key or []:
144+
t = torch.empty((length, 0))
145+
torch._dynamo.mark_dynamic(t, 0)
146+
self.length_per_key_tensors.append(t)
147+
self.stride_per_key_per_rank_tensor: List[List[torch.Tensor]] = []
148+
for strides_per_key in stride_per_key_per_rank or []:
149+
strides_per_key_list: List[torch.Tensor] = []
150+
for s in strides_per_key:
151+
t = torch.empty((s, 0))
152+
torch._dynamo.mark_dynamic(t, 0)
153+
strides_per_key_list.append(t)
154+
self.stride_per_key_per_rank_tensor.append(strides_per_key_list)
155+
156+
def length_per_key(self) -> List[int]:
157+
if len(self.length_per_key_tensors) > 0:
158+
self._length_per_key = [t.size(0) for t in self.length_per_key_tensors]
159+
else:
160+
self._length_per_key = super().length_per_key()
161+
return self._length_per_key
162+
163+
def stride_per_key_per_rank(self) -> List[List[int]]:
164+
if len(self.stride_per_key_per_rank_tensor) > 0:
165+
self._stride_per_key_per_rank = [
166+
[t.size(0) for t in strides_per_key_list]
167+
for strides_per_key_list in self.stride_per_key_per_rank_tensor
168+
]
169+
stride_per_key_per_rank = self._stride_per_key_per_rank
170+
return stride_per_key_per_rank if stride_per_key_per_rank is not None else []
171+
172+
93173
# pyre-ignore
94174
def default_pipeline_input_transformer(inp):
95175
for attr_name in ["id_list_features", "id_score_list_features"]:
96176
if hasattr(inp, attr_name):
97177
attr = getattr(inp, attr_name)
98178
if isinstance(attr, KeyedJaggedTensor):
99179
setattr(inp, attr_name, kjt_for_pt2_tracing(attr))
180+
for attr_name in [
181+
"uhm_history_timestamps",
182+
"raw_uhm_history_timestamps",
183+
"event_id_list_feature_invert_indexes",
184+
]:
185+
if hasattr(inp, attr_name):
186+
attr = getattr(inp, attr_name)
187+
if isinstance(attr, dict):
188+
for key in attr:
189+
torch._dynamo.decorators.mark_dynamic(attr[key], 0)
190+
torch._dynamo.decorators.mark_dynamic(inp.supervision_label["keys"], 0)
191+
torch._dynamo.decorators.mark_dynamic(inp.supervision_label["values"], 0)
192+
193+
for attr_name in ["event_id_list_features_seqs"]:
194+
if hasattr(inp, attr_name):
195+
attr = getattr(inp, attr_name)
196+
if isinstance(attr, dict):
197+
for key in attr:
198+
if isinstance(attr[key], KeyedJaggedTensor):
199+
attr[key] = kjt_for_pt2_tracing(attr[key], mark_length=True)
200+
201+
setattr(inp, attr_name, attr)
202+
100203
return inp
101204

102205

0 commit comments

Comments
 (0)