2525import time
2626
2727import mlx .core as mx
28+ from mlx_lm .server import convert_chat , process_message_content
2829
2930from parallax .server .cache_manager import CacheManager
3031from parallax .server .request import InitialRequest
3132from parallax .server .sampling .sampler import SamplingBatchInfo
3233from parallax .server .sampling .sampling_params import SamplingParams
34+ from parallax .server .scheduler import _normalize_token_ids
3335from parallax .server .shard_loader import MLXModelLoader
3436from parallax .utils .utils import create_causal_mask , get_layer_types
3537
@@ -44,6 +46,40 @@ def print_rank(message):
4446 print (f"[Rank { tp_rank } ] { message } " )
4547
4648
49+ def get_eos_token_ids (config , tokenizer ):
50+ eos_token_id = config .get ("eos_token_id" )
51+ tokenizer_eos_token_id = getattr (tokenizer , "eos_token_id" , None )
52+ if eos_token_id is None :
53+ eos_token_id = tokenizer_eos_token_id
54+
55+ eos_token_ids = _normalize_token_ids (eos_token_id )
56+ eos_token_ids .update (_normalize_token_ids (tokenizer_eos_token_id ))
57+ return eos_token_ids
58+
59+
60+ def build_prompt (messages , tokenizer ):
61+ if tokenizer .chat_template :
62+ process_message_content (messages )
63+ prompt_tokens = tokenizer .apply_chat_template (
64+ messages ,
65+ None ,
66+ tokenize = True ,
67+ add_generation_prompt = True ,
68+ return_dict = False ,
69+ )
70+ full_prompt = tokenizer .apply_chat_template (
71+ messages ,
72+ None ,
73+ tokenize = False ,
74+ add_generation_prompt = True ,
75+ return_dict = False ,
76+ )
77+ else :
78+ full_prompt = convert_chat (messages , None )
79+ prompt_tokens = tokenizer .encode (full_prompt )
80+ return full_prompt , prompt_tokens
81+
82+
4783def main ():
4884 parser = argparse .ArgumentParser (description = "Simple offline inference script" )
4985 parser .add_argument (
@@ -76,6 +112,8 @@ def main():
76112 # 2. Initialize CacheManager
77113 num_layers = config .get ("num_hidden_layers" )
78114 num_kv_heads = config .get ("num_key_value_heads" )
115+ if num_kv_heads is None :
116+ num_kv_heads = config .get ("num_attention_groups" )
79117 head_dim = config .get ("head_dim" ) or config .get ("hidden_size" ) // config .get (
80118 "num_attention_heads"
81119 )
@@ -88,6 +126,18 @@ def main():
88126
89127 v_head_dim = config .get ("v_head_dim" )
90128 layer_types = get_layer_types (config , 0 , num_layers )
129+ linear_key_head_dim = config .get ("linear_key_head_dim" )
130+ linear_value_head_dim = config .get ("linear_value_head_dim" )
131+ linear_conv_kernel_dim = config .get ("linear_conv_kernel_dim" )
132+ linear_num_key_heads = config .get ("linear_num_key_heads" )
133+ linear_num_value_heads = config .get ("linear_num_value_heads" )
134+ key_dim , value_dim , conv_dim = None , None , None
135+ if linear_key_head_dim is not None and linear_num_key_heads is not None :
136+ key_dim = linear_key_head_dim * linear_num_key_heads
137+ if linear_value_head_dim is not None and linear_num_value_heads is not None :
138+ value_dim = linear_value_head_dim * linear_num_value_heads
139+ if key_dim is not None and value_dim is not None :
140+ conv_dim = key_dim * 2 + value_dim
91141
92142 cache_manager = CacheManager (
93143 num_layers = num_layers ,
@@ -98,19 +148,17 @@ def main():
98148 cache_memory_fraction = 0.1 ,
99149 head_dim_v = v_head_dim ,
100150 layer_types = layer_types ,
151+ conv_dim = conv_dim ,
152+ conv_kernel_size = linear_conv_kernel_dim ,
153+ linear_k_dim = linear_key_head_dim ,
154+ linear_v_dim = linear_value_head_dim ,
155+ linear_num_k_heads = linear_num_key_heads ,
156+ linear_num_v_heads = linear_num_value_heads ,
101157 )
102158
103159 # 3. Tokenize and Create Request
104160 messages = [{"role" : "user" , "content" : args .prompt }]
105-
106- if hasattr (tokenizer , "apply_chat_template" ) and tokenizer .chat_template is not None :
107- full_prompt = tokenizer .apply_chat_template (
108- messages , tokenize = False , add_generation_prompt = True
109- )
110- else :
111- full_prompt = args .prompt
112-
113- prompt_tokens = tokenizer .encode (full_prompt )
161+ full_prompt , prompt_tokens = build_prompt (messages , tokenizer )
114162 sampling_params = SamplingParams (temperature = args .temp , top_k = args .topk )
115163 request = InitialRequest (
116164 prompt = full_prompt ,
@@ -119,22 +167,9 @@ def main():
119167 max_new_tokens = args .max_tokens ,
120168 )
121169
122- eos_token_ids = []
123- if tokenizer .eos_token_id is not None :
124- if isinstance (tokenizer .eos_token_id , list ):
125- eos_token_ids .extend (tokenizer .eos_token_id )
126- else :
127- eos_token_ids .append (tokenizer .eos_token_id )
128- config_eos = config .get ("eos_token_id" )
129- if config_eos is not None :
130- if isinstance (config_eos , list ):
131- for e in config_eos :
132- if e not in eos_token_ids :
133- eos_token_ids .append (e )
134- elif config_eos not in eos_token_ids :
135- eos_token_ids .append (config_eos )
136-
137- eos_token_ids = set (eos_token_ids )
170+ eos_token_ids = get_eos_token_ids (config , tokenizer )
171+ if not eos_token_ids :
172+ raise ValueError ("EOS token ID must be set for generation." )
138173
139174 # 4. Prefill
140175 print_rank (f"Full prompt:\n { full_prompt } " )
@@ -151,6 +186,9 @@ def main():
151186 input_ids = mx .array ([request .input_ids ])
152187 block_table = mx .array ([cache_manager .get_block_table (request .request_id )], dtype = mx .int32 )
153188 context_lengths = mx .array ([request .prompt_len ], dtype = mx .int32 )
189+ state_slot_mapping = None
190+ if cache_manager .needs_slots :
191+ state_slot_mapping = mx .array ([cache_manager .get_slot (request .request_id )], dtype = mx .int32 )
154192
155193 block_size = cache_manager .block_size
156194 slot_mapping = []
@@ -172,21 +210,27 @@ def main():
172210 block_tables = block_table ,
173211 context_lengths = context_lengths ,
174212 slot_mapping = slot_mapping ,
213+ state_slot_mapping = state_slot_mapping ,
175214 )
176215
177216 sampling_info = SamplingBatchInfo .from_reqs ([request ])
178217
179218 next_token_id = model .logits_to_tokens (logits , context_lengths , sampling_info )
180219
181220 token_id = int (next_token_id [0 ])
182- request .commit_new_token (token_id )
221+ is_finished = token_id in eos_token_ids
222+ if not is_finished :
223+ request .commit_new_token (token_id )
183224
184225 prefill_time = time .perf_counter () - prefill_start
185226 print_rank (f"Token 1 (Prefill) time: { prefill_time * 1000 :.2f} ms" )
186227
187228 # 5. Decode Loop
188229 total_decode_time = 0
189230 for i in range (args .max_tokens - 1 ):
231+ if is_finished :
232+ break
233+
190234 decode_step_start = time .perf_counter ()
191235
192236 success = cache_manager .append_slot (request .request_id )
@@ -204,12 +248,14 @@ def main():
204248 mask = None ,
205249 block_tables = block_table ,
206250 context_lengths = context_lengths ,
251+ state_slot_mapping = state_slot_mapping ,
207252 )
208253
209254 next_token_id = model .logits_to_tokens (logits , mx .array ([1 ]), sampling_info )
210255
211256 token_id = int (next_token_id [0 ])
212- if token_id in eos_token_ids :
257+ is_finished = token_id in eos_token_ids
258+ if is_finished :
213259 break
214260 request .commit_new_token (token_id )
215261
0 commit comments