Conversation
✅ Test Results - PASSEDSummary
Details
🎉 All tests passed! This PR is ready for review. |
✅ Test Coverage ReportCoverage of Changed Lines
|
| class DemoConfig: | ||
| max_num_seqs: int = 4 | ||
| max_num_batched_tokens: int = 8192 | ||
| block_size: int = 32 |
There was a problem hiding this comment.
block size = 32 means each KV block can store KV vectors for 32 tokens?
| return torch.argmax(logits, dim=-1) | ||
|
|
||
|
|
||
| def get_last_token_logits( |
There was a problem hiding this comment.
I guess this was that weird part where model generates logits(scores for each token in the llm vocabulary) for each token :)... the others are used in that speculative decoding
| mesh_shape = system_name_to_mesh_shape(requested_system_name.upper()) | ||
| logging.info(f"MESH_DEVICE: '{requested_system_name}' - mesh shape: {mesh_shape}") | ||
|
|
||
| fabric_config = ttnn.FabricConfig.FABRIC_1D |
There was a problem hiding this comment.
are we doing prefill on galaxy? I thought galaxy is 2d?
There was a problem hiding this comment.
You should be able to map any device to 1D. It's not tightly coupled to physical topology but to the workload distribution across the chips https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/Programming_Multiple_Meshes/Programming_Multiple_Meshes.md#6-fabric-configuration
|
|
||
| logging.info("Running prefills...") | ||
| while not engine.is_finished(): | ||
| logits, scheduled = engine.step() |
There was a problem hiding this comment.
I guess one step means one forward pass? The while loop here is because we can batch more than we can compute in one forward pass... the scheduled here is the prompts we just did a prefill for?
| print("=" * 60) | ||
| kv_cache = engine.get_kv_cache() | ||
| print(f"num_layers: {len(kv_cache)}") | ||
| for layer_idx in [0, len(kv_cache) - 1]: |
There was a problem hiding this comment.
len of kv cache is the number of layers of the llm? 61?
| ) | ||
|
|
||
| self._generator._prepare_run_configs("prefill", kv_cache_override=tt_kv_cache_config) | ||
| self._kv_cache = self._generator.get_kv_cache() |
There was a problem hiding this comment.
I guess this is the part we allocate memory in device RAM based on the number of layers, hidden dim, batch size etc? Altho it should already know it needs to prealloc 12 gigs.... from my research, vllm splits these 12 gigs into num_of_layers pools of cache blocks.. do we do this as well?
There was a problem hiding this comment.
_kv_cache is actually a list of 61 objects, one per layer
| # Build page_table from all sequences' block tables | ||
| # Shape: [batch_size, max_num_blocks_per_req] | ||
| max_blocks_per_req = max(len(seq.block_table) for seq in seqs) | ||
| page_table = torch.zeros((len(seqs), max_blocks_per_req), dtype=torch.int32) |
There was a problem hiding this comment.
int32 because each entry in the page table is the address?(or index?) of the physical block of device RAM?
There was a problem hiding this comment.
each entry is KV cache block ID
| # Shape: [batch_size, max_num_blocks_per_req] | ||
| max_blocks_per_req = max(len(seq.block_table) for seq in seqs) | ||
| page_table = torch.zeros((len(seqs), max_blocks_per_req), dtype=torch.int32) | ||
| for i, seq in enumerate(seqs): |
There was a problem hiding this comment.
this was confusing to me at first, but I guess we can fill up the page table with real data because we know exactly how many physical blocks each seq will need, the dynamic part(where it dynamically asks the block manager for more blocks) happens in the decode phase?
There was a problem hiding this comment.
exactly, what was confusing btw?
|
|
||
| logger.info("Output: logits.shape=%s", tuple(logits.shape)) | ||
|
|
||
| user_logits = logits[0, 0, :prompt_len, :] |
There was a problem hiding this comment.
why are we always indexing 0,0?
There was a problem hiding this comment.
What _prefill returns
_prefill returns a single tensor of logits for the full sequence, with shape:
[1, 1, seq_len, V]
1, 1: batch and (row) batch dimensions
seq_len: length of the prefill sequence (same as input tokens)
V: vocabulary size
| def __init__(self, num_blocks: int, block_size: int): | ||
| self.block_size = block_size | ||
| self.num_blocks = num_blocks | ||
| self.free_block_ids: deque[int] = deque(range(num_blocks)) |
There was a problem hiding this comment.
very confusing how a simple python level deque with ints will actually map to real physical blocks of RAM on a device
| """ | ||
| Compute max number of KV cache blocks that fit in available memory. | ||
|
|
||
| Memory per block (across all layers): num_layers * tensors_per_layer |
There was a problem hiding this comment.
tensor per layer here is actually the number of heads?
| ) | ||
|
|
||
| @timed() | ||
| def schedule(self) -> list[PrefillSequence]: |
There was a problem hiding this comment.
I guess continuous batching here does not make sense, since everything in a batch is padded and will take the same time to exec, that is why this is very simple?
|
where do we tell the device |
No description provided.