File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -377,6 +377,10 @@ def _(
377377
378378 output = torch .empty_like (query )
379379
380+ query = rblib .align_tensor_last_dim_to_64 (query )
381+ key = rblib .align_tensor_last_dim_to_64 (key )
382+ value = rblib .align_tensor_last_dim_to_64 (value )
383+
380384 NUM_HEAD = query .shape [1 ]
381385 NUM_GROUP = query .shape [2 ]
382386 HEAD_DIM = query .shape [- 1 ]
@@ -435,6 +439,10 @@ def _(
435439
436440 output = torch .empty_like (query )
437441
442+ query = rblib .align_tensor_last_dim_to_64 (query )
443+ key = rblib .align_tensor_last_dim_to_64 (key )
444+ value = rblib .align_tensor_last_dim_to_64 (value )
445+
438446 NUM_HEAD = query .shape [1 ]
439447 NUM_GROUP = query .shape [2 ]
440448 HEAD_DIM = query .shape [- 1 ]
Original file line number Diff line number Diff line change @@ -519,6 +519,10 @@ def _(
519519
520520 output = torch .empty_like (query )
521521
522+ query = rblib .align_tensor_last_dim_to_64 (query )
523+ key = rblib .align_tensor_last_dim_to_64 (key )
524+ value = rblib .align_tensor_last_dim_to_64 (value )
525+
522526 NUM_HEAD = query .shape [1 ]
523527 NUM_GROUP = query .shape [2 ]
524528 HEAD_DIM = query .shape [- 1 ]
@@ -591,6 +595,10 @@ def _(
591595
592596 output = torch .empty_like (query )
593597
598+ query = rblib .align_tensor_last_dim_to_64 (query )
599+ key = rblib .align_tensor_last_dim_to_64 (key )
600+ value = rblib .align_tensor_last_dim_to_64 (value )
601+
594602 NUM_HEAD = query .shape [1 ]
595603 NUM_GROUP = query .shape [2 ]
596604 HEAD_DIM = query .shape [- 1 ]
Original file line number Diff line number Diff line change @@ -453,6 +453,10 @@ def _(
453453 qk_scale = qk_scale .to (torch .float32 )
454454 output = torch .empty_like (query )
455455
456+ query = rblib .align_tensor_last_dim_to_64 (query )
457+ key = rblib .align_tensor_last_dim_to_64 (key )
458+ value = rblib .align_tensor_last_dim_to_64 (value )
459+
456460 NUM_HEAD = query .shape [1 ]
457461 NUM_GROUP = query .shape [2 ]
458462 HEAD_DIM = query .shape [- 1 ]
@@ -525,6 +529,10 @@ def _(
525529 qk_scale = qk_scale .to (torch .float32 )
526530 output = torch .empty_like (query )
527531
532+ query = rblib .align_tensor_last_dim_to_64 (query )
533+ key = rblib .align_tensor_last_dim_to_64 (key )
534+ value = rblib .align_tensor_last_dim_to_64 (value )
535+
528536 NUM_HEAD = query .shape [1 ]
529537 NUM_GROUP = query .shape [2 ]
530538 HEAD_DIM = query .shape [- 1 ]
You can’t perform that action at this time.
0 commit comments