Skip to content

Commit c64e9ef

Browse files
committed
fix detokenizer extra tokens
Signed-off-by: Alex Chi Z <[email protected]>
1 parent 7a05e22 commit c64e9ef

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

main.py

+15-6
Original file line numberDiff line numberDiff line change
@@ -13,34 +13,43 @@
1313
parser.add_argument("--device", type=str, default="gpu")
1414
args = parser.parse_args()
1515

16+
use_mlx = False
1617
if args.solution == "tiny_llm":
1718
from tiny_llm import Qwen2Model, simple_generate
18-
1919
print("Using your tiny_llm solution")
2020
elif args.solution == "tiny_llm_week1_ref" or args.solution == "week1_ref":
2121
from tiny_llm_week1_ref import Qwen2Model, simple_generate
22-
2322
print("Using tiny_llm_week1_ref solution")
2423
elif args.solution == "tiny_llm_week2_ref" or args.solution == "week2_ref":
2524
from tiny_llm_week2_ref import Qwen2Model, simple_generate
26-
2725
print("Using tiny_llm_week2_ref solution")
26+
elif args.solution == "mlx":
27+
use_mlx = True
28+
from mlx_lm.generate import stream_generate
29+
print("Using the original mlx model")
2830
else:
2931
raise ValueError(f"Solution {args.solution} not supported")
3032

3133
mlx_model, tokenizer = load(
3234
args.model,
3335
tokenizer_config={"eos_token": "<|im_end|>"},
34-
model_config={"tie_word_embeddings": False, "rope_traditional": True},
36+
model_config={"tie_word_embeddings": False, "rope_traditional": False},
3537
)
3638

3739
with mx.stream(mx.gpu if args.device == "gpu" else mx.cpu):
38-
tiny_llm_model = Qwen2Model(mlx_model)
40+
if use_mlx:
41+
tiny_llm_model = mlx_model
42+
else:
43+
tiny_llm_model = Qwen2Model(mlx_model)
3944
messages = [
4045
{"role": "system", "content": "You are a helpful assistant."},
4146
{"role": "user", "content": args.prompt},
4247
]
4348
prompt = tokenizer.apply_chat_template(
4449
messages, tokenize=False, add_generation_prompt=True
4550
)
46-
simple_generate(tiny_llm_model, tokenizer, prompt)
51+
if not use_mlx:
52+
simple_generate(tiny_llm_model, tokenizer, prompt)
53+
else:
54+
for resp in stream_generate(tiny_llm_model, tokenizer, prompt):
55+
print(resp.text, end="", flush=True)

src/tiny_llm_week2_ref/generate.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,10 @@ def _step(model, y, offset):
2323
# generate/decode
2424
while True:
2525
token, _ = _step(model, tokens, offset)
26-
offset += tokens.size
27-
tokens = token
26+
if offset != 0:
27+
detokenizer.add_token(token.item())
28+
print(detokenizer.last_segment, end="", flush=True)
2829
if token.item() == tokenizer.eos_token_id:
2930
break
30-
detokenizer.add_token(token.item())
31-
print(detokenizer.last_segment, end="", flush=True)
31+
offset += tokens.size
32+
tokens = token

0 commit comments

Comments
 (0)