forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_nvfuser_dynamo.py
148 lines (119 loc) · 5.15 KB
/
test_nvfuser_dynamo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
# Owner(s): ["module: nvfuser"]
import unittest
import warnings
from functools import partial
import torch
import torch._dynamo as torchdynamo
from torch.testing import make_tensor
from torch.testing._internal.common_utils import (
IS_WINDOWS,
run_tests,
skipIfTorchDynamo,
TEST_WITH_ROCM,
TestCase,
)
from torch.testing._internal.jit_utils import RUN_CUDA
RUN_NVFUSER = RUN_CUDA and not TEST_WITH_ROCM
def is_pre_volta():
if not RUN_NVFUSER:
return False
prop = torch.cuda.get_device_properties(torch.cuda.current_device())
return prop.major < 7
def is_networkx_available():
try:
import networkx # noqa: F401
return True
except ImportError:
return False
@skipIfTorchDynamo("Not a suitable test for TorchDynamo")
@unittest.skipIf(IS_WINDOWS, "TorchDynamo is not supported on Windows")
@unittest.skipIf(not RUN_NVFUSER, "requires CUDA")
@unittest.skipIf(is_pre_volta(), "Only supported on Volta and newer devices.")
class TestNvFuserDynamo(TestCase):
def test_basic(self):
input1 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float32)
input2 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float32)
@torchdynamo.optimize("nvprims_nvfuser")
def func(a, b):
return a.sin() + b.cos()
# No warnings and no errors
with warnings.catch_warnings(record=True) as w:
nvfuser_result = func(input1, input2)
self.assertEqual(len(w), 0)
eager_result = func.__wrapped__(input1, input2)
self.assertEqual(eager_result, nvfuser_result)
@unittest.skipIf(not is_networkx_available(), "networkx not available")
def test_min_cut(self):
from functorch.compile import default_partition
from torch._dynamo.optimizations.training import nvprims_fw_bw_partition_fn
def get_fw_bw_graph(f, inps, partitioner):
from functorch.compile import aot_function
# Helper functions are taken from functorch/test_aotdispatch.py
def extract_graph(fx_g, _, graph_cell):
graph_cell[0] = fx_g
return fx_g
fw_graph_cell = [None]
bw_graph_cell = [None]
aot_function(
f,
fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell),
partition_fn=partitioner,
)(*inps).sum().backward()
return (fw_graph_cell[0], bw_graph_cell[0])
def get_ins_outs(fx_g):
ins = []
outs = []
for n in fx_g.graph.nodes:
if n.op == "placeholder":
ins.append(n)
elif n.op == "output":
outs = tuple(n.args[0])
return ins, outs
def get_num_ins_outs(fx_g):
return tuple(len(i) for i in get_ins_outs(fx_g))
def func(x):
return x * x * x
input1 = make_tensor(
(3,), device="cpu", dtype=torch.float32, requires_grad=True
)
fw_graph, bw_graph = get_fw_bw_graph(func, [input1], default_partition)
self.assertEqual(get_num_ins_outs(fw_graph), (1, 3))
self.assertEqual(get_num_ins_outs(bw_graph), (3, 1))
input1 = make_tensor(
(3,), device="cpu", dtype=torch.float32, requires_grad=True
)
fw_graph, bw_graph = get_fw_bw_graph(func, [input1], nvprims_fw_bw_partition_fn)
self.assertEqual(get_num_ins_outs(fw_graph), (1, 2))
self.assertEqual(get_num_ins_outs(bw_graph), (2, 1))
def test_batch_norm_implicit_dtype_promotion(self):
input1 = make_tensor((2, 3, 4, 5), device="cuda", dtype=torch.float32)
input2 = make_tensor((5, 5), device="cuda", dtype=torch.float32)
w = make_tensor((3), device="cuda", dtype=torch.float32)
b = make_tensor((3), device="cuda", dtype=torch.float32)
@torchdynamo.optimize("nvprims_nvfuser")
def func(mat1, mat2, w, b):
o = torch.matmul(mat1, mat2)
return torch.batch_norm(o, w, b, None, None, True, 1e-2, 1e-5, True)
# No warnings and no errors
with torch.cuda.amp.autocast():
with warnings.catch_warnings(record=True) as warning:
nvfuser_result = func(input1, input2, w, b)
self.assertEqual(len(warning), 0)
eager_result = func.__wrapped__(input1, input2, w, b)
self.assertEqual(eager_result, nvfuser_result)
def test_dtype_correctness(self):
input1 = make_tensor((2, 4, 8), device="cuda", dtype=torch.float16)
@torchdynamo.optimize("nvprims_nvfuser")
def func(a):
tmp = a + 1.0
# nvfuser would promote output to fp32 in math, FusionDefinition should cast output dtype back
return torch.where(tmp > 0, tmp, 0.0)
# No warnings and no errors
with warnings.catch_warnings(record=True) as w:
nvfuser_result = func(input1)
self.assertEqual(len(w), 0)
eager_result = func.__wrapped__(input1)
self.assertEqual(eager_result, nvfuser_result)
if __name__ == "__main__":
run_tests()