Skip to content

Commit 3b13683

Browse files
PaulZhang12facebook-github-bot
authored andcommitted
TorchScript bad_alloc issue (#2542)
Summary: Pull Request resolved: #2542 Differential Revision: D65495806
1 parent 509b0d2 commit 3b13683

File tree

3 files changed

+6
-5
lines changed

3 files changed

+6
-5
lines changed

torchrec/models/tests/test_deepfm.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -202,7 +202,7 @@ def test_fx_script(self) -> None:
202202
sparse_features=sparse_features,
203203
)
204204

205-
gm = symbolic_trace(deepfm_nn)
205+
gm = torch.fx.GraphModule(deepfm_nn, Tracer().trace(deepfm_nn))
206206

207207
scripted_gm = torch.jit.script(gm)
208208

torchrec/modules/tests/test_mc_modules.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import Dict
1212

1313
import torch
14+
from torchrec.fx import Tracer
1415
from torchrec.modules.mc_modules import (
1516
average_threshold_filter,
1617
DistanceLFU_EvictionPolicy,
@@ -357,5 +358,5 @@ def test_fx_jit_script_not_training(self) -> None:
357358
)
358359

359360
model.train(False)
360-
gm = torch.fx.symbolic_trace(model)
361+
gm = torch.fx.GraphModule(model, Tracer().trace(model))
361362
torch.jit.script(gm)

torchrec/modules/tests/test_mlp.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import torch
1515
from hypothesis import given, settings
1616
from torch import nn
17-
from torchrec.fx import symbolic_trace
17+
from torchrec.fx import symbolic_trace, Tracer
1818
from torchrec.modules.mlp import MLP, Perceptron
1919

2020

@@ -99,13 +99,13 @@ def test_fx_script_Perceptron(self) -> None:
9999
# Dry-run to initialize lazy module.
100100
m(torch.randn(batch_size, in_features))
101101

102-
gm = symbolic_trace(m)
102+
gm = torch.fx.GraphModule(m, Tracer().trace(m))
103103
torch.jit.script(gm)
104104

105105
def test_fx_script_MLP(self) -> None:
106106
in_features = 3
107107
layer_sizes = [16, 8, 4]
108108
m = MLP(in_features, layer_sizes)
109109

110-
gm = symbolic_trace(m)
110+
gm = torch.fx.GraphModule(m, Tracer().trace(m))
111111
torch.jit.script(gm)

0 commit comments

Comments
 (0)