Skip to content

Commit 0185a8d

Browse files
committed
Lint the code
1 parent f607751 commit 0185a8d

File tree

22 files changed

+116
-90
lines changed

22 files changed

+116
-90
lines changed

.github/workflows/linter.yaml

+3
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,9 @@ jobs:
1717
uses: actions/checkout@v3
1818
with:
1919
path: tritonbench
20+
- name: Install deps
21+
run: |
22+
pip install ruff-api
2023
- name: Check Formatting
2124
uses: omnilib/ufmt@action-v1
2225
with:

test/test_cpu/main.py

-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77

88

99
class TestTritonbenchCpu(unittest.TestCase):
10-
1110
def _get_test_op(self):
1211
parser = get_parser(["--device", "cpu", "--op", "test_op"])
1312
tb_args, extra_args = parser.parse_known_args(

tools/cuda_utils.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,5 @@ def check_torch_nightly_version(force_date: Optional[str] = None):
238238
if args.install_torch_nightly:
239239
install_pytorch_nightly(cuda_version=args.cudaver, env=os.environ)
240240
if args.check_torch_nightly_version:
241-
assert (
242-
not args.install_torch_nightly
243-
), "Error: Can't run install torch nightly and check version in the same command."
241+
assert not args.install_torch_nightly, "Error: Can't run install torch nightly and check version in the same command."
244242
check_torch_nightly_version(args.force_date)

tritonbench/components/workers/subprocess_rpc.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -274,8 +274,10 @@ def write(self, msg: bytes) -> None:
274274
def get_writer_pid(self) -> int:
275275
assert (
276276
self._writer_pid is not None
277-
), "Writer pid is not specified. Maybe calling from child process or input pipe.\
277+
), (
278+
"Writer pid is not specified. Maybe calling from child process or input pipe.\
278279
Please report a bug."
280+
)
279281
return self._writer_pid
280282

281283
def set_writer_pid(self, writer_pid: int) -> None:

tritonbench/kernels/triton_fused_attention.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,6 @@
3535

3636

3737
class TmaAutoTuneHelper:
38-
3938
# duck typing wrapper to implement the same interface as TmaDescKernelParam in Triton PR #4498
4039
class KernelParamWrapper:
4140
def __init__(self, desc):
@@ -457,7 +456,6 @@ def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
457456
HEAD_DIM: tl.constexpr, #
458457
STAGE: tl.constexpr, #
459458
):
460-
461459
tl.static_assert(BLOCK_N <= HEAD_DIM)
462460
start_m = tl.program_id(0)
463461
off_hz = tl.program_id(1)
@@ -569,7 +567,14 @@ def _attn_fwd_tma( # Q, V, desc_k, desc_v, sm_scale, M, Out, #
569567

570568
@triton.jit
571569
def _attn_bwd_preprocess(
572-
O, DO, Delta, Z, H, N_CTX, BLOCK_M: tl.constexpr, HEAD_DIM: tl.constexpr # # # #
570+
O,
571+
DO,
572+
Delta,
573+
Z,
574+
H,
575+
N_CTX,
576+
BLOCK_M: tl.constexpr,
577+
HEAD_DIM: tl.constexpr, # # # #
573578
):
574579
off_m = tl.program_id(0) * BLOCK_M + tl.arange(0, BLOCK_M)
575580
off_hz = tl.program_id(1)
@@ -900,7 +905,6 @@ def _attn_bwd(
900905

901906

902907
class _attention(torch.autograd.Function):
903-
904908
@staticmethod
905909
def forward(ctx, q, k, v, causal, sm_scale):
906910
# shape constraints
@@ -949,7 +953,7 @@ def forward(ctx, q, k, v, causal, sm_scale):
949953
N_CTX=q.shape[2], #
950954
HEAD_DIM=HEAD_DIM_K, #
951955
STAGE=stage, #
952-
**extra_kern_args
956+
**extra_kern_args,
953957
)
954958

955959
ctx.save_for_backward(q, k, v, o, M)
@@ -1021,7 +1025,6 @@ def backward(ctx, do):
10211025

10221026

10231027
class _attention_tma(torch.autograd.Function):
1024-
10251028
@staticmethod
10261029
def forward(ctx, q, k, v, causal, sm_scale):
10271030
# shape constraints
@@ -1175,7 +1178,7 @@ def grid_tma(META):
11751178
N_CTX=q.shape[2], #
11761179
HEAD_DIM=HEAD_DIM_K, #
11771180
STAGE=stage, #
1178-
**extra_kern_args
1181+
**extra_kern_args,
11791182
)
11801183

