Skip to content

Commit 6d85a7d

Browse files
committed
fix lint
1 parent 8196f49 commit 6d85a7d

File tree

4 files changed

+24
-24
lines changed

4 files changed

+24
-24
lines changed

modules/flash_attn_triton_amd/fwd_prefill.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,9 @@ def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, # pylint: disable=un
295295
# The tensor allocated for L is based on MAX_SEQLENS_Q as that is
296296
# statically known.
297297
l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m
298-
l_ptrs = l_offset + offs_m * stride_lse_m
298+
l_ptrs = l_offset + offs_m * stride_lse_m
299299

300-
l = tl.full([BLOCK_M], value=0.0, dtype=ACCUMULATOR_TYPE)
300+
l = tl.full([BLOCK_M], value=0.0, dtype=ACCUMULATOR_TYPE) # noqa: E741
301301

302302
# mask_m_offsets = start_m + tl.arange(0, BLOCK_M)
303303
# lse_mask = mask_m_offsets < causal_start_idx
@@ -450,7 +450,7 @@ def attn_fwd(Q, K, V, bias, Cache_seqlens, Cache_batch_idx, # pylint: disable=un
450450

451451
# write back LSE(Log Sum Exponents), the log of the normalization constant
452452
l_offset = LSE + off_z * stride_lse_z + off_h_q * stride_lse_h + cu_seqlens_q_start * stride_lse_m
453-
l_ptrs = l_offset + offs_m * stride_lse_m
453+
l_ptrs = l_offset + offs_m * stride_lse_m
454454
if USE_EXP2:
455455
RCP_LN2: tl.constexpr = 1.4426950408889634
456456
LN2: tl.constexpr = 0.6931471824645996
@@ -499,9 +499,9 @@ def attention_prefill_forward_triton_impl(
499499
bias: Optional[torch.Tensor],
500500
layout: Literal["bshd", "bhsd", "thd"],
501501
# varlen
502-
cu_seqlens_q: Optional[torch.Tensor],
502+
cu_seqlens_q: Optional[torch.Tensor],
503503
cu_seqlens_k: Optional[torch.Tensor],
504-
max_seqlens_q: int,
504+
max_seqlens_q: int,
505505
max_seqlens_k: int,
506506
# inference
507507
cache_seqlens: Optional[Union[(int, torch.Tensor)]],
@@ -570,7 +570,7 @@ def attention_prefill_forward_triton_impl(
570570
attn_fwd[grid](q, k, v, bias, cache_seqlens, cache_batch_idx,
571571
sm_scale, softmax_lse, o, *q_strides, *k_strides, *v_strides, *o_strides,
572572
*bias_strides, stride_az, stride_ah, *scores_strides, stride_lse_z, stride_lse_h, stride_lse_m, cu_seqlens_q, cu_seqlens_k,
573-
dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes,
573+
dropout_p=dropout_p, philox_seed=philox_seed, philox_offset_base=philox_offset, sd_mask=sd_mask, dropout_mask=dropout_mask, alibi_slopes=alibi_slopes,
574574
HQ=nheads_q, HK=nheads_k, ACTUAL_BLOCK_DMODEL=head_size, MAX_SEQLENS_Q=max_seqlens_q,
575575
MAX_SEQLENS_K=max_seqlens_k, IS_CAUSAL=causal, IS_VARLEN=is_varlen, IS_INFERENCE=is_inference,
576576
BLOCK_DMODEL=padded_d_model, USE_BIAS=False if bias is None else True,

modules/flash_attn_triton_amd/utils.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ class MetaData():
4242
rotary_cos: Optional[torch.Tensor] = None
4343
rotary_interleaved: bool = False
4444
rotary_conjunction: bool = False
45-
45+
4646

4747
def __repr__(self) -> str:
4848
return (f"MetaData(\n"
@@ -161,7 +161,7 @@ def generate_varlen_tensor(
161161
if batch_size is None:
162162
valid_batch_sizes = [bs for bs in [1, 2, 4, 8, 16, 32, 64] if bs <= total_seqlen]
163163
batch_size = random.choice(valid_batch_sizes)
164-
164+
165165
# get seqlens
166166
if equal_seqlens:
167167
seqlens = torch.full(
@@ -241,14 +241,14 @@ def input_helper(
241241
TOTAL_SEQLENS_Q = BATCH * N_CTX_Q
242242
TOTAL_SEQLENS_K = BATCH * N_CTX_K
243243
equal_seqlens=False
244-
244+
245245
# gen tensors
246246
# TODO: the gen functions should maybe have different gen modes like random, ones, increasing seqlen
247247
q, cu_seqlens_q, max_seqlen_q = generate_varlen_tensor(TOTAL_SEQLENS_Q, HQ, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT)
248248
k, cu_seqlens_k, max_seqlen_k = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT)
249249
v, _, _ = generate_varlen_tensor(TOTAL_SEQLENS_K, HK, D_HEAD, batch_size=BATCH, dtype=dtype, device=device, equal_seqlens=equal_seqlens, DEBUG_INPUT=DEBUG_INPUT)
250250
do = torch.ones_like(q) if DEBUG_INPUT else torch.randn_like(q)
251-
251+
252252
# setup metadata
253253
if DEBUG_INPUT:
254254
sm_scale = 1
@@ -313,7 +313,7 @@ def input_helper(
313313

314314
return qkv, do, metadata
315315
else:
316-
assert False, f"Unsupported packing mode: {packing}"
316+
raise AssertionError(f"Unsupported packing mode: {packing}")
317317

318318
# -------------------------------
319319
# Alibi
@@ -366,21 +366,21 @@ def get_shape_from_layout(
366366
elif layout == 'thd':
367367
total_seqlen, num_heads, head_dim = x.shape
368368
if cu_seqlens is None:
369-
raise ValueError("cu_seqlens must be provided for varlen (thd) layout")
369+
raise ValueError("cu_seqlens must be provided for varlen (thd) layout")
370370
if max_seqlen is None:
371371
raise ValueError("max_seqlen must be provided for varlen (thd) layout")
372-
372+
373373
batch, max_seqlen_final, num_heads, head_dim = len(cu_seqlens) - 1, max_seqlen, num_heads, head_dim
374374
else:
375-
assert False, "Got unsupported layout."
375+
raise AssertionError("Got unsupported layout.")
376376

377377
return batch, max_seqlen_final, num_heads, head_dim
378378

379379

380380
def get_shapes_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = None, max_seqlen_q=None, max_seqlen_k=None):
381381
batch_q, seqlen_q, nheads_q, head_size_q = get_shape_from_layout(q, layout, cu_seqlens_q, max_seqlen_q)
382382
batch_k, seqlen_k, nheads_k, head_size_k = get_shape_from_layout(k, layout, cu_seqlens_k, max_seqlen_k)
383-
383+
384384
# assert
385385
assert batch_q == batch_k
386386
assert head_size_q == head_size_k
@@ -389,13 +389,13 @@ def get_shapes_from_layout(q, k, layout, cu_seqlens_q = None, cu_seqlens_k = Non
389389

390390
def get_stride_from_layout(x: torch.Tensor, layout:Literal["bshd", "bhsd", "thd"]):
391391
if layout == 'thd':
392-
strides = (0, x.stride(1), x.stride(0), x.stride(2))
392+
strides = (0, x.stride(1), x.stride(0), x.stride(2))
393393
elif layout == 'bhsd':
394394
strides = (x.stride(0), x.stride(1), x.stride(2), x.stride(3))
395395
elif layout == 'bshd':
396396
strides = (x.stride(0), x.stride(2), x.stride(1), x.stride(3))
397397
else:
398-
assert False, 'Got unsupported layout.'
398+
raise AssertionError('Got unsupported layout.')
399399
return strides
400400

401401
def get_shape_and_strides_from_layout(x: torch.Tensor, layout: Literal["bshd", "bhsd", "thd"], cu_seqlens: Optional[torch.Tensor] = None, max_seqlen: Optional[int] = None):
@@ -458,22 +458,22 @@ def write_dropout_mask(x, tensor_name = "tensor"):
458458
if True:
459459
BLOCK_M = 64
460460
BLOCK_N = 64
461-
461+
462462
# Calculate number of blocks in each dimension
463463
m_blocks = math.ceil(seqlen_m / BLOCK_M)
464464
n_blocks = math.ceil(seqlen_n / BLOCK_N)
465-
465+
466466
# Process each block
467467
for m_block in range(m_blocks):
468468
# Calculate row range for current block
469469
row_start = m_block * BLOCK_M
470470
row_end = min(row_start + BLOCK_M, seqlen_m)
471-
471+
472472
for n_block in range(n_blocks):
473473
# Calculate column range for current block
474474
col_start = n_block * BLOCK_N
475475
col_end = min(col_start + BLOCK_N, seqlen_n)
476-
476+
477477
# Extract and write the current block
478478
for row_idx in range(row_start, row_end):
479479
row_data = dropout_mask[row_idx][col_start:col_end]

modules/processing.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -260,7 +260,7 @@ def fill_fields_from_opts(self):
260260
self.s_noise = self.s_noise if self.s_noise is not None else opts.s_noise
261261

262262
@property
263-
def sd_model(self):
263+
def sd_model(self): # noqa: F811
264264
return shared.sd_model
265265

266266
@sd_model.setter
@@ -1683,7 +1683,7 @@ def __post_init__(self):
16831683
self.initial_noise_multiplier = opts.initial_noise_multiplier if self.initial_noise_multiplier is None else self.initial_noise_multiplier
16841684

16851685
@property
1686-
def mask_blur(self):
1686+
def mask_blur(self): # noqa: F811
16871687
if self.mask_blur_x == self.mask_blur_y:
16881688
return self.mask_blur_x
16891689
return None

modules/zluda.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from torch._prims_common import DeviceLikeType
55
import onnxruntime as ort
66
from modules import shared, devices, zluda_installer
7-
from modules.zluda_installer import core, default_agent # pylint: disable=unused-import
7+
from modules.zluda_installer import core, default_agent # noqa: F401
88
from modules.onnx_impl.execution_providers import available_execution_providers, ExecutionProvider
99

1010

0 commit comments

Comments
 (0)