Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 3609ec0

Browse files
committedMay 2, 2025·
VLLM Workaround
1 parent fb39817 commit 3609ec0

File tree

1 file changed

+8
-7
lines changed

1 file changed

+8
-7
lines changed
 

‎torchao/prototype/mx_formats/utils.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -35,13 +35,14 @@ def to_blocked(input_matrix) -> Tensor:
3535
padded_cols = n_col_blocks * 4
3636

3737
padded = input_matrix
38-
# if (rows, cols) != (padded_rows, padded_cols):
39-
padded = torch.zeros(
40-
(padded_rows, padded_cols),
41-
device=input_matrix.device,
42-
dtype=input_matrix.dtype,
43-
)
44-
padded[:rows, :cols] = input_matrix
38+
# TODO This is to work around VLLM's usage of compile w/ dynamic shapes
39+
if torch.compiler.is_compiling() or (rows, cols) != (padded_rows, padded_cols):
40+
padded = torch.zeros(
41+
(padded_rows, padded_cols),
42+
device=input_matrix.device,
43+
dtype=input_matrix.dtype,
44+
)
45+
padded[:rows, :cols] = input_matrix
4546

4647
# Rearrange the blocks
4748
blocks = padded.view(n_row_blocks, 128, n_col_blocks, 4).permute(0, 2, 1, 3)

0 commit comments

Comments
 (0)
Please sign in to comment.