11811184
ctx.save_for_backward(q, k, v, o, M)

tritonbench/operators/gather_gemv/operator.py

-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626

2727

2828
class Operator(BenchmarkOperator):
29-
3029
@register_metric()
3130
def gbps(self, fn_name, example_inputs, metrics: BenchmarkOperatorMetrics):
3231
arg0_1, arg1_1, arg2_1 = example_inputs

tritonbench/operators/gather_gemv/triton_gather_gemv.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def triton_red_fused_mv_0(
8282
rbase = tl.arange(0, RBLOCK)[None, :].to(tl.int64)
8383
x0 = xindex
8484
# x0 // rnumel should have the same value of either 0 or 1
85-
tmp0 = tl.load(in_ptr0 + ((x0 // rnumel)), None, eviction_policy="evict_last")
85+
tmp0 = tl.load(in_ptr0 + (x0 // rnumel), None, eviction_policy="evict_last")
8686
_tmp11 = tl.full([XBLOCK, RBLOCK], 0, tl.float32)
8787
for roffset in range(0, rnumel, RBLOCK):
8888
rindex = roffset + rbase

tritonbench/operators/jagged_layer_norm/operator.py

+4-5
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,6 @@ def parse_op_args(args: List[str]):
3636

3737

3838
class Operator(BenchmarkOperator):
39-
4039
DEFAULT_METRICS = ["latency", "accuracy"]
4140
DEFAULT_PRECISION = "fp32"
4241

@@ -48,8 +47,8 @@ def __init__(
4847
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
4948
):
5049
super().__init__(tb_args, extra_args)
51-
self.sizes = list(range(2, 12, 4)) + list(
52-
range(12, 23, 3)
50+
self.sizes = (
51+
list(range(2, 12, 4)) + list(range(12, 23, 3))
5352
) # bias towards larger sizes, which are more representative of real-world shapes
5453

5554
args = parse_op_args(self.extra_args)
@@ -105,8 +104,8 @@ def _inner():
105104
) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm
106105

107106
padded_normalized = (
108-
padded_values - mean
109-
) * padded_mask_values # mask elements outside of the ragged dimension size for correct variance calculation
107+
(padded_values - mean) * padded_mask_values
108+
) # mask elements outside of the ragged dimension size for correct variance calculation
110109

111110
variance = (
112111
torch.sum(

tritonbench/operators/jagged_mean/kernels.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ def triton_jagged_mean_kernel_simple_fused_sum_then_buffer(
5353
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
5454
mask_m = offsets_m < M
5555

56-
ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
57-
input_ptr_offsets + (pid_b + 1)
56+
ragged_start, ragged_end = (
57+
tl.load(input_ptr_offsets + pid_b),
58+
tl.load(input_ptr_offsets + (pid_b + 1)),
5859
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
5960
ragged_len = ragged_end - ragged_start
6061

@@ -133,8 +134,9 @@ def triton_jagged_mean_kernel_simple_fused_buffer_then_sum(
133134
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
134135
mask_m = offsets_m < M
135136

136-
ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
137-
input_ptr_offsets + (pid_b + 1)
137+
ragged_start, ragged_end = (
138+
tl.load(input_ptr_offsets + pid_b),
139+
tl.load(input_ptr_offsets + (pid_b + 1)),
138140
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
139141
ragged_len = ragged_end - ragged_start
140142

@@ -212,8 +214,9 @@ def triton_jagged_mean_kernel_variable_length_loop_sum_then_buffer(
212214
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
213215
mask_m = offsets_m < M
214216

215-
ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
216-
input_ptr_offsets + (pid_b + 1)
217+
ragged_start, ragged_end = (
218+
tl.load(input_ptr_offsets + pid_b),
219+
tl.load(input_ptr_offsets + (pid_b + 1)),
217220
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
218221
ragged_len = ragged_end - ragged_start
219222

@@ -288,8 +291,9 @@ def triton_jagged_mean_kernel_variable_length_loop_buffer_then_sum(
288291
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
289292
mask_m = offsets_m < M
290293

291-
ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load(
292-
input_ptr_offsets + (pid_ragged + 1)
294+
ragged_start, ragged_end = (
295+
tl.load(input_ptr_offsets + pid_ragged),
296+
tl.load(input_ptr_offsets + (pid_ragged + 1)),
293297
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
294298
ragged_len = ragged_end - ragged_start
295299

tritonbench/operators/jagged_mean/operator.py

+30-20
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,6 @@ def execute_kernel_variable_length_loop(x, sum_then_buffer):
9292

9393

9494
class Operator(BenchmarkOperator):
95-
9695
DEFAULT_METRICS = ["latency", "accuracy"]
9796
DEFAULT_PRECISION = "fp32"
9897

@@ -104,8 +103,8 @@ def __init__(
104103
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
105104
):
106105
super().__init__(tb_args, extra_args)
107-
self.sizes = list(range(2, 12, 4)) + list(
108-
range(12, 23, 3)
106+
self.sizes = (
107+
list(range(2, 12, 4)) + list(range(12, 23, 3))
109108
) # bias towards larger sizes, which are more representative of real-world shapes
110109

111110
args = parse_op_args(self.extra_args)
@@ -130,28 +129,37 @@ def torch_jagged_mean_unbind_torch_mean(
130129
def torch_jagged_mean_torch_nanmean(
131130
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
132131
):
133-
return lambda: torch.nanmean(
134-
torch.ops.aten._jagged_to_padded_dense_forward(
135-
x.values(),
136-
[x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
137-
max_lengths=[seqlen], # max length of ragged dimension
138-
padding_value=float("nan"),
139-
),
140-
dim=1,
132+
return (
133+
lambda: torch.nanmean(
134+
torch.ops.aten._jagged_to_padded_dense_forward(
135+
x.values(),
136+
[
137+
x.offsets()
138+
], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
139+
max_lengths=[seqlen], # max length of ragged dimension
140+
padding_value=float("nan"),
141+
),
142+
dim=1,
143+
)
141144
)
142145

143146
@register_benchmark()
144147
def torch_jagged_mean_torch_sum(
145148
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
146149
):
147-
return lambda: torch.sum(
148-
torch.ops.aten._jagged_to_padded_dense_forward(
149-
x.values(),
150-
[x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
151-
max_lengths=[seqlen], # max length of ragged dimension
152-
),
153-
dim=1,
154-
) / x.offsets().diff().unsqueeze(1)
150+
return (
151+
lambda: torch.sum(
152+
torch.ops.aten._jagged_to_padded_dense_forward(
153+
x.values(),
154+
[
155+
x.offsets()
156+
], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
157+
max_lengths=[seqlen], # max length of ragged dimension
158+
),
159+
dim=1,
160+
)
161+
/ x.offsets().diff().unsqueeze(1)
162+
)
155163

156164
@register_benchmark()
157165
def triton_jagged_mean_simple_fused(
@@ -176,7 +184,9 @@ def torch_compile_nested_tensor_integration(
176184
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
177185
):
178186
def _inner(x: torch.Tensor): # mean along ragged dimension (dim == 1)
179-
return torch.mean(x, dim=x._ragged_idx, keepdim=True) # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `_ragged_idx`.
187+
return torch.mean(
188+
x, dim=x._ragged_idx, keepdim=True
189+
) # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `_ragged_idx`.
180190

181191
torch_compile_func = torch.compile(_inner)
182192
return lambda: torch_compile_func(x)

tritonbench/operators/jagged_softmax/kernels.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,9 @@ def triton_jagged_softmax_kernel_simple_fused_buffer_then_sum(
5454
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
5555
mask_m = offsets_m < M
5656

57-
ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
58-
input_ptr_offsets + (pid_b + 1)
57+
ragged_start, ragged_end = (
58+
tl.load(input_ptr_offsets + pid_b),
59+
tl.load(input_ptr_offsets + (pid_b + 1)),
5960
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
6061

6162
buffer_max_all = tl.full(
@@ -163,8 +164,9 @@ def triton_jagged_softmax_kernel_variable_length_loop_buffer_then_sum(
163164
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
164165
mask_m = offsets_m < M
165166

166-
ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
167-
input_ptr_offsets + (pid_b + 1)
167+
ragged_start, ragged_end = (
168+
tl.load(input_ptr_offsets + pid_b),
169+
tl.load(input_ptr_offsets + (pid_b + 1)),
168170
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
169171

170172
buffer_max_all = tl.full(

tritonbench/operators/jagged_softmax/operator.py

+8-5
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def parse_op_args(args: List[str]):
7171

7272

7373
class Operator(BenchmarkOperator):
74-
7574
DEFAULT_METRICS = ["latency", "accuracy", "best_config"]
7675
DEFAULT_PRECISION = "fp32"
7776

@@ -83,8 +82,8 @@ def __init__(
8382
self, tb_args: argparse.Namespace, extra_args: Optional[List[str]] = None
8483
):
8584
super().__init__(tb_args, extra_args)
86-
self.sizes = list(range(2, 12, 4)) + list(
87-
range(12, 23, 3)
85+
self.sizes = (
86+
list(range(2, 12, 4)) + list(range(12, 23, 3))
8887
) # bias towards larger sizes, which are more representative of real-world shapes
8988

9089
args = parse_op_args(self.extra_args)
@@ -114,7 +113,9 @@ def torch_jagged_softmax_torch_sum(
114113
def _inner():
115114
padded = torch.ops.aten._jagged_to_padded_dense_forward(
116115
x.values(),
117-
[x.offsets()], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
116+
[
117+
x.offsets()
118+
], # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `offsets`.
118119
max_lengths=[seqlen], # max length of ragged dimension
119120
padding_value=float("-inf"), # e^-inf = 0
120121
)
@@ -153,7 +154,9 @@ def torch_compile_nested_tensor_integration(
153154
self, x: torch.Tensor, B: int, M: int, seqlen: int, sparsity: float
154155
):
155156
def _inner(x: torch.Tensor): # softmax along ragged dimension
156-
return torch.softmax(x, dim=x._ragged_idx) # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `_ragged_idx`.
157+
return torch.softmax(
158+
x, dim=x._ragged_idx
159+
) # pyre-ignore: Undefined attribute [16]: `torch._tensor.Tensor` has no attribute `_ragged_idx`.
157160

158161
torch_compile_func = torch.compile(_inner)
159162
return lambda: torch_compile_func(

tritonbench/operators/jagged_sum/kernels.py

+12-8
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,9 @@ def triton_jagged_sum_kernel_simple_fused_sum_then_buffer(
5252
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
5353
mask_m = offsets_m < M
5454

55-
ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load(
56-
input_ptr_offsets + (pid_ragged + 1)
55+
ragged_start, ragged_end = (
56+
tl.load(input_ptr_offsets + pid_ragged),
57+
tl.load(input_ptr_offsets + (pid_ragged + 1)),
5758
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
5859

5960
for block_pos in range(
@@ -127,8 +128,9 @@ def triton_jagged_sum_kernel_simple_fused_buffer_then_sum(
127128
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
128129
mask_m = offsets_m < M
129130

130-
ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load(
131-
input_ptr_offsets + (pid_ragged + 1)
131+
ragged_start, ragged_end = (
132+
tl.load(input_ptr_offsets + pid_ragged),
133+
tl.load(input_ptr_offsets + (pid_ragged + 1)),
132134
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
133135

134136
for block_pos in range(
@@ -201,8 +203,9 @@ def triton_jagged_sum_kernel_variable_length_loop_sum_then_buffer(
201203
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
202204
mask_m = offsets_m < M
203205

204-
ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_b), tl.load(
205-
input_ptr_offsets + (pid_b + 1)
206+
ragged_start, ragged_end = (
207+
tl.load(input_ptr_offsets + pid_b),
208+
tl.load(input_ptr_offsets + (pid_b + 1)),
206209
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
207210

208211
for block_start_ragged in range(
@@ -272,8 +275,9 @@ def triton_jagged_sum_kernel_variable_length_loop_buffer_then_sum(
272275
offsets_m = block_start_m + tl.arange(0, BLOCK_SIZE_M)
273276
mask_m = offsets_m < M
274277

275-
ragged_start, ragged_end = tl.load(input_ptr_offsets + pid_ragged), tl.load(
276-
input_ptr_offsets + (pid_ragged + 1)
278+
ragged_start, ragged_end = (
279+
tl.load(input_ptr_offsets + pid_ragged),
280+
tl.load(input_ptr_offsets + (pid_ragged + 1)),
277281
) # load start and end offsets for current program, similar to offsets[i] and offsets[i + 1]
278282

279283
for block_start_ragged in range(

0 commit comments

Comments
 (0)