ATOM (AiTer Optimized Model) uses a prefill-first scheduler with paged KV cache block management to drive LLM inference on AMD ROCm/HIP GPUs. This guide covers the scheduling algorithm, batch construction, block-level KV cache management, prefix caching, postprocessing, speculative decoding integration, and sequence lifecycle.
| Class | File | Purpose |
|---|---|---|
Scheduler |
atom/model_engine/scheduler.py |
Orchestrates prefill/decode scheduling, preemption, and postprocessing |
ScheduledBatch |
atom/model_engine/scheduler.py |
Immutable snapshot of a scheduled batch sent to the model runner |
ScheduledBatchOutput |
atom/model_engine/scheduler.py |
Holds sampled token IDs and draft token IDs returned from forward pass |
BlockManager |
atom/model_engine/block_manager.py |
Manages paged KV cache blocks with allocation, deallocation, and prefix caching |
Block |
atom/model_engine/block_manager.py |
Single KV cache block with ID, reference count, hash, and token IDs |
Sequence |
atom/model_engine/sequence.py |
Tracks a single request through its lifetime (tokens, blocks, status, timing) |
SequenceStatus |
atom/model_engine/sequence.py |
Enum: WAITING, RUNNING, FINISHED, EXIT_ENGINE |
SequenceType |
atom/model_engine/sequence.py |
Enum: DUMMY, PREFILL, DECODE |
RequestOutput |
atom/model_engine/request.py |
Dataclass streamed to clients with new tokens and finish status |
Config |
atom/config.py |
Scheduling-related fields: max_num_seqs, max_num_batched_tokens, kv_cache_block_size, etc. |
Key config defaults:
| Field | Default | Description |
|---|---|---|
max_num_seqs |
512 | Maximum sequences in a single batch |
max_num_batched_tokens |
16384 | Maximum tokens scheduled in a single step |
kv_cache_block_size |
16 | Tokens per KV cache block (must be multiple of 16, or 1) |
enable_prefix_caching |
False |
Enable hash-based prefix block sharing |
scheduler_delay_factor |
0.0 | Delay factor for batching prompt requests (0 = no delay) |
gpu_memory_utilization |
0.9 | Fraction of GPU memory for KV cache |
The scheduler implements a prefill-first policy: all waiting (prefill) requests are scheduled before any running (decode) requests. The entry point is Scheduler.schedule(), which returns a (ScheduledBatch, dict[int, Sequence]) tuple or None if both queues are empty.
class Scheduler:
def __init__(self, config: Config):
self.max_num_seqs = config.max_num_seqs
self.max_num_batched_tokens = config.max_num_batched_tokens
self.bos_token_id = config.bos_token_id
self.eos_token_id = config.eos_token_id
self.stop_token_ids = config.stop_token_ids
self.block_manager = BlockManager(config)
self.waiting: deque[Sequence] = deque()
self.running: deque[Sequence] = deque()
self.prev_time = 0.0
self.prev_prompt = False
self.last_prompt_latency = 0.0
self.delay_factor = config.scheduler_delay_factor
self.use_spec = config.speculative_config is not None
self.mtp_k: int = (
config.speculative_config.num_speculative_tokens if self.use_spec else 0
)
self.total_draft_tokens = 0
self.total_accepted_tokens = 0The scheduler maintains two deques -- waiting (pending prefill) and running (active decode) -- plus a BlockManager for KV cache allocation.
Scheduler.schedule() proceeds in two phases:
Phase 1 -- Prefill scheduling:
- While the delay gate passes (
_passed_delay), the waiting queue is non-empty, andnum_seqs_prefill < max_num_seqs:- Peek the first waiting sequence.
- Compute
num_new_tokens = seq.num_tokens - seq.num_cached_tokens(prefix cache hits reduce new tokens). - If
num_batched_tokens + num_new_tokens > max_num_batched_tokensorblock_manager.can_allocate(seq)returnsFalse, break. - Otherwise: allocate blocks, set
seq.status = RUNNING,seq.type = PREFILL, move fromwaitingtorunning.
- If any prefill sequences were scheduled, return the batch immediately (no decode mixing).
Phase 2 -- Decode scheduling (only when zero prefills were scheduled):
- Pop sequences from
runningup tomax_num_seqs. - For each sequence, check
block_manager.can_append(seq). - If a block cannot be appended, preempt the last running sequence (move it back to
waitingwith statusWAITINGand deallocate its blocks). - If the sequence has speculative draft tokens (
seq.spec_token_ids), record them inscheduled_spec_decode_tokens. - Call
block_manager.may_append(seq, num_new_tokens)wherenum_new_tokens = mtp_k + 1. - Re-insert all scheduled sequences back into
running(preserving order).
When scheduler_delay_factor > 0, the scheduler delays prefill scheduling to allow the waiting queue to accumulate more requests for better batching:
def _passed_delay(self, now: float) -> bool:
if self.prev_prompt:
self.last_prompt_latency = now - self.prev_time
self.prev_time, self.prev_prompt = now, False
if self.delay_factor > 0 and self.waiting:
earliest_arrival_time = min([seq.arrive_time for seq in self.waiting])
passed_delay = (now - earliest_arrival_time) > (
self.delay_factor * self.last_prompt_latency
) or not self.running
else:
passed_delay = True
return passed_delayA new prefill is scheduled only when the earliest waiting request has waited longer than delay_factor * last_prompt_latency, or when there are no running decode requests.
When a decode step cannot extend a sequence's KV cache (no free blocks), the scheduler preempts the last running sequence:
def preempt(self, seq: Sequence):
seq.status = SequenceStatus.WAITING
self.block_manager.deallocate(seq)
self.waiting.appendleft(seq)The preempted sequence is pushed to the front of the waiting queue and its blocks are fully deallocated, so it will be re-prefilled on the next scheduling cycle.
ScheduledBatch is constructed by Scheduler.schedule() and passed to the model runner. It is a frozen snapshot of batch metadata.
class ScheduledBatch:
def __init__(
self,
seqs: dict[int, Sequence],
num_scheduled_tokens: list[int],
total_tokens_num: int,
total_tokens_num_prefill: int = 0,
total_tokens_num_decode: int = 0,
total_seqs_num: int = 0,
total_seqs_num_prefill: int = 0,
total_seqs_num_decode: int = 0,
is_dummy_run: bool = False,
num_spec_step: int = 0,
scheduled_spec_decode_tokens: dict[int, list[int]] = {},
):| Field | Type | Description |
|---|---|---|
req_ids |
list[int] |
Sequence IDs in batch order (list(seqs.keys())) |
scheduled_tokens |
list[list[int]] |
Last num_tokens token IDs per sequence (the tokens to process) |
temperatures |
list[float] |
Sampling temperature per sequence |
context_lens |
list[int] |
Total token count per sequence (seq.num_tokens) |
block_tables |
list[list[int]] |
Block ID tables for sequences that have block tables |
last_block_num_tokens |
list[int] |
Number of valid tokens in each sequence's last block |
num_cached_tokens |
list[int] |
Number of tokens served from prefix cache per sequence |
num_scheduled_tokens |
list[int] |
Number of new tokens scheduled per sequence |
total_tokens_num |
int |
Sum of all scheduled tokens across all sequences |
total_tokens_num_prefill |
int |
Total scheduled tokens for prefill sequences |
total_tokens_num_decode |
int |
Total scheduled tokens for decode sequences |
total_seqs_num |
int |
Total number of sequences in the batch |
total_seqs_num_prefill |
int |
Number of prefill sequences |
total_seqs_num_decode |
int |
Number of decode sequences |
is_dummy_run |
bool |
Whether this is a dummy/warmup run |
num_spec_step |
int |
Number of speculative decode steps (mtp_k) |
scheduled_spec_decode_tokens |
dict[int, list[int]] |
Draft token IDs per sequence ID from prior speculative step |
Returned by the model runner after a forward pass:
class ScheduledBatchOutput:
def __init__(
self,
token_ids: dict[int, tuple[int, ...]],
draft_token_ids,
):
self.req_ids = list(token_ids.keys())
self.token_ids = token_ids # {seq_id: (accepted_token_ids...)}
self.draft_token_ids = draft_token_ids # {seq_id: [draft_ids]} or Nonetoken_idsmaps sequence ID to a tuple of accepted token IDs.draft_token_idsmaps sequence ID to a list of speculative draft token IDs for the next step (when MTP is active).- A special key
-1intoken_idssignals deferred output mode.
The BlockManager implements paged KV cache management with fixed-size blocks.
class Block:
def __init__(self, block_id):
self.block_id = block_id # Unique integer ID
self.ref_count = 0 # Number of sequences referencing this block
self.hash = -1 # xxhash64 digest for prefix caching (-1 = unhashed)
self.token_ids = [] # Token IDs stored in this blockMethods:
update(hash, token_ids)-- Sets the block's hash and token content.reset()-- Setsref_count = 1,hash = -1,token_ids = [](used on fresh allocation).
class BlockManager:
def __init__(self, config: Config):
block_size = config.kv_cache_block_size # Tokens per block (default 16)
num_blocks = config.num_kvcache_blocks # Total blocks in pool
self.block_size = block_size
self.blocks: list[Block] = [Block(i) for i in range(num_blocks)]
self.hash_to_block_id: dict[int, int] = dict()
self.free_block_ids: deque[int] = deque(range(num_blocks))
self.used_block_ids: set[int] = set()
self.enable_prefix_caching = config.enable_prefix_cachingThe block pool is pre-allocated at startup. free_block_ids is a deque for O(1) pop/push, used_block_ids tracks active blocks, and hash_to_block_id maps content hashes to block IDs for prefix caching.
Called during prefill scheduling for new sequences:
def allocate(self, seq: Sequence):- Iterates over
seq.num_blocksblocks. - For each block, computes hash if the block is full (
len(token_ids) == block_size). Partial (last) blocks gethash = -1. - If prefix caching is enabled, looks up
hash_to_block_id:- Cache hit: Verifies
token_idsmatch. If the block is already inused_block_ids, incrementsref_count. If it was evicted but still in the free list, re-allocates it. Incrementsseq.num_cached_tokensbyblock_size. - Cache miss: Allocates from
free_block_ids[0].
- Cache hit: Verifies
- Full blocks are registered in
hash_to_block_id.
Called when a sequence finishes or is preempted:
def deallocate(self, seq: Sequence):
for block_id in reversed(seq.block_table):
block = self.blocks[block_id]
block.ref_count -= 1
if block.ref_count == 0:
self._deallocate_block(block_id)
seq.num_cached_tokens = 0
seq.block_table.clear()Blocks are released in reverse order. Shared blocks (with ref_count > 1 from prefix caching) are not freed until all referencing sequences release them.
def can_allocate(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= seq.num_blocks
def can_append(self, seq: Sequence) -> bool:
return len(self.free_block_ids) >= (len(seq) % self.block_size == 1)can_allocatechecks that enough free blocks exist for the full sequence.can_appendchecks whether a decode step needs a new block. A new block is needed only whenlen(seq) % block_size == 1(the previous block just filled up), requiring exactly 1 free block.
def may_append(self, seq: Sequence, num_new_tokens: int = 1):Called during decode scheduling to extend a sequence's block table:
- If the sequence length modulo
block_sizefalls within(0, num_new_tokens], orblock_size == 1, a new block is needed:- Allocates from
free_block_idsand appends toblock_table. - For
block_size == 1, immediately computes and stores the hash.
- Allocates from
- If
seq_len % block_size == 0, the last block is now full -- computes and stores its hash using the chained prefix. - Otherwise the last block is partially filled with
hash = -1(hash deferred until full).
Prefix caching enables sharing KV cache blocks across sequences that share a common prompt prefix, avoiding redundant computation.
ATOM uses xxhash64 (via the xxhash Python library) for fast, collision-resistant block hashing:
@classmethod
def compute_hash(cls, token_ids: list[int], prefix: int = -1):
h = xxhash.xxh64()
if prefix != -1:
h.update(prefix.to_bytes(8, "little"))
h.update(np.array(token_ids).tobytes())
return h.intdigest()Blocks form a hash chain: each block's hash incorporates the previous block's hash as a prefix. This ensures that two blocks with identical token content but different preceding context produce different hashes.
- First block:
compute_hash(token_ids, prefix=-1)(no prefix). - Subsequent blocks:
compute_hash(token_ids, prefix=prev_block.hash). - Only full blocks (where
len(token_ids) == block_size) receive a hash. Partial blocks havehash = -1and are not cached.
During allocate(), for each full block:
- Compute the block hash via the chain.
- Look up
hash_to_block_id.get(h, -1). - If found, verify
self.blocks[block_id].token_ids == token_ids(guard against hash collisions). - Hit: Reuse the block. If already in
used_block_ids, incrementref_count. Addblock_sizetoseq.num_cached_tokens. - Miss (or first miss in chain): Once a cache miss occurs, all subsequent blocks in the sequence are also misses (
cache_miss = Trueis sticky). Allocate fresh blocks from the free list.
- On allocation:
block.reset()setsref_count = 1. - On cache hit for an in-use block:
ref_count += 1. - On deallocation:
ref_count -= 1. Block returns to free list only whenref_count == 0. - Shared blocks (prefix cache hits) have
ref_count > 1.
Set enable_prefix_caching=True in Config. When disabled, the hash lookup in allocate() is skipped entirely (block_id is always -1).
Scheduler.postprocess() is called after the model forward pass to update sequences with sampled tokens, check stop conditions, generate streaming output, and clean up finished sequences.
def postprocess(
self,
seqs: list[Sequence],
fwd_output: ScheduledBatchOutput,
stream_output_queue=None,
) -> list[Sequence]:For each running sequence whose ID appears in fwd_output.req_ids:
- Deferred output or speculative decode with EOS: Replaces placeholder tokens in-place:
seq.token_ids[-num_placeholder:] = token_ids seq.output_tokens[-num_placeholder:] = token_ids
- Normal path: Calls
seq.append_token(token_id)for each accepted token, which appends totoken_ids, updatesoutput_tokens,last_token, andnum_tokens.
The postprocessor checks stop conditions in priority order:
- Stop token sequences: Compares the tail of
seq.token_idsagainst each entry inseq.stop_token_sequences. Also checks the MTP-adjusted position for speculative decode. Setsleave_reason = "stop_sequence". - EOS token: If
self.eos_token_idappears in the accepted tokens andseq.ignore_eosisFalse. Setsleave_reason = "eos". - Stop token IDs: If any accepted token is in
self.stop_token_ids(fromConfig.stop_token_ids, derived from the model's generation config). Setsleave_reason = "stop_{token_id}". - Max tokens: If
seq.num_completion_tokens >= seq.max_tokens. Setsleave_reason = "max_tokens".
When stream_output_queue is provided, the scheduler creates a RequestOutput for each processed sequence:
request_output = RequestOutput(
request_id=seq.id,
output_tokens=output_tokens_list,
finished=(leave_reason is not None),
finish_reason=leave_reason,
)RequestOutput fields:
| Field | Type | Description |
|---|---|---|
request_id |
int |
Sequence ID |
output_tokens |
list[int] |
Newly generated tokens since last callback |
finished |
bool |
Whether the sequence is done |
finish_reason |
Optional[str] |
One of: "eos", "max_tokens", "stop_sequence", "stop_{token_id}", or None |
Stream outputs are batched and put onto stream_output_queue via put_nowait.
For finished sequences:
- Set
seq.status = SequenceStatus.FINISHED. - Call
block_manager.deallocate(seq)to free KV cache blocks. - Remove from the
runningdeque. - Return in the
finished_seqslist.
When speculative decoding or deferred output is active, placeholder EOS tokens are appended to still-running sequences to reserve KV cache slots for the next step:
if need_placeholder:
for seq in seqs:
if seq.status == SequenceStatus.RUNNING:
for _ in range(seq.num_placeholder):
seq.append_token(self.eos_token_id)The placeholder count is determined as follows:
- For sequences processed in this step (had output in
fwd_output): always1 + mtp_k, regardless of mode. - For sequences not processed (skipped in this step): the count depends on the batch-level mode:
- Deferred output + speculative:
mtp_k + 1 - Deferred output only:
1 - Speculative only:
mtp_k
- Deferred output + speculative:
ATOM supports Multi-Token Prediction (MTP) speculative decoding, where a draft model proposes mtp_k additional tokens per step.
self.use_spec = config.speculative_config is not None
self.mtp_k: int = config.speculative_config.num_speculative_tokens if self.use_spec else 0
self.total_draft_tokens = 0
self.total_accepted_tokens = 0Note: SpeculativeConfig currently enforces num_speculative_tokens == 1.
During decode scheduling:
- If
seq.spec_token_idsis non-empty, the draft tokens are recorded inscheduled_spec_decode_tokens[seq.id]. num_new_tokens = mtp_k + 1(1 target +mtp_kdraft tokens), somay_appendreserves enough block space.- The
ScheduledBatchcarriesnum_spec_step = mtp_kand thescheduled_spec_decode_tokensdict.
def update_spec_stats(self, num_accepted_tokens):
self.total_draft_tokens += self.mtp_k
self.total_accepted_tokens += num_accepted_tokens - self.mtp_kEvery 1000 draft tokens, the acceptance rate is logged:
[MTP Stats] Total draft tokens: 5000, Accepted: 3750, Acceptance rate: 75.00%
After postprocessing, accepted draft token IDs for the next step are stored on the sequence:
if draft_token_ids and seq.id in draft_token_ids:
seq.spec_token_ids = draft_token_ids[seq.id]These are picked up by the scheduler on the next schedule() call.
The Sequence class represents a single request throughout its lifecycle.
class Sequence:
def __init__(
self,
token_ids: list[int],
block_size: int,
sampling_params=SamplingParams(),
stop_token_sequences: list[list[int]] = None,
stream_callback: Optional[Callable[[Any], None]] = None,
id=None,
):| Field | Type | Description |
|---|---|---|
id |
int |
Auto-incrementing unique ID (from itertools.count) |
token_ids |
list[int] |
Full token sequence (prompt + completion) |
block_size |
int |
KV cache block size (from config) |
status |
SequenceStatus |
Current lifecycle state |
type |
SequenceType |
Current step type (DUMMY, PREFILL, DECODE) |
num_tokens |
int |
Total tokens (prompt + completion); property with setter that also updates num_blocks and last_block_num_tokens |
num_prompt_tokens |
int |
Number of prompt tokens (fixed at init) |
num_cached_tokens |
int |
Tokens served from prefix cache |
block_table |
list[int] |
Ordered list of block IDs assigned to this sequence |
last_token |
int |
Most recently appended token ID |
temperature |
float |
Sampling temperature (from SamplingParams) |
max_tokens |
int |
Max completion tokens (from SamplingParams, default 64) |
ignore_eos |
bool |
Whether to ignore EOS tokens (from SamplingParams) |
stop_strings |
Optional[list[str]] |
Stop strings (from SamplingParams) |
stop_token_sequences |
list[list[int]] |
Token-level stop sequences |
stream_callback |
Optional[Callable] |
Per-sequence stream callback |
output_tokens |
list[int] |
Cache of newly generated tokens |
spec_token_ids |
list[int] |
Speculative draft token IDs for next step |
num_placeholder |
int |
Number of placeholder tokens inserted for speculative/deferred output |
| Field | Type | Description |
|---|---|---|
arrive_time |
float |
Timestamp when the sequence entered the scheduler |
first_token_time |
float |
Timestamp of the first completion token (TTFT measurement) |
leave_time |
float |
Timestamp when the sequence finished |
leave_reason |
str |
Reason for finishing (e.g., "eos", "max_tokens", "stop_sequence") |
| Property | Returns |
|---|---|
num_completion_tokens |
num_tokens - num_prompt_tokens |
prompt_token_ids |
token_ids[:num_prompt_tokens] |
completion_token_ids |
token_ids[num_prompt_tokens:] |
num_cached_blocks |
num_cached_tokens // block_size |
is_finished |
status == SequenceStatus.FINISHED |
Setting num_tokens triggers derived field updates:
@num_tokens.setter
def num_tokens(self, value):
self._num_tokens = value
self.num_blocks = (value + self.block_size - 1) // self.block_size
self.last_block_num_tokens = self._num_tokens - (self.num_blocks - 1) * self.block_size allocate blocks
add(seq) ---------> WAITING ---------> RUNNING (PREFILL)
^ |
| | next schedule() step
preempt() v
| RUNNING (DECODE) <--+
+--- can't append | |
| stop condition met
v
FINISHED
|
| deallocate blocks
v
(removed from running)
| Value | Meaning |
|---|---|
WAITING |
In the waiting queue, pending prefill |
RUNNING |
Actively being processed (prefill or decode) |
FINISHED |
Stop condition met, blocks deallocated |
EXIT_ENGINE |
Sentinel for engine shutdown |
| Value | Meaning |
|---|---|
DUMMY |
Initial state before scheduling |
PREFILL |
Currently in prefill phase |
DECODE |
Currently in decode phase |
| File | Description |
|---|---|
atom/model_engine/scheduler.py |
Scheduler, ScheduledBatch, ScheduledBatchOutput -- scheduling algorithm, postprocessing, speculative decode stats |
atom/model_engine/block_manager.py |
Block, BlockManager -- paged KV cache block pool, allocation/deallocation, prefix caching with xxhash64 |
atom/model_engine/sequence.py |
Sequence, SequenceStatus, SequenceType -- request lifecycle, token management, timing |
atom/model_engine/request.py |
RequestOutput -- streaming output dataclass with request_id, output_tokens, finished, finish_reason |
atom/config.py |
Config -- scheduling-related fields (max_num_seqs, max_num_batched_tokens, kv_cache_block_size, enable_prefix_caching, scheduler_delay_factor), SpeculativeConfig |
atom/sampling_params.py |
SamplingParams -- temperature, max_tokens, ignore_eos, stop_strings |