Skip to content

Commit 60d70c3

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
TorchScript bad_alloc issue (#2542)
Summary: Pull Request resolved: #2542 Differential Revision: D65495806
1 parent 42c512c commit 60d70c3

File tree

3 files changed

+13
-7
lines changed

3 files changed

+13
-7
lines changed

torchrec/models/tests/test_deepfm.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,11 @@ def test_basic(self) -> None:
3535

3636
# check tracer compatibility
3737
gm = torch.fx.GraphModule(dense_arch, Tracer().trace(dense_arch))
38-
script = torch.jit.script(gm)
39-
script(dense_arch_input)
38+
39+
# TODO: Causes std::bad_alloc in OSS env
40+
# script = torch.jit.script(gm)
41+
42+
# script(dense_arch_input)
4043

4144

4245
class FMInteractionArchTest(unittest.TestCase):

torchrec/modules/tests/test_mc_modules.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Dict
1212

1313
import torch
14+
from torchrec.fx import Tracer
1415
from torchrec.modules.mc_modules import (
1516
average_threshold_filter,
1617
DistanceLFU_EvictionPolicy,
@@ -357,5 +358,6 @@ def test_fx_jit_script_not_training(self) -> None:
357358
)
358359

359360
model.train(False)
360-
gm = torch.fx.symbolic_trace(model)
361-
torch.jit.script(gm)
361+
gm = torch.fx.GraphModule(model, Tracer().trace(model))
362+
# TODO: Causes std::bad_alloc in OSS env
363+
# torch.jit.script(gm)

torchrec/modules/tests/test_mlp.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515
from hypothesis import given, settings
1616
from torch import nn
17-
from torchrec.fx import symbolic_trace
17+
from torchrec.fx import symbolic_trace, Tracer
1818
from torchrec.modules.mlp import MLP, Perceptron
1919

2020

@@ -107,5 +107,6 @@ def test_fx_script_MLP(self) -> None:
107107
layer_sizes = [16, 8, 4]
108108
m = MLP(in_features, layer_sizes)
109109

110-
gm = symbolic_trace(m)
111-
torch.jit.script(gm)
110+
gm = torch.fx.GraphModule(m, Tracer().trace(m))
111+
# TODO: Causes std::bad_alloc in OSS env
112+
# torch.jit.script(gm)

0 commit comments

Comments
 (0)