-
Notifications
You must be signed in to change notification settings - Fork 316
Description
Problem
Inference latency for large autoregressive models (e.g., Llama3-90B, Gemma-27B) is memory-bandwidth bound. While Speculative Decoding (Leviathan et al.) offers significant speedups (2x-3x) by leveraging smaller draft models, KerasHub currently lacks a native, optimized implementation.
Existing external implementations (like Hugging Face's assisted_generation) often rely on Python-side orchestration loops. This introduces significant CPU-GPU synchronization overhead ("Python Ping-Pong"), preventing the logic from being fully compiled into a single XLA graph, which is critical for TPU performance.
Proposed Solution
I propose implementing Speculative Decoding natively in KerasHub by introducing a SpeculativeSampler class.
Unlike standard approaches that modify the generate() loop, this approach:
- Encapsulates Logic: Keeps
CausalLM.generate()clean by isolating speculation logic within theSamplerabstraction. - Enables XLA Compilation: Orchestrates the Draft -> Verify -> Accept loop entirely within the graph (
jax.jit/tf.function), maximizing hardware utilization. - Static Shape Compatibility: Uses fixed-length padding for draft sequences to ensure compatibility with XLA/TPU static shape requirements.
Proposed API
The API follows the standard Keras pattern: the user defines the "How" (Strategy) during compilation/setup, keeping the runtime generate call clean.
# 1. Load models
target_model = keras_hub.models.Llama3CausalLM.from_preset("llama3_8b_en")
draft_model = keras_hub.models.Llama3CausalLM.from_preset("llama3_1b_en")
# 2. Attach the speculative strategy
# This allows us to pre-compile the draft generation graph
target_model.compile(
sampler=keras_hub.samplers.SpeculativeSampler(
draft_model=draft_model,
draft_steps=K, # e.g., 5
temperature=1.0,
)
)
# 3. Fast Inference (Clean API)
target_model.generate("The future of AI is", max_length=128)