Skip to content

Commit 03fa086

Browse files
xuzhao9facebook-github-bot
authored andcommitted
Fix import bug in data_utils.py
Summary: As the title says Reviewed By: nmacchioni Differential Revision: D65881268 fbshipit-source-id: 91ab130b133e2d35e15244d971882f3b51946331
1 parent f63be70 commit 03fa086

File tree

3 files changed

+7
-3
lines changed

3 files changed

+7
-3
lines changed

tritonbench/operators/softmax/operator.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def _inner():
103103

104104
def get_input_iter(self):
105105
M = 4096
106-
shapes = (tuple(M, 128 * i) for i in range(2, 100))
106+
shapes = [(M, 128 * i) for i in range(2, 100)]
107107
if IS_FBCODE and self.tb_args.production_shapes:
108108
shapes = get_production_shapes(self.name, "softmax")
109109
for M, N in shapes:

tritonbench/utils/data_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from .triton_ops import IS_FBCODE
1+
from .triton_op import IS_FBCODE
22

33

44
def get_production_shapes(op_name, op_type):

tritonbench/utils/parser.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -200,8 +200,12 @@ def _find_param_loc(params, key: str) -> int:
200200
def _remove_params(params, loc):
201201
if loc == -1:
202202
return params
203-
if (loc + 1) < len(params) and params[loc + 1].startswith("--"):
203+
if loc == len(params) - 1:
204+
return params[:loc]
205+
if params[loc + 1].startswith("--"):
204206
return params[:loc] + params[loc + 1 :]
207+
if loc == len(params) - 2:
208+
return params[:loc]
205209
return params[:loc] + params[loc + 2 :]
206210

207211

0 commit comments

Comments
 (0)