Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ Week 1 is complete. Week 2 is in progress.
| 3.1 | Paged Attention - Part 1 | 🚧 | 🚧 | 🚧 |
| 3.2 | Paged Attention - Part 2 | 🚧 | 🚧 | 🚧 |
| 3.3 | MoE (Mixture of Experts) | 🚧 | 🚧 | 🚧 |
| 3.4 | Speculative Decoding | 🚧 | 🚧 | 🚧 |
| 3.4 | Speculative Decoding | 🚧 | | 🚧 |
| 3.5 | RAG Pipeline | 🚧 | 🚧 | 🚧 |
| 3.6 | AI Agent / Tool Calling | 🚧 | 🚧 | 🚧 |
| 3.7 | Long Context | 🚧 | 🚧 | 🚧 |
Expand Down
23 changes: 22 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

parser = argparse.ArgumentParser()
parser.add_argument("--model", type=str, default="qwen2-7b")
parser.add_argument("--draft-model", type=str, default=None)
parser.add_argument(
"--prompt",
type=str,
Expand Down Expand Up @@ -39,6 +40,7 @@
models,
simple_generate,
simple_generate_with_kv_cache,
speculative_generate,
sampler,
)

Expand All @@ -53,6 +55,15 @@
args.model = models.shortcut_name_to_full_name(args.model)
mlx_model, tokenizer = load(args.model)

if args.draft_model:
args.draft_model = models.shortcut_name_to_full_name(args.draft_model)
draft_mlx_model, draft_tokenizer = load(args.draft_model)
if args.loader == "week1":
raise ValueError("Draft model not supported for week1")
else:
draft_mlx_model = None
draft_tokenizer = None

with mx.stream(mx.gpu if args.device == "gpu" else mx.cpu):
if use_mlx:
tiny_llm_model = mlx_model
Expand All @@ -67,6 +78,13 @@
tiny_llm_model = models.dispatch_model(
args.model, mlx_model, week=2, enable_flash_attn=args.enable_flash_attn
)
if draft_mlx_model is not None:
print(f"Using draft model {args.draft_model}")
draft_tiny_llm_model = models.dispatch_model(
args.draft_model, draft_mlx_model, week=2, enable_flash_attn=args.enable_flash_attn
)
else:
draft_tiny_llm_model = None
else:
raise ValueError(f"Loader {args.loader} not supported")
messages = [
Expand All @@ -86,7 +104,10 @@
if args.loader == "week1":
simple_generate(tiny_llm_model, tokenizer, prompt, sampler=sampler)
elif args.loader == "week2":
simple_generate_with_kv_cache(tiny_llm_model, tokenizer, prompt)
if draft_tiny_llm_model is not None:
speculative_generate(draft_tiny_llm_model, tiny_llm_model, draft_tokenizer, tokenizer, prompt)
else:
simple_generate_with_kv_cache(tiny_llm_model, tokenizer, prompt)
else:
sampler = mlx_lm.sample_utils.make_sampler(
args.sampler_temp, top_p=args.sampler_top_p, top_k=args.sampler_top_k
Expand Down
95 changes: 95 additions & 0 deletions src/tiny_llm_ref/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,98 @@ def _step(model, y, offset, kv_cache):
# Otherwise, we add the decoded token size (which is always 1).
offset += tokens.size
tokens = token

def speculative_generate(
draft_model: Qwen2ModelWeek2, model: Qwen2ModelWeek2, draft_tokenizer: TokenizerWrapper, tokenizer: TokenizerWrapper, prompt: str
) -> str:
draft_kv_cache = [TinyKvFullCache() for _ in range(draft_model.num_hidden_layers)]
kv_cache = [TinyKvFullCache() for _ in range(model.num_hidden_layers)]

def _step(model, y, offset, kv_cache, n_tokens=1):
logits = model(y[None], offset, kv_cache)
if n_tokens > 1:
logits = logits[:, -n_tokens:, :]
else:
logits = logits[:, -1, :]
logprobs = logits - mx.logsumexp(logits, keepdims=True)
sampler = lambda x: mx.argmax(x, axis=-1)
y = sampler(logprobs)
return y, logprobs.squeeze(0)

# prefill with the prompt, using the large model
def _prefill(model, tokenizer, prompt, kv_cache):
prefill_tokens = mx.array(tokenizer.encode(prompt, add_special_tokens=False))
offset = 0
token, _ = _step(model, prefill_tokens, offset, kv_cache)
mx.eval(token)
if token.item() == tokenizer.eos_token_id:
return
offset = prefill_tokens.size
return token, offset

draft_token, draft_offset = _prefill(draft_model, draft_tokenizer, prompt, draft_kv_cache)
token, offset = _prefill(model, tokenizer, prompt, kv_cache)

def _decode_one(token, tokenizer):
if token.item() == tokenizer.eos_token_id:
return False
detokenizer = tokenizer.detokenizer
detokenizer.add_token(token.item())
return True


def draft_generate(model, last_token, offset, kv_cache, num_drafts):
tokens = []
for _ in range(num_drafts):
token, _ = _step(model, last_token, offset, kv_cache)
mx.eval(token)
tokens.append(token.item())
last_token = token
return tokens

num_drafts = 4

def _rewind_cache(kv_cache, revert_len):
for layer in kv_cache:
layer.rewind(revert_len)

def _print_text(text, progress):
print(f"+{progress} {text.replace('\n', ' ')[-80:]}")

# speculative decode
while True:
draft_tokens = draft_generate(draft_model, token, draft_offset, draft_kv_cache, num_drafts)
draft_offset += num_drafts
# assume both models use the same tokenizer
draft_tokens = mx.concat([token, mx.array(draft_tokens)])
new_tokens, _ = _step(model, draft_tokens, offset, kv_cache, num_drafts + 1)
new_tokens = new_tokens.tolist()[0]
offset += num_drafts + 1
last_new_token = new_tokens[-1]
new_tokens = mx.array([token.item()] + new_tokens[:-1])
assert len(new_tokens) == len(draft_tokens)
accept_all = True
for i in range(len(new_tokens)):
if new_tokens[i] != draft_tokens[i]:
# revert the full draft generation; re-generate next time
# or we matched full, then no rewind and use the last token
assert i >= 1 # first token is always the same
revert_len = len(draft_tokens) - i
_rewind_cache(draft_kv_cache, revert_len - 1)
draft_offset -= revert_len - 1
_rewind_cache(kv_cache, revert_len)
token = mx.array([new_tokens[i]])
offset -= revert_len
assert offset == draft_offset
assert offset == kv_cache[0].offset
_print_text(tokenizer._detokenizer.text, i)
accept_all = False
break
if not _decode_one(new_tokens[i], tokenizer):
print(tokenizer._detokenizer.text)
return
if accept_all:
_print_text(tokenizer._detokenizer.text, len(new_tokens))
draft_generate(draft_model, mx.array(draft_tokens[-1:]), draft_offset, draft_kv_cache, 1)
token = mx.array([last_new_token])
draft_offset += 1
4 changes: 4 additions & 0 deletions src/tiny_llm_ref/kv_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,7 @@ def update_and_fetch(
self.key_values = (new_keys, new_values)
self.offset += S
return new_keys, new_values, self.offset, mask

def rewind(self, n: int):
self.offset -= n
self.key_values = (self.key_values[0][:, :, :self.offset], self.key_values[1][:, :, :self.offset])