Skip to content

Commit 12a5085

Browse files
committed
EZ API: lazy kernel creation, eliminate done()
Refactor ProgramBuilder so kernels are created lazily at build() time instead of eagerly in reader()/writer()/compute(). This enables defines() and named_args() to be called on KernelRef after the kernel method, and eliminates the need for done() to chain back to the builder. KernelRef now stores all kernel configuration as a deferred descriptor and provides forwarding methods (cb, reader, writer, compute, kernel, on, semaphore, build) that implicitly transition back to the builder.
1 parent 5b1f7ae commit 12a5085

File tree

19 files changed

+399
-132
lines changed

19 files changed

+399
-132
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// SPDX-FileCopyrightText: (c) 2026 Olof Johansson <olof@lixom.net>
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
// Variant of tiles_add that reads CB indices from named compile-time args
6+
// instead of hardcoding them. Used to test per-kernel named_args on KernelRef.
7+
8+
#include <cstdint>
9+
#include "api/compute/common.h"
10+
#include "api/compute/tile_move_copy.h"
11+
#include "api/compute/eltwise_binary.h"
12+
#include "api/compute/compute_kernel_api.h"
13+
14+
void kernel_main() {
15+
uint32_t n_tiles = get_arg_val<uint32_t>(0);
16+
17+
constexpr auto cb_in0 = get_named_compile_time_arg_val("cb_in0");
18+
constexpr auto cb_in1 = get_named_compile_time_arg_val("cb_in1");
19+
constexpr auto cb_out0 = get_named_compile_time_arg_val("cb_out0");
20+
constexpr uint32_t dst_reg = 0;
21+
22+
binary_op_init_common(cb_in0, cb_in1, cb_out0);
23+
add_tiles_init(cb_in0, cb_in1);
24+
25+
for (uint32_t i = 0; i < n_tiles; i++) {
26+
cb_wait_front(cb_in0, 1);
27+
cb_wait_front(cb_in1, 1);
28+
tile_regs_acquire();
29+
add_tiles(cb_in0, cb_in1, 0, 0, dst_reg);
30+
tile_regs_commit();
31+
tile_regs_wait();
32+
cb_reserve_back(cb_out0, 1);
33+
pack_tile(dst_reg, cb_out0);
34+
cb_push_back(cb_out0, 1);
35+
cb_pop_front(cb_in0, 1);
36+
cb_pop_front(cb_in1, 1);
37+
tile_regs_release();
38+
}
39+
}

tests/tt_metal/tt_metal/ez/test_ez_api.cpp

Lines changed: 58 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@ TEST_F(EzApiTest, HelloWorld) {
2020
DeviceContext ctx(0);
2121
auto program = ProgramBuilder(CoreCoord{0, 0})
2222
.compute("tests/tt_metal/tt_metal/ez/kernels/void_compute.cpp")
23-
.done()
2423
.build();
2524
ctx.run(std::move(program));
2625
}
@@ -79,15 +78,12 @@ TEST_F(EzApiTest, EltwiseBinary) {
7978
"tests/tt_metal/tt_metal/ez/kernels/read_tiles.cpp",
8079
{src0, src1})
8180
.runtime_args({src0->address(), src1->address(), n_tiles})
82-
.done()
8381
.compute("tests/tt_metal/tt_metal/ez/kernels/tiles_add.cpp")
8482
.runtime_args({n_tiles})
85-
.done()
8683
.writer(
8784
"tests/tt_metal/tt_metal/ez/kernels/write_tile.cpp",
8885
{dst})
8986
.runtime_args({dst->address(), n_tiles})
90-
.done()
9187
.build();
9288

9389
ctx.run(std::move(program));
@@ -111,7 +107,6 @@ TEST_F(EzApiTest, MultiCore) {
111107
auto& k = builder.compute("tests/tt_metal/tt_metal/ez/kernels/void_compute.cpp");
112108
k.runtime_args_at(CoreCoord{0, 0}, {});
113109
k.runtime_args_at(CoreCoord{1, 0}, {});
114-
k.done();
115110

116111
auto program = builder.build();
117112
ctx.run(std::move(program));
@@ -127,8 +122,7 @@ TEST_F(EzApiTest, PerCoreLambdaArgs) {
127122
.runtime_args([](const CoreCoord& core) -> std::vector<uint32_t> {
128123
// Each core gets a unique arg based on its x coordinate.
129124
return {core.x};
130-
})
131-
.done();
125+
});
132126

