Skip to content

Commit 929aa26

Browse files
ruochen99facebook-github-bot
authored andcommitted
SymInt and FakeTensor tracing compatibility
Differential Revision: D105054082
1 parent adaafec commit 929aa26

2 files changed

Lines changed: 6 additions & 4 deletions

File tree

generative_recommenders/common.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -261,9 +261,10 @@ def apply_sampling(
261261

262262
def switch_to_contiguous_if_needed(x: torch.Tensor) -> torch.Tensor:
263263
if not torch.jit.is_scripting() and torch.compiler.is_compiling():
264-
# Tell Dynamo this data-dependent value is in the range (0, 10**9)
265-
torch._check(x.size(0) > 0)
266-
torch._check(x.size(0) < 10**9)
264+
# Range the size as [0, 10**9-1] without requiring s > 0 to be
265+
# provable, so unbacked SymInts (e.g. from custom-op fake kernels)
266+
# don't fail the strict ``_check(s > 0)`` form at trace time.
267+
torch._check_is_size(x.size(0), max=10**9 - 1)
267268
if x.stride(-1) == 1:
268269
return x
269270
return x.contiguous()

generative_recommenders/ops/triton/triton_hstu_linear.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
)
3333
from generative_recommenders.ops.triton.triton_addmm import maybe_triton_addmm_fwd
3434
from generative_recommenders.ops.utils import maybe_register_custom_op
35+
from torch.fx.experimental.symbolic_shapes import guard_or_false
3536

3637

3738
def _get_layer_norm_mul_dropout_fwd_multirow_configs() -> List[triton.Config]:
@@ -1217,7 +1218,7 @@ def triton_layer_norm_mul_dropout_fwd(
12171218
assert weight.numel() == D
12181219
assert bias.numel() == D
12191220

1220-
if N == 0:
1221+
if guard_or_false(N == 0):
12211222
D = x.shape[1]
12221223
if concat_u and concat_x:
12231224
y = torch.empty((0, 3 * D), dtype=x.dtype, device=x.device)

0 commit comments

Comments
 (0)