Skip to content

Commit f855a13

Browse files
alignment enforced for q,k,v
Signed-off-by: Jinseok Lee <jindol21@rebellions.ai>
1 parent fbc4f04 commit f855a13

3 files changed

Lines changed: 24 additions & 0 deletions

File tree

vllm_rbln/triton_kernels/attention.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,10 @@ def _(
377377

378378
output = torch.empty_like(query)
379379

380+
query = rblib.align_tensor_last_dim_to_64(query)
381+
key = rblib.align_tensor_last_dim_to_64(key)
382+
value = rblib.align_tensor_last_dim_to_64(value)
383+
380384
NUM_HEAD = query.shape[1]
381385
NUM_GROUP = query.shape[2]
382386
HEAD_DIM = query.shape[-1]
@@ -435,6 +439,10 @@ def _(
435439

436440
output = torch.empty_like(query)
437441

442+
query = rblib.align_tensor_last_dim_to_64(query)
443+
key = rblib.align_tensor_last_dim_to_64(key)
444+
value = rblib.align_tensor_last_dim_to_64(value)
445+
438446
NUM_HEAD = query.shape[1]
439447
NUM_GROUP = query.shape[2]
440448
HEAD_DIM = query.shape[-1]

vllm_rbln/triton_kernels/causal_attention.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,10 @@ def _(
519519

520520
output = torch.empty_like(query)
521521

522+
query = rblib.align_tensor_last_dim_to_64(query)
523+
key = rblib.align_tensor_last_dim_to_64(key)
524+
value = rblib.align_tensor_last_dim_to_64(value)
525+
522526
NUM_HEAD = query.shape[1]
523527
NUM_GROUP = query.shape[2]
524528
HEAD_DIM = query.shape[-1]
@@ -591,6 +595,10 @@ def _(
591595

592596
output = torch.empty_like(query)
593597

598+
query = rblib.align_tensor_last_dim_to_64(query)
599+
key = rblib.align_tensor_last_dim_to_64(key)
600+
value = rblib.align_tensor_last_dim_to_64(value)
601+
594602
NUM_HEAD = query.shape[1]
595603
NUM_GROUP = query.shape[2]
596604
HEAD_DIM = query.shape[-1]

vllm_rbln/triton_kernels/sliding_window_attention.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -453,6 +453,10 @@ def _(
453453
qk_scale = qk_scale.to(torch.float32)
454454
output = torch.empty_like(query)
455455

456+
query = rblib.align_tensor_last_dim_to_64(query)
457+
key = rblib.align_tensor_last_dim_to_64(key)
458+
value = rblib.align_tensor_last_dim_to_64(value)
459+
456460
NUM_HEAD = query.shape[1]
457461
NUM_GROUP = query.shape[2]
458462
HEAD_DIM = query.shape[-1]
@@ -525,6 +529,10 @@ def _(
525529
qk_scale = qk_scale.to(torch.float32)
526530
output = torch.empty_like(query)
527531

532+
query = rblib.align_tensor_last_dim_to_64(query)
533+
key = rblib.align_tensor_last_dim_to_64(key)
534+
value = rblib.align_tensor_last_dim_to_64(value)
535+
528536
NUM_HEAD = query.shape[1]
529537
NUM_GROUP = query.shape[2]
530538
HEAD_DIM = query.shape[-1]

0 commit comments

Comments
 (0)