|
13 | 13 | parser.add_argument("--device", type=str, default="gpu")
|
14 | 14 | args = parser.parse_args()
|
15 | 15 |
|
| 16 | +use_mlx = False |
16 | 17 | if args.solution == "tiny_llm":
|
17 | 18 | from tiny_llm import Qwen2Model, simple_generate
|
18 |
| - |
19 | 19 | print("Using your tiny_llm solution")
|
20 | 20 | elif args.solution == "tiny_llm_week1_ref" or args.solution == "week1_ref":
|
21 | 21 | from tiny_llm_week1_ref import Qwen2Model, simple_generate
|
22 |
| - |
23 | 22 | print("Using tiny_llm_week1_ref solution")
|
24 | 23 | elif args.solution == "tiny_llm_week2_ref" or args.solution == "week2_ref":
|
25 | 24 | from tiny_llm_week2_ref import Qwen2Model, simple_generate
|
26 |
| - |
27 | 25 | print("Using tiny_llm_week2_ref solution")
|
| 26 | +elif args.solution == "mlx": |
| 27 | + use_mlx = True |
| 28 | + from mlx_lm.generate import stream_generate |
| 29 | + print("Using the original mlx model") |
28 | 30 | else:
|
29 | 31 | raise ValueError(f"Solution {args.solution} not supported")
|
30 | 32 |
|
31 | 33 | mlx_model, tokenizer = load(
|
32 | 34 | args.model,
|
33 | 35 | tokenizer_config={"eos_token": "<|im_end|>"},
|
34 |
| - model_config={"tie_word_embeddings": False, "rope_traditional": True}, |
| 36 | + model_config={"tie_word_embeddings": False, "rope_traditional": False}, |
35 | 37 | )
|
36 | 38 |
|
37 | 39 | with mx.stream(mx.gpu if args.device == "gpu" else mx.cpu):
|
38 |
| - tiny_llm_model = Qwen2Model(mlx_model) |
| 40 | + if use_mlx: |
| 41 | + tiny_llm_model = mlx_model |
| 42 | + else: |
| 43 | + tiny_llm_model = Qwen2Model(mlx_model) |
39 | 44 | messages = [
|
40 | 45 | {"role": "system", "content": "You are a helpful assistant."},
|
41 | 46 | {"role": "user", "content": args.prompt},
|
42 | 47 | ]
|
43 | 48 | prompt = tokenizer.apply_chat_template(
|
44 | 49 | messages, tokenize=False, add_generation_prompt=True
|
45 | 50 | )
|
46 |
| - simple_generate(tiny_llm_model, tokenizer, prompt) |
| 51 | + if not use_mlx: |
| 52 | + simple_generate(tiny_llm_model, tokenizer, prompt) |
| 53 | + else: |
| 54 | + for resp in stream_generate(tiny_llm_model, tokenizer, prompt): |
| 55 | + print(resp.text, end="", flush=True) |
0 commit comments