Skip to content

Commit 03fe112

Browse files
committed
update
1 parent 2d0180b commit 03fe112

3 files changed

Lines changed: 5 additions & 13 deletions

File tree

src/parallax/models/qwen3_5.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ def __call__(
3838
if prefix_lens is not None and context_lengths is not None:
3939
input_lengths = context_lengths - prefix_lens
4040

41-
if target_len == 1 or (
42-
prefix_lens is not None and bool(mx.any(prefix_lens > 0))
43-
):
41+
if target_len == 1 or (prefix_lens is not None and bool(mx.any(prefix_lens > 0))):
4442
conv_state, state = cache.read_states(state_slot_mapping)
4543
else:
4644
conv_state = mx.zeros(

src/parallax/models/qwen3_next.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515

1616
from parallax.server.cache.base import BaseCache
1717
from parallax.utils.prefix_cache_utils import compute_attention_with_prefix_cache
18-
from parallax_utils.logging_config import get_logger
1918
from parallax_extensions.ops import paged_attention_v1, reshape_and_cache
19+
from parallax_utils.logging_config import get_logger
2020

2121
logger = get_logger(__name__)
2222

@@ -167,9 +167,7 @@ def __call__(
167167
if prefix_lens is not None and context_lengths is not None:
168168
input_lengths = context_lengths - prefix_lens
169169

170-
if target_len == 1 or (
171-
prefix_lens is not None and bool(mx.any(prefix_lens > 0))
172-
):
170+
if target_len == 1 or (prefix_lens is not None and bool(mx.any(prefix_lens > 0))):
173171
conv_state, state1 = cache.read_states(state_slot_mapping)
174172
else:
175173
conv_state = mx.zeros(

src/parallax/server/cache/linear_cache.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -66,17 +66,13 @@ def get_indexer_cache(self) -> Optional[mx.array]:
6666
def zero_slot(self, slot_idx: int):
6767
"""Reset a request slot to an empty recurrent state."""
6868
if self.conv_state_cache is not None:
69-
self.conv_state_cache[0, slot_idx] = mx.zeros_like(
70-
self.conv_state_cache[0, slot_idx]
71-
)
69+
self.conv_state_cache[0, slot_idx] = mx.zeros_like(self.conv_state_cache[0, slot_idx])
7270
if self.linear_state_cache is not None:
7371
self.linear_state_cache[0, slot_idx] = mx.zeros_like(
7472
self.linear_state_cache[0, slot_idx]
7573
)
7674

77-
def snapshot_slot(
78-
self, slot_idx: int
79-
) -> Tuple[Optional[mx.array], Optional[mx.array]]:
75+
def snapshot_slot(self, slot_idx: int) -> Tuple[Optional[mx.array], Optional[mx.array]]:
8076
"""Copy the recurrent state currently stored in a request slot."""
8177
conv_state = None
8278
linear_state = None

0 commit comments

Comments
 (0)