133127
auto program = builder.build();
134128
ctx.run(std::move(program));
@@ -142,10 +136,8 @@ TEST_F(EzApiTest, CoreOverride) {
142136

143137
auto program = ProgramBuilder(default_core)
144138
.compute("tests/tt_metal/tt_metal/ez/kernels/void_compute.cpp")
145-
.done()
146139
.on(other_core)
147140
.compute("tests/tt_metal/tt_metal/ez/kernels/void_compute.cpp")
148-
.done()
149141
.build();
150142

151143
ctx.run(std::move(program));
@@ -179,7 +171,6 @@ TEST_F(EzApiTest, L1BackedCircularBuffer) {
179171
auto program = ProgramBuilder(CoreCoord{0, 0})
180172
.cb(tt::CBIndex::c_0, l1_buf)
181173
.compute("tests/tt_metal/tt_metal/ez/kernels/void_compute.cpp")
182-
.done()
183174
.build();
184175
ctx.run(std::move(program));
185176
}
@@ -208,8 +199,7 @@ TEST_F(EzApiTest, Semaphore) {
208199
uint32_t sem_addr = builder.semaphore(0);
209200

210201
builder.compute("tests/tt_metal/tt_metal/ez/kernels/void_compute.cpp")
211-
.runtime_args([&](const CoreCoord&) -> std::vector<uint32_t> { return {sem_addr}; })
212-
.done();
202+
.runtime_args([&](const CoreCoord&) -> std::vector<uint32_t> { return {sem_addr}; });
213203
auto program = builder.build();
214204
ctx.run(std::move(program));
215205
}
@@ -229,8 +219,63 @@ TEST_F(EzApiTest, NonBlockingLaunchAndFinish) {
229219
DeviceContext ctx(0);
230220
auto program = ProgramBuilder(CoreCoord{0, 0})
231221
.compute("tests/tt_metal/tt_metal/ez/kernels/void_compute.cpp")
232-
.done()
233222
.build();
234223
ctx.launch(std::move(program));
235224
ctx.finish();
236225
}
226+
227+
TEST_F(EzApiTest, PerKernelNamedArgs) {
228+
// Verify that named_args() called on a KernelRef applies to that kernel only.
229+
// The compute kernel reads CB indices from named compile-time args; if they
230+
// were routed to the wrong kernel (or lost), compilation would fail.
231+
DeviceContext ctx(0);
232+
constexpr uint32_t n_tiles = 4;
233+
constexpr uint32_t elements = n_tiles * tt::constants::TILE_WIDTH * tt::constants::TILE_HEIGHT;
234+
235+
auto src0 = ctx.dram_tile_buffer(n_tiles);
236+
auto src1 = ctx.dram_tile_buffer(n_tiles);
237+
auto dst = ctx.dram_tile_buffer(n_tiles);
238+
239+
std::mt19937 rng(999);
240+
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
241+
std::vector<bfloat16> a_data(elements), b_data(elements);
242+
for (size_t i = 0; i < elements; ++i) {
243+
a_data[i] = bfloat16(dist(rng));
244+
b_data[i] = bfloat16(dist(rng));
245+
}
246+
247+
ctx.write(src0, a_data);
248+
ctx.write(src1, b_data);
249+
250+
// named_args() is called on the KernelRef returned by .compute(), not on the builder.
251+
constexpr CoreCoord core = {0, 0};
252+
auto program = ProgramBuilder(core)
253+
.cb(tt::CBIndex::c_0)
254+
.cb(tt::CBIndex::c_1)
255+
.cb(tt::CBIndex::c_16)
256+
.reader(
257+
"tests/tt_metal/tt_metal/ez/kernels/read_tiles.cpp",
258+
{src0, src1})
259+
.runtime_args({src0->address(), src1->address(), n_tiles})
260+
.compute("tests/tt_metal/tt_metal/ez/kernels/tiles_add_named.cpp")
261+
.named_args({{"cb_in0", (uint32_t)tt::CBIndex::c_0},
262+
{"cb_in1", (uint32_t)tt::CBIndex::c_1},
263+
{"cb_out0", (uint32_t)tt::CBIndex::c_16}})
264+
.runtime_args({n_tiles})
265+
.writer(
266+
"tests/tt_metal/tt_metal/ez/kernels/write_tile.cpp",
267+
{dst})
268+
.runtime_args({dst->address(), n_tiles})
269+
.build();
270+
271+
ctx.run(std::move(program));
272+
auto result = ctx.read<bfloat16>(dst);
273+
274+
ASSERT_EQ(result.size(), elements);
275+
constexpr float eps = 1e-2f;
276+
for (size_t i = 0; i < elements; ++i) {
277+
float expected = static_cast<float>(a_data[i]) + static_cast<float>(b_data[i]);
278+
float actual = static_cast<float>(result[i]);
279+
EXPECT_NEAR(expected, actual, eps) << "Mismatch at index " << i;
280+
}
281+
}

tt_metal/api/tt-metalium/experimental/ez/program_builder.hpp

Lines changed: 72 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,23 @@ using CoreSpec = std::variant<CoreCoord, CoreRange, CoreRangeSet>;
2727

2828
class ProgramBuilder;
2929

30-
// Handle to a kernel added via ProgramBuilder, supporting fluent runtime_args configuration.
30+
// Deferred kernel descriptor added via ProgramBuilder. Stores all configuration
31+
// and defers CreateKernel + SetRuntimeArgs to build() time.
32+
// Supports fluent chaining: methods for kernel configuration (defines, named_args,
33+
// runtime_args) return KernelRef&, while builder methods (cb, reader, writer,
34+
// compute, kernel, on, semaphore, build) forward to the parent ProgramBuilder.
3135
class KernelRef {
3236
public:
37+
enum class Type { Reader, Writer, Compute, Custom };
38+
39+
// Kernel configuration methods (return KernelRef& for chaining).
40+
41+
// Set preprocessor defines for this kernel.
42+
KernelRef& defines(const std::map<std::string, std::string>& defs);
43+
44+
// Set named compile-time args for this kernel.
45+
KernelRef& named_args(const std::unordered_map<std::string, uint32_t>& args);
46+
3347
// Set runtime args applied uniformly across all cores in the kernel's core spec.
3448
KernelRef& runtime_args(std::initializer_list<uint32_t> args);
3549
KernelRef& runtime_args(const std::vector<uint32_t>& args);
@@ -40,19 +54,70 @@ class KernelRef {
4054
// Set runtime args for a specific core.
4155
KernelRef& runtime_args_at(const CoreCoord& core, const std::vector<uint32_t>& args);
4256

43-
// Return to the ProgramBuilder for chaining.
44-
ProgramBuilder& done();
57+
// Forwarding methods to ProgramBuilder (return appropriate type for chaining).
58+
59+
ProgramBuilder& cb(tt::CBIndex index, uint32_t num_tiles = 2,
60+
tt::DataFormat fmt = tt::DataFormat::Float16_b);
61+
ProgramBuilder& cb(tt::CBIndex index, tt::DataFormat fmt, uint32_t num_tiles, uint32_t page_size);
62+
ProgramBuilder& cb(tt::CBIndex index, const std::shared_ptr<distributed::MeshBuffer>& l1_buffer,
63+
uint32_t num_tiles = 0, tt::DataFormat fmt = tt::DataFormat::Float16_b);
64+
ProgramBuilder& cb(const CircularBufferConfig& config);
65+
66+
KernelRef& reader(
67+
const std::string& path,
68+
const std::vector<std::shared_ptr<distributed::MeshBuffer>>& buffers = {},
69+
const std::vector<uint32_t>& compile_args = {});
70+
KernelRef& writer(
71+
const std::string& path,
72+
const std::vector<std::shared_ptr<distributed::MeshBuffer>>& buffers = {},
73+
const std::vector<uint32_t>& compile_args = {});
74+
KernelRef& compute(
75+
const std::string& path,
76+
MathFidelity fidelity = MathFidelity::HiFi4,
77+
const std::vector<uint32_t>& compile_args = {});
78+
KernelRef& compute(const std::string& path, const ComputeConfig& config);
79+
KernelRef& kernel(
80+
const std::string& path,
81+
const std::variant<DataMovementConfig, ComputeConfig, EthernetConfig>& config);
4582

46-
// Access the underlying kernel handle.
47-
KernelHandle handle() const;
83+
ProgramBuilder& on(const CoreSpec& core_spec);
84+
ProgramBuilder& defines_next(const std::map<std::string, std::string>& defs);
85+
ProgramBuilder& named_args_next(const std::unordered_map<std::string, uint32_t>& args);
86+
uint32_t semaphore(uint32_t initial_value = 0);
87+
uint32_t semaphore(const CoreSpec& cores, uint32_t initial_value = 0);
88+
Program build();
4889

4990
private:
5091
friend class ProgramBuilder;
51-
KernelRef(ProgramBuilder& builder, KernelHandle handle, CoreSpec core_spec);
92+
93+
using DeferredRuntimeArgs = std::function<void(Program&, KernelHandle, const CoreSpec&)>;
94+
95+
KernelRef(ProgramBuilder& builder, Type type, std::string path, CoreSpec core_spec);
5296

5397
ProgramBuilder& builder_;
54-
KernelHandle handle_;
98+
Type type_;
99+
std::string path_;
55100
CoreSpec core_spec_;
101+
102+
// Reader/Writer specific.
103+
std::vector<std::shared_ptr<distributed::MeshBuffer>> buffers_;
104+
std::vector<uint32_t> compile_args_;
105+
106+
// Compute specific.
107+
MathFidelity fidelity_ = MathFidelity::HiFi4;
108+
109+
// Custom kernel config (for kernel() with full config).
110+
std::optional<std::variant<DataMovementConfig, ComputeConfig, EthernetConfig>> custom_config_;
111+
112+
// Shared.
113+
std::map<std::string, std::string> defines_;
114+
std::unordered_map<std::string, uint32_t> named_compile_args_;
115+
116+
// Deferred runtime args — replayed at build() time.
117+
std::vector<DeferredRuntimeArgs> deferred_runtime_args_;
118+
119+
// Materialize this kernel into the program. Called by ProgramBuilder::build().
120+
void materialize(Program& program);
56121
};
57122

58123
// Fluent builder for constructing a Program with circular buffers and kernels.

0 commit comments

Comments
 (0)