Skip to content

RFC: XLA-Optimized Speculative Decoding via SpeculativeSampler #2513

@Vivek1106-04

Description

@Vivek1106-04

Problem

Inference latency for large autoregressive models (e.g., Llama3-90B, Gemma-27B) is memory-bandwidth bound. While Speculative Decoding (Leviathan et al.) offers significant speedups (2x-3x) by leveraging smaller draft models, KerasHub currently lacks a native, optimized implementation.

Existing external implementations (like Hugging Face's assisted_generation) often rely on Python-side orchestration loops. This introduces significant CPU-GPU synchronization overhead ("Python Ping-Pong"), preventing the logic from being fully compiled into a single XLA graph, which is critical for TPU performance.

Proposed Solution

I propose implementing Speculative Decoding natively in KerasHub by introducing a SpeculativeSampler class.

Unlike standard approaches that modify the generate() loop, this approach:

  1. Encapsulates Logic: Keeps CausalLM.generate() clean by isolating speculation logic within the Sampler abstraction.
  2. Enables XLA Compilation: Orchestrates the Draft -> Verify -> Accept loop entirely within the graph (jax.jit/tf.function), maximizing hardware utilization.
  3. Static Shape Compatibility: Uses fixed-length padding for draft sequences to ensure compatibility with XLA/TPU static shape requirements.

Proposed API

The API follows the standard Keras pattern: the user defines the "How" (Strategy) during compilation/setup, keeping the runtime generate call clean.

# 1. Load models
target_model = keras_hub.models.Llama3CausalLM.from_preset("llama3_8b_en")
draft_model = keras_hub.models.Llama3CausalLM.from_preset("llama3_1b_en")

# 2. Attach the speculative strategy
# This allows us to pre-compile the draft generation graph
target_model.compile(
    sampler=keras_hub.samplers.SpeculativeSampler(
        draft_model=draft_model,
        draft_steps=K,  # e.g., 5
        temperature=1.0,
    )
)

# 3. Fast Inference (Clean API)
target_model.generate("The future of AI is", max_length=128)

Metadata

Metadata

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions