Skip to content

Commit 14200c3

Browse files
author
Trevor Morris
authored
[Neo] Get cuda arch from cuda_target_arch rather than querying gpu (#131)
* [Neo] Get cuda arch from cuda_target_arch rather than querying gpu * fix lint branch * move import
1 parent 99048cb commit 14200c3

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

python/tvm/relay/op/strategy/cuda.py

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,20 @@
2020
import tvm
2121
from tvm.te import SpecializedCondition
2222
from tvm.contrib import nvcc
23+
from tvm.autotvm.env import AutotvmGlobalScope
2324
from .generic import *
2425
from .. import op as _op
2526
from .... import get_global_func
2627

28+
def get_cross_compile_compute_ver():
29+
"""Temporary workaround to enable cross compiling for GPU in Neo. tvm.gpu(0).compute_version
30+
will encounter an error if there is no GPU present. Instead, we use compute_version from
31+
set_cuda_target_arch"""
32+
if AutotvmGlobalScope.current.cuda_target_arch:
33+
arch = AutotvmGlobalScope.current.cuda_target_arch.split("sm_")[-1]
34+
return arch[0] + "." + arch[1]
35+
return tvm.gpu(0).compute_version
36+
2737
@schedule_injective.register(["cuda", "gpu"])
2838
def schedule_injective_cuda(attrs, outs, target):
2939
"""schedule injective ops for cuda"""
@@ -146,7 +156,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
146156
pre_flag=False)
147157
if judge_winograd_shape:
148158
if target.target_name == "cuda" and \
149-
nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \
159+
nvcc.have_tensorcore(get_cross_compile_compute_ver()) and \
150160
judge_winograd_tensorcore:
151161
strategy.add_implementation(
152162
wrap_compute_conv2d(topi.cuda.conv2d_nhwc_winograd_tensorcore),
@@ -163,7 +173,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
163173
name="conv2d_nhwc_winograd_direct.cuda",
164174
plevel=5)
165175
if target.target_name == "cuda":
166-
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
176+
if nvcc.have_tensorcore(get_cross_compile_compute_ver()):
167177
if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
168178
(N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
169179
(N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0):
@@ -265,7 +275,7 @@ def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_ty
265275
dilation_h, dilation_w,
266276
pre_flag=True)
267277
if target.target_name == "cuda" and \
268-
nvcc.have_tensorcore(tvm.gpu(0).compute_version) and \
278+
nvcc.have_tensorcore(get_cross_compile_compute_ver()) and \
269279
judge_winograd_tensorcore:
270280
strategy.add_implementation(
271281
wrap_compute_conv2d(
@@ -363,7 +373,7 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
363373
N, _, _, _, _ = get_const_tuple(data.shape)
364374
_, _, _, CI, CO = get_const_tuple(kernel.shape)
365375
if target.target_name == "cuda":
366-
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
376+
if nvcc.have_tensorcore(get_cross_compile_compute_ver()):
367377
if (N % 16 == 0 and CI % 16 == 0 and CO % 16 == 0) or \
368378
(N % 8 == 0 and CI % 16 == 0 and CO % 32 == 0) or \
369379
(N % 32 == 0 and CI % 16 == 0 and CO % 8 == 0):
@@ -459,7 +469,7 @@ def dense_strategy_cuda(attrs, inputs, out_type, target):
459469
name="dense_large_batch.cuda",
460470
plevel=5)
461471
if target.target_name == "cuda":
462-
if nvcc.have_tensorcore(tvm.gpu(0).compute_version):
472+
if nvcc.have_tensorcore(get_cross_compile_compute_ver()):
463473
if(i % 16 == 0 and b % 16 == 0 and o % 16 == 0) \
464474
or (i % 16 == 0 and b % 8 == 0 and o % 32 == 0) \
465475
or (i % 16 == 0 and b % 32 == 0 and o % 8 == 0):

tests/scripts/task_lint.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ echo "clang-format check..."
5151
# check lastest change, for squash merge into master
5252
./tests/lint/git-clang-format.sh HEAD~1
5353
# chekc against origin/master for PRs.
54-
./tests/lint/git-clang-format.sh origin/dev
54+
./tests/lint/git-clang-format.sh origin/release-1.3.0
5555

5656
echo "Check codestyle of python code..."
5757
make pylint

0 commit comments

Comments
 (0)