Skip to content

Dmadic tt/prefill server poc#1936

Draft
dmadicTT wants to merge 12 commits intodevfrom
dmadicTT/prefill-server-poc
Draft

Dmadic tt/prefill server poc#1936
dmadicTT wants to merge 12 commits intodevfrom
dmadicTT/prefill-server-poc

Conversation

@dmadicTT
Copy link
Contributor

No description provided.

@github-actions
Copy link
Contributor

github-actions bot commented Jan 29, 2026

✅ Test Results - PASSED

Summary

Component Total Passed Skipped Failed Status
tt-inference-server 352 352 0 0
tt-media-server 433 407 26 0
Overall 785 759 26 0

Details

  • Python Version: 3.10
  • Workflow: Test Gate
  • Commit: ad754f4
  • Run ID: 21519608138

🎉 All tests passed! This PR is ready for review.

@github-actions
Copy link
Contributor

github-actions bot commented Jan 29, 2026

✅ Test Coverage Report

Coverage of Changed Lines

Metric Value
Coverage %
Threshold 50%
Status ✅ PASSED

💡 This checks coverage of newly added/modified lines only, not total codebase coverage.

class DemoConfig:
max_num_seqs: int = 4
max_num_batched_tokens: int = 8192
block_size: int = 32
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

block size = 32 means each KV block can store KV vectors for 32 tokens?

return torch.argmax(logits, dim=-1)


def get_last_token_logits(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this was that weird part where model generates logits(scores for each token in the llm vocabulary) for each token :)... the others are used in that speculative decoding

mesh_shape = system_name_to_mesh_shape(requested_system_name.upper())
logging.info(f"MESH_DEVICE: '{requested_system_name}' - mesh shape: {mesh_shape}")

fabric_config = ttnn.FabricConfig.FABRIC_1D
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we doing prefill on galaxy? I thought galaxy is 2d?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You should be able to map any device to 1D. It's not tightly coupled to physical topology but to the workload distribution across the chips https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/Programming_Multiple_Meshes/Programming_Multiple_Meshes.md#6-fabric-configuration


logging.info("Running prefills...")
while not engine.is_finished():
logits, scheduled = engine.step()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess one step means one forward pass? The while loop here is because we can batch more than we can compute in one forward pass... the scheduled here is the prompts we just did a prefill for?

print("=" * 60)
kv_cache = engine.get_kv_cache()
print(f"num_layers: {len(kv_cache)}")
for layer_idx in [0, len(kv_cache) - 1]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

len of kv cache is the number of layers of the llm? 61?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yup

)

self._generator._prepare_run_configs("prefill", kv_cache_override=tt_kv_cache_config)
self._kv_cache = self._generator.get_kv_cache()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this is the part we allocate memory in device RAM based on the number of layers, hidden dim, batch size etc? Altho it should already know it needs to prealloc 12 gigs.... from my research, vllm splits these 12 gigs into num_of_layers pools of cache blocks.. do we do this as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_kv_cache is actually a list of 61 objects, one per layer

# Build page_table from all sequences' block tables
# Shape: [batch_size, max_num_blocks_per_req]
max_blocks_per_req = max(len(seq.block_table) for seq in seqs)
page_table = torch.zeros((len(seqs), max_blocks_per_req), dtype=torch.int32)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

int32 because each entry in the page table is the address?(or index?) of the physical block of device RAM?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

each entry is KV cache block ID

# Shape: [batch_size, max_num_blocks_per_req]
max_blocks_per_req = max(len(seq.block_table) for seq in seqs)
page_table = torch.zeros((len(seqs), max_blocks_per_req), dtype=torch.int32)
for i, seq in enumerate(seqs):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was confusing to me at first, but I guess we can fill up the page table with real data because we know exactly how many physical blocks each seq will need, the dynamic part(where it dynamically asks the block manager for more blocks) happens in the decode phase?

Copy link
Contributor Author

@dmadicTT dmadicTT Feb 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exactly, what was confusing btw?


logger.info("Output: logits.shape=%s", tuple(logits.shape))

user_logits = logits[0, 0, :prompt_len, :]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we always indexing 0,0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What _prefill returns
_prefill returns a single tensor of logits for the full sequence, with shape:
[1, 1, seq_len, V]
1, 1: batch and (row) batch dimensions
seq_len: length of the prefill sequence (same as input tokens)
V: vocabulary size

def __init__(self, num_blocks: int, block_size: int):
self.block_size = block_size
self.num_blocks = num_blocks
self.free_block_ids: deque[int] = deque(range(num_blocks))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

very confusing how a simple python level deque with ints will actually map to real physical blocks of RAM on a device

"""
Compute max number of KV cache blocks that fit in available memory.

Memory per block (across all layers): num_layers * tensors_per_layer
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor per layer here is actually the number of heads?

)

@timed()
def schedule(self) -> list[PrefillSequence]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess continuous batching here does not make sense, since everything in a batch is padded and will take the same time to exec, that is why this is very simple?

@knovokmetTT
Copy link
Contributor

where do we tell the device
"allocate 12 gigs". When we create a page table and fill it with block numbers, where does this mapping actually happen? The kernel on the device thinks its writing in a continuous memory, but someone is doing the mapping.. who? How does it index into the device RAM with the block ids we have prewritten into the page table tensor?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants