Skip to content

Commit 8eb14f7

Browse files
authored
[Blitz Decode] Integrate Embedding with H2D (#37913)
### Ticket [Link to Github Issue](#37662) ### Problem description - We need to add the embedding op for Blitz Decode: - `[vocab_size, embedding_dim] = [129280, 7168]` - DRAM Interleaved with `page_size = embedding_dim * sizeof(bfloat16)` - For this implementation, we decided to fuse the embedding lookup with the H2D receiver ### What's changed - Modify `HostInterface` to accept an embedding tensor as an optional input - If specified, the interface will be run using the new `fused_h2d_receiver_embedding.cpp` kernel, which does a direct embedding lookup from input tokens and sends the embedding to the downstream receiver - Bugfix: Support xfer sizes > `NOC_MAX_BURST_SIZE` in `HostInterface`. The original issue was that the NOC read/write APIs used for interfacing with PCIe only supported up to a single NOC packet. Utility APIs for support arbitrary data sizes are added to `models/demos/deepseek_v3_b1/micro_ops/host_io/kernels/pcie_noc_utils.h` - Tested main use case for Blitz Decode and additional embedding tensor shapes ### Checklist - [ ] [![All post-commit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml/badge.svg?branch=asaigal/embedding_rebased)](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml?query=branch:asaigal/embedding_rebased) - [ ] [![Blackhole Post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml/badge.svg?branch=asaigal/embedding_rebased)](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml?query=branch:asaigal/embedding_rebased) - [ ] [![cpp-unit-tests](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml/badge.svg?branch=asaigal/embedding_rebased)](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml?query=branch:asaigal/embedding_rebased) - [ ] New/Existing tests provide coverage for changes #### Model tests If your changes cover model-related code, you should run tests corresponding to affected models and platforms (Single card, T3K, Galaxy). "Choose your pipeline" workflows facilitate running multiple kinds of tests in a single run. Each offers `models-mandatory` and `models-extended` presets. The former includes a minimal set of tests, to be run always. The latter extends that with additional ones - use your best judgement in deciding which is the most appropriate for your PR. - [ ] [![(Single) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select.yaml/badge.svg?branch=asaigal/embedding_rebased)](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select.yaml?query=branch:asaigal/embedding_rebased) - [ ] `models-mandatory` preset (runs: [Device perf regressions](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-device-models.yaml) and [Frequent model and ttnn tests](https://github.com/tenstorrent/tt-metal/actions/workflows/fast-dispatch-full-regressions-and-models.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/single-card-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-models.yaml) tests) - [ ] other selection - specify runs - [ ] [![(T3K) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-t3k.yaml/badge.svg?branch=asaigal/embedding_rebased)](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-t3k.yaml?query=branch:asaigal/embedding_rebased) - [ ] `models-mandatory` preset (runs: [Unit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-unit-tests.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-model-perf-tests.yaml) tests) - [ ] other selection - specify runs - [ ] [![(Galaxy) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-galaxy.yaml/badge.svg?branch=asaigal/embedding_rebased)](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-galaxy.yaml?query=branch:asaigal/embedding_rebased) - [ ] `models-mandatory` preset (runs: [Quick tests](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-quick.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-perf-tests.yaml) tests) - [ ] other selection - specify runs
1 parent ecc3ca4 commit 8eb14f7

File tree

7 files changed

+408
-39
lines changed

7 files changed

+408
-39
lines changed

models/demos/deepseek_v3_b1/micro_ops/host_io/kernels/d2h_sender.cpp

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
#include <cstdint>
55
#include "api/dataflow/dataflow_api.h"
66
#include "api/socket_api.h"
7-
#include "api/debug/dprint.h"
7+
#include "pcie_noc_utils.h"
88

99
FORCE_INLINE bool socket_wait_for_pages_with_termination(
1010
const SocketReceiverInterface& socket, uint32_t num_pages, volatile tt_l1_ptr uint32_t* termination_semaphore) {
@@ -63,14 +63,13 @@ void kernel_main() {
6363
break;
6464
}
6565
uint32_t read_addr = get_read_ptr(upstream_interface_index);
66-
noc_wwrite_with_state<noc_mode, write_cmd_buf, CQ_NOC_SNDL, CQ_NOC_SEND, CQ_NOC_WAIT, true, false>(
66+
noc_async_wide_write_any_len_with_state(
6767
NOC_INDEX,
6868
read_addr,
6969
pcie_xy_enc,
7070
((static_cast<uint64_t>(write_addr_hi) << 32) | sender_socket.downstream_fifo_addr) +
7171
sender_socket.write_ptr,
72-
page_size,
73-
1);
72+
page_size);
7473
noc_async_writes_flushed();
7574
cb_pop_front(upstream_interface_index, 1);
7675
} else {
@@ -79,14 +78,13 @@ void kernel_main() {
7978
break;
8079
}
8180
uint32_t read_addr = receiver_socket.read_ptr;
82-
noc_wwrite_with_state<noc_mode, write_cmd_buf, CQ_NOC_SNDL, CQ_NOC_SEND, CQ_NOC_WAIT, true, false>(
81+
noc_async_wide_write_any_len_with_state(
8382
NOC_INDEX,
8483
read_addr,
8584
pcie_xy_enc,
8685
((static_cast<uint64_t>(write_addr_hi) << 32) | sender_socket.downstream_fifo_addr) +
8786
sender_socket.write_ptr,
88-
page_size,
89-
1);
87+
page_size);
9088
socket_pop_pages(receiver_socket, 1);
9189
noc_async_writes_flushed();
9290
socket_notify_sender(receiver_socket);
@@ -102,5 +100,4 @@ void kernel_main() {
102100

103101
noc_async_write_barrier();
104102
noc_async_read_barrier();
105-
DPRINT << "End D2H Main Loop" << ENDL();
106103
}
Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
// SPDX-FileCopyrightText: © 2026 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
#include <cstdint>
5+
#include "api/dataflow/dataflow_api.h"
6+
#include "api/socket_api.h"
7+
#include "api/tensor/tensor_accessor.h"
8+
#include "pcie_noc_utils.h"
9+
10+
FORCE_INLINE bool socket_wait_for_pages_with_termination(
11+
const SocketReceiverInterface& socket, uint32_t num_pages, volatile tt_l1_ptr uint32_t* termination_semaphore) {
12+
constexpr uint32_t termination_value = 1;
13+
while (!socket_wait_for_pages(socket, num_pages, 1000)) {
14+
invalidate_l1_cache();
15+
if (termination_semaphore[0] == termination_value) {
16+
return false;
17+
}
18+
}
19+
return true;
20+
}
21+
22+
void kernel_main() {
23+
// Get this value from MeshSocket struct on host
24+
constexpr uint32_t recv_socket_config_addr = get_compile_time_arg_val(0);
25+
constexpr uint32_t termination_semaphore_addr = get_compile_time_arg_val(1);
26+
constexpr uint32_t token_page_size = get_compile_time_arg_val(2);
27+
constexpr bool pull_from_host = get_compile_time_arg_val(3);
28+
constexpr bool loopback_mode = get_compile_time_arg_val(4);
29+
constexpr uint32_t downstream_interface_index = get_compile_time_arg_val(5);
30+
constexpr uint32_t embedding_cb_index = get_compile_time_arg_val(6);
31+
constexpr uint32_t embedding_page_size = get_compile_time_arg_val(7);
32+
constexpr uint32_t embedding_addr = get_compile_time_arg_val(8);
33+
// TensorAccessorArgs for embedding tensor at CT arg index 9
34+
constexpr auto embedding_args = TensorAccessorArgs<9>();
35+
36+
auto embedding_accessor = TensorAccessor(embedding_args, embedding_addr, embedding_page_size);
37+
38+
SocketReceiverInterface receiver_socket = create_receiver_socket_interface(recv_socket_config_addr);
39+
SocketSenderInterface sender_socket = {};
40+
41+
if constexpr (!loopback_mode) {
42+
sender_socket = create_sender_socket_interface(downstream_interface_index);
43+
set_sender_socket_page_size(sender_socket, embedding_page_size);
44+
}
45+
set_receiver_socket_page_size(receiver_socket, token_page_size);
46+
47+
// Read first page of embedding tensor from DRAM into CB
48+
49+
uint32_t read_addr_hi = receiver_socket.h2d.data_addr_hi;
50+
uint32_t read_addr_lo = receiver_socket.h2d.data_addr_lo;
51+
uint32_t pcie_xy_enc = receiver_socket.h2d.pcie_xy_enc;
52+
53+
volatile tt_l1_ptr uint32_t* termination_semaphore =
54+
reinterpret_cast<volatile tt_l1_ptr uint32_t*>(termination_semaphore_addr);
55+
while (true) {
56+
// Wait for pages in H2D socket
57+
if (!socket_wait_for_pages_with_termination(receiver_socket, 1, termination_semaphore)) {
58+
break;
59+
}
60+
if constexpr (pull_from_host) {
61+
// Pages available in H2D socket - read over PCIe
62+
noc_async_wide_read_any_len_with_state(
63+
NOC_INDEX,
64+
pcie_xy_enc,
65+
((static_cast<uint64_t>(read_addr_hi) << 32) | read_addr_lo) + receiver_socket.read_ptr -
66+
receiver_socket.fifo_addr,
67+
receiver_socket.read_ptr,
68+
token_page_size);
69+
noc_async_read_barrier();
70+
}
71+
72+
// TODO: Add and assert that token id is within vocab size
73+
volatile tt_l1_ptr uint32_t* token_id_ptr =
74+
reinterpret_cast<volatile tt_l1_ptr uint32_t*>(receiver_socket.read_ptr);
75+
// Embedding CB is a scratch pad for now. We only read into the first slot of the CB.
76+
// TODO: Setup separate reader to pipeline reads.
77+
uint32_t l1_write_addr = get_write_ptr(embedding_cb_index);
78+
uint64_t noc_addr = embedding_accessor.get_noc_addr(*token_id_ptr);
79+
noc_async_read(noc_addr, l1_write_addr, embedding_page_size);
80+
noc_async_read_barrier();
81+
82+
if constexpr (loopback_mode) {
83+
cb_reserve_back(downstream_interface_index, 1);
84+
noc_async_write(
85+
get_noc_addr(get_read_ptr(embedding_cb_index)),
86+
get_noc_addr(get_write_ptr(downstream_interface_index)),
87+
embedding_page_size);
88+
noc_async_write_barrier();
89+
cb_push_back(downstream_interface_index, 1);
90+
} else {
91+
sender_downstream_encoding downstream_enc = get_downstream_encoding(sender_socket, 0);
92+
socket_reserve_pages(sender_socket, 1);
93+
noc_async_write(
94+
get_noc_addr(get_read_ptr(embedding_cb_index)),
95+
get_noc_addr(
96+
downstream_enc.d2d.downstream_noc_x,
97+
downstream_enc.d2d.downstream_noc_y,
98+
sender_socket.write_ptr + sender_socket.downstream_fifo_addr),
99+
embedding_page_size);
100+
socket_push_pages(sender_socket, 1);
101+
socket_notify_receiver(sender_socket);
102+
noc_async_writes_flushed();
103+
}
104+
socket_pop_pages(receiver_socket, 1);
105+
socket_notify_sender(receiver_socket);
106+
invalidate_l1_cache();
107+
}
108+
109+
update_socket_config(receiver_socket);
110+
if constexpr (!loopback_mode) {
111+
socket_barrier(sender_socket);
112+
}
113+
114+
noc_async_write_barrier();
115+
noc_async_read_barrier();
116+
}

models/demos/deepseek_v3_b1/micro_ops/host_io/kernels/h2d_receiver.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
#include <cstdint>
55
#include "api/dataflow/dataflow_api.h"
66
#include "api/socket_api.h"
7+
#include "pcie_noc_utils.h"
78

89
FORCE_INLINE bool socket_wait_for_pages_with_termination(
910
const SocketReceiverInterface& socket, uint32_t num_pages, volatile tt_l1_ptr uint32_t* termination_semaphore) {
@@ -50,7 +51,7 @@ void kernel_main() {
5051
}
5152
if constexpr (pull_from_host) {
5253
// Pages available in H2D socket - read over PCIe
53-
noc_read_with_state<noc_mode, read_cmd_buf, CQ_NOC_SNDL, CQ_NOC_SEND, CQ_NOC_WAIT>(
54+
noc_async_wide_read_any_len_with_state(
5455
NOC_INDEX,
5556
pcie_xy_enc,
5657
((static_cast<uint64_t>(read_addr_hi) << 32) | read_addr_lo) + receiver_socket.read_ptr -
@@ -59,15 +60,16 @@ void kernel_main() {
5960
page_size);
6061
noc_async_read_barrier();
6162
}
63+
6264
if constexpr (loopback_mode) {
6365
cb_reserve_back(downstream_interface_index, 1);
6466
noc_async_write(
6567
receiver_socket.read_ptr, get_noc_addr(get_write_ptr(downstream_interface_index)), page_size);
6668
noc_async_write_barrier();
6769
cb_push_back(downstream_interface_index, 1);
6870
} else {
69-
socket_reserve_pages(sender_socket, 1);
7071
sender_downstream_encoding downstream_enc = get_downstream_encoding(sender_socket, 0);
72+
socket_reserve_pages(sender_socket, 1);
7173
noc_async_write(
7274
receiver_socket.read_ptr,
7375
get_noc_addr(
@@ -92,5 +94,4 @@ void kernel_main() {
9294

9395
noc_async_write_barrier();
9496
noc_async_read_barrier();
95-
DPRINT << "End H2D Main Loop" << ENDL();
9697
}
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// SPDX-FileCopyrightText: © 2026 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
#pragma once
5+
6+
#include <cstdint>
7+
#include "api/dataflow/dataflow_api.h"
8+
9+
// Contains utility functions to perform IO operations of variable length
10+
// over PCIe.
11+
12+
// This implementation is currently not optimized to minimize RISC cycles.
13+
// APIs can be made more stateful, especially for the HostIO op, since the PCIe
14+
// NOC encoding is constant.
15+
16+
FORCE_INLINE void noc_async_wide_write_any_len_with_state(
17+
uint32_t noc, uint32_t src_addr, uint32_t dst_noc_addr, uint64_t dst_addr, uint32_t len_bytes) {
18+
while (len_bytes > NOC_MAX_BURST_SIZE) {
19+
noc_wwrite_with_state<noc_mode, write_cmd_buf, CQ_NOC_SNDL, CQ_NOC_SEND, CQ_NOC_WAIT, true, false>(
20+
noc, src_addr, dst_noc_addr, dst_addr, NOC_MAX_BURST_SIZE, 1);
21+
len_bytes -= NOC_MAX_BURST_SIZE;
22+
src_addr += NOC_MAX_BURST_SIZE;
23+
dst_addr += NOC_MAX_BURST_SIZE;
24+
}
25+
noc_wwrite_with_state<noc_mode, write_cmd_buf, CQ_NOC_SNDL, CQ_NOC_SEND, CQ_NOC_WAIT, true, false>(
26+
noc, src_addr, dst_noc_addr, dst_addr, len_bytes, 1);
27+
}
28+
29+
FORCE_INLINE void noc_async_wide_read_any_len_with_state(
30+
uint32_t noc, uint64_t src_noc_encoding, uint64_t src_addr, uint32_t dst_addr, uint32_t len_bytes) {
31+
while (len_bytes > NOC_MAX_BURST_SIZE) {
32+
noc_read_with_state<noc_mode, read_cmd_buf, CQ_NOC_SNDL, CQ_NOC_SEND, CQ_NOC_WAIT>(
33+
noc, src_noc_encoding, src_addr, dst_addr, NOC_MAX_BURST_SIZE);
34+
len_bytes -= NOC_MAX_BURST_SIZE;
35+
src_addr += NOC_MAX_BURST_SIZE;
36+
dst_addr += NOC_MAX_BURST_SIZE;
37+
}
38+
noc_read_with_state<noc_mode, read_cmd_buf, CQ_NOC_SNDL, CQ_NOC_SEND, CQ_NOC_WAIT>(
39+
noc, src_noc_encoding, src_addr, dst_addr, len_bytes);
40+
}

0 commit comments

Comments
 (0)