File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -38,9 +38,7 @@ def __call__(
3838 if prefix_lens is not None and context_lengths is not None :
3939 input_lengths = context_lengths - prefix_lens
4040
41- if target_len == 1 or (
42- prefix_lens is not None and bool (mx .any (prefix_lens > 0 ))
43- ):
41+ if target_len == 1 or (prefix_lens is not None and bool (mx .any (prefix_lens > 0 ))):
4442 conv_state , state = cache .read_states (state_slot_mapping )
4543 else :
4644 conv_state = mx .zeros (
Original file line number Diff line number Diff line change 1515
1616from parallax .server .cache .base import BaseCache
1717from parallax .utils .prefix_cache_utils import compute_attention_with_prefix_cache
18- from parallax_utils .logging_config import get_logger
1918from parallax_extensions .ops import paged_attention_v1 , reshape_and_cache
19+ from parallax_utils .logging_config import get_logger
2020
2121logger = get_logger (__name__ )
2222
@@ -167,9 +167,7 @@ def __call__(
167167 if prefix_lens is not None and context_lengths is not None :
168168 input_lengths = context_lengths - prefix_lens
169169
170- if target_len == 1 or (
171- prefix_lens is not None and bool (mx .any (prefix_lens > 0 ))
172- ):
170+ if target_len == 1 or (prefix_lens is not None and bool (mx .any (prefix_lens > 0 ))):
173171 conv_state , state1 = cache .read_states (state_slot_mapping )
174172 else :
175173 conv_state = mx .zeros (
Original file line number Diff line number Diff line change @@ -66,17 +66,13 @@ def get_indexer_cache(self) -> Optional[mx.array]:
6666 def zero_slot (self , slot_idx : int ):
6767 """Reset a request slot to an empty recurrent state."""
6868 if self .conv_state_cache is not None :
69- self .conv_state_cache [0 , slot_idx ] = mx .zeros_like (
70- self .conv_state_cache [0 , slot_idx ]
71- )
69+ self .conv_state_cache [0 , slot_idx ] = mx .zeros_like (self .conv_state_cache [0 , slot_idx ])
7270 if self .linear_state_cache is not None :
7371 self .linear_state_cache [0 , slot_idx ] = mx .zeros_like (
7472 self .linear_state_cache [0 , slot_idx ]
7573 )
7674
77- def snapshot_slot (
78- self , slot_idx : int
79- ) -> Tuple [Optional [mx .array ], Optional [mx .array ]]:
75+ def snapshot_slot (self , slot_idx : int ) -> Tuple [Optional [mx .array ], Optional [mx .array ]]:
8076 """Copy the recurrent state currently stored in a request slot."""
8177 conv_state = None
8278 linear_state = None
You can’t perform that action at this time.
0 commit comments