Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions flashinfer/decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -2715,15 +2715,15 @@ def xqa_batch_decode_with_kv_cache(
query_new,
k_cache,
v_cache,
k_cache_sf,
v_cache_sf,
block_tables,
seq_lens_new,
out_4d,
scratch,
semaphore,
num_kv_heads,
page_size,
k_sf_cache=k_cache_sf,
v_sf_cache=v_cache_sf,
sinks=sinks_new,
q_scale=q_scale_value,
kv_scale=kv_scale_value,
Expand Down
19 changes: 13 additions & 6 deletions flashinfer/norm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@

import functools
import os
from typing import Optional
import warnings
from typing import Optional, Union

import torch

Expand Down Expand Up @@ -62,11 +63,17 @@ def get_norm_module():


def _normalize_scale_tensor(
scale: torch.Tensor, ref_tensor: torch.Tensor
scale: Union[float, torch.Tensor], ref_tensor: torch.Tensor
) -> torch.Tensor:
"""Normalize quantization scale tensor to 1D shape (1,) on target device."""
"""Normalize quantization scale to 1D tensor of shape (1,) on target device."""
if not isinstance(scale, torch.Tensor):
raise TypeError(f"scale must be torch.Tensor, got {type(scale)}")
warnings.warn(
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this interface have to be stable version over version? Naively I would think that _-prefixed functions are not part of the public interface and therefore have no external stability guarantees.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

rmsnorm_quant and fused_add_rmsnorm_quant were the breaking changes (type changed). the compatibility fix landed in this helper function so both old and new signature are supported

"Passing scale as a float is deprecated and will be removed in a future "
"release. Use a torch.Tensor of shape (1,) instead.",
FutureWarning,
stacklevel=3,
)
scale = torch.tensor([scale], dtype=torch.float32, device=ref_tensor.device)
if scale.device != ref_tensor.device:
scale = scale.to(ref_tensor.device)
if scale.dtype != torch.float32:
Expand Down Expand Up @@ -159,7 +166,7 @@ def rmsnorm_quant(
out: torch.Tensor,
input: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
scale: Union[float, torch.Tensor],
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
) -> None:
Expand Down Expand Up @@ -268,7 +275,7 @@ def fused_add_rmsnorm_quant(
input: torch.Tensor,
residual: torch.Tensor,
weight: torch.Tensor,
scale: torch.Tensor,
scale: Union[float, torch.Tensor],
eps: float = 1e-6,
enable_pdl: Optional[bool] = None,
) -> None:
Expand Down
5 changes: 3 additions & 2 deletions flashinfer/xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,8 +155,6 @@ def xqa(
q: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
k_sf_cache: Optional[torch.Tensor],
v_sf_cache: Optional[torch.Tensor],
page_table: torch.Tensor,
seq_lens: torch.Tensor,
output: torch.Tensor,
Expand All @@ -174,6 +172,9 @@ def xqa(
rcp_out_scale: float = 1.0,
q_seq_len: int = 1,
mask: Optional[torch.Tensor] = None,
*,
Comment thread
aleozlx marked this conversation as resolved.
k_sf_cache: Optional[torch.Tensor] = None,
v_sf_cache: Optional[torch.Tensor] = None,
) -> None:
r"""Apply attention with paged KV cache using XQA kernel.
Parameters
Expand Down
2 changes: 0 additions & 2 deletions tests/attention/test_xqa.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,6 @@ def test_xqa(
q_heads,
cache_k_heads.to(torch.float8_e4m3fn) if fp8_kv_cache else cache_k_heads,
cache_v_heads.to(torch.float8_e4m3fn) if fp8_kv_cache else cache_v_heads,
None,
None,
page_list_arg,
seq_len_list,
output,
Expand Down
2 changes: 1 addition & 1 deletion version.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.6.6
0.6.7
Loading