Add experimental EZ convenience API for simplified Metalium programming#37909
Open
olofj wants to merge 6 commits intotenstorrent:mainfrom
Open
Add experimental EZ convenience API for simplified Metalium programming#37909olofj wants to merge 6 commits intotenstorrent:mainfrom
olofj wants to merge 6 commits intotenstorrent:mainfrom
Conversation
The existing TT-Metalium API requires substantial boilerplate for even
simple programs: manual MeshDevice lifecycle, MeshBuffer configuration
with DeviceLocalBufferConfig/ReplicatedBufferConfig, CircularBufferConfig
construction, TensorAccessorArgs assembly, DataMovementConfig/ComputeConfig
wiring, MeshWorkload + EnqueueMeshWorkload orchestration, and explicit
device teardown. This cognitive overhead makes the API hard to learn and
the resulting code hard to read, especially for programming examples
that are supposed to teach newcomers.
The new `tt::tt_metal::experimental::ez` namespace provides two RAII
wrappers that dramatically simplify the common case:
**DeviceContext** — wraps MeshDevice creation, command queue, and teardown:
DeviceContext ctx(device_id);
auto buf = ctx.dram_tile_buffer(num_tiles);
ctx.write(buf, input_data);
ctx.run(program);
auto result = ctx.read<bfloat16>(buf);
Key methods:
- dram_tile_buffer(), dram_buffer(), l1_buffer(), sharded_l1_buffer()
- write<T>(), read<T>() with blocking control
- run() (synchronous), launch()/finish() (asynchronous)
- physical_core() for logical-to-physical coordinate translation
- device() for direct MeshDevice access when needed
**ProgramBuilder** — fluent builder for Programs with method chaining:
auto builder = ProgramBuilder(core);
builder.cb(CBIndex::c_0, 2)
.cb(CBIndex::c_1, 2)
.cb(CBIndex::c_16, 2);
auto& reader = builder.reader("reader.cpp", {buf_a, buf_b});
auto& writer = builder.writer("writer.cpp", {buf_c});
builder.compute("compute.cpp", MathFidelity::HiFi4, {Mt, Kt, Nt});
reader.runtime_args({buf_a->address(), buf_b->address(), n_tiles});
ctx.run(builder.build());
Key features:
- Auto-generates TensorAccessorArgs from buffer lists
- compile_args placed before TensorAccessorArgs (matching codebase convention)
- .on(core_spec) for per-core kernel assignment in multi-core programs
- .runtime_args_at(core, args) for SPMD per-core arguments
- .runtime_args(lambda) for computed per-core arguments
- .semaphore() for semaphore creation
- .cb(CircularBufferConfig) raw overload for advanced patterns like
shared CB indices (multiple CB indices aliasing same L1 memory)
- .cb(index, l1_buffer) for L1-backed circular buffers
The API is intentionally minimal and non-opinionated: it wraps the
existing Metalium primitives without hiding them. Users can always
call ctx.device() or builder.build() to drop down to the full API
when the convenience layer doesn't cover their use case.
Includes unit tests validating buffer creation, circular buffer
configuration, kernel execution, and data round-trip correctness.
Convert all single-core programming examples from the verbose Metalium
API to the new experimental::ez convenience wrappers. These examples
demonstrate the most basic Metalium programming patterns: creating a
device, allocating DRAM buffers, configuring circular buffers, launching
reader/compute/writer kernels, and reading back results.
The old API required 6-8 separate setup steps per program:
auto mesh_device = distributed::MeshDevice::create_unit_mesh(0);
distributed::MeshCommandQueue& cq = mesh_device->mesh_command_queue();
distributed::MeshWorkload workload;
auto device_range = distributed::MeshCoordinateRange(mesh_device->shape());
Program program = CreateProgram();
// ... DeviceLocalBufferConfig, ReplicatedBufferConfig, MeshBuffer::create
// ... CircularBufferConfig, CreateCircularBuffer
// ... TensorAccessorArgs assembly, DataMovementConfig, CreateKernel
// ... SetRuntimeArgs, EnqueueWriteMeshBuffer, EnqueueMeshWorkload
// ... EnqueueReadMeshBuffer, mesh_device->close()
With the EZ API, the same program reads naturally:
DeviceContext ctx(device_id);
auto buf = ctx.dram_tile_buffer(n_tiles);
auto builder = ProgramBuilder(core);
builder.cb(CBIndex::c_0, 2).cb(CBIndex::c_1, 2);
auto& reader = builder.reader("reader.cpp", {buf_a, buf_b});
reader.runtime_args({buf_a->address(), n_tiles});
ctx.write(buf_a, data);
ctx.run(builder.build());
auto result = ctx.read<bfloat16>(buf);
All educational comments explaining Tenstorrent hardware concepts
(tile layout, circular buffers, Tensix core architecture, SFPU/FPU
pipelines) are preserved — only API boilerplate is removed.
Converted examples:
- hello_world_datamovement_kernel (DRAM read/write basics)
- hello_world_datatypes_kernel (bfloat16 data handling)
- hello_world_compute_kernel (basic compute kernel)
- add_2_integers_in_compute (FPU tile addition)
- add_2_integers_in_riscv (RISC-V scalar addition)
- custom_sfpi_add (SFPU custom kernel)
- custom_sfpi_smoothstep (SFPU math approximation)
- eltwise_binary (element-wise binary operation)
- loopback (DRAM round-trip)
- contributed/vecadd (community vector addition)
Convert five examples that showcase data movement patterns, sharding, and multi-stage compute pipelines: - shard_data_rm: Row-major sharding with ShardConfig helper Demonstrates DeviceContext::sharded_l1_buffer() and L1-backed circular buffers — the old API required manually computing shard specs, buffer configs, and wiring them together. - sfpu_eltwise_chain: Multi-kernel SFPU pipeline (exp → sqrt → gelu) Shows ProgramBuilder chaining three compute kernels with shared circular buffers. Previously required 50+ lines of boilerplate CB setup; now uses builder.cb() with sensible defaults. - NoC_tile_transfer: Explicit NoC data movement between cores Demonstrates per-core kernel placement with builder.on(core) and physical_core() translation. The old API required manually constructing CoreRange objects and calling multiple overloaded SetRuntimeArgs variants. - vecadd_sharding: Sharded vector addition across cores End-to-end sharding example using ShardConfig for automatic shard spec calculation and L1-backed circular buffers. - eltwise_sfpu: Single-core exponential (SFPU) operation Shows the compute() helper with MathFidelity parameter and TensorAccessorArgs auto-generation from buffer lists. Also wires up the EZ API test subdirectory in the parent test CMakeLists.txt (missed from the initial API commit). All educational comments explaining Tensix architecture concepts, NoC mechanics, and data format details are preserved verbatim.
Convert three examples that demonstrate multi-core programming patterns
on Tenstorrent hardware:
- pad_multi_core: Padding operation distributed across cores
Shows CoreRangeSet-based work distribution with per-core runtime
args via the lambda overload: kernel.runtime_args([](const CoreCoord& core) { ... }).
The old API required manually iterating core ranges and calling
SetRuntimeArgs for each core — easy to get wrong with off-by-one
errors on core coordinate arithmetic.
- vecadd_multi_core: Vector addition across multiple cores
Demonstrates the SPMD (Single Program Multiple Data) pattern with
split_work_to_cores for balanced work distribution. Per-core
runtime args use runtime_args_at() for explicit core targeting.
Previously required ~30 lines of boilerplate for program setup
and core iteration.
- contributed/multicast: NoC multicast from sender to receiver cores
The most complex data movement example — uses builder.on(core) to
place different kernels on sender vs. receiver cores, semaphore()
for synchronization, and physical_core() for logical-to-physical
coordinate translation needed by NoC multicast operations.
The old API version was ~100 lines longer and required juggling
multiple CreateKernel/CreateSemaphore/SetRuntimeArgs calls with
different core specs.
All educational comments explaining core allocation strategies,
NoC multicast semantics, and synchronization patterns are preserved.
Convert the three matrix multiplication examples — the most complex
programs in the examples suite — demonstrating progressively
sophisticated parallelization strategies:
- matmul_single_core: Basic tiled matmul on one Tensix core
Shows compile-time args via compute(path, fidelity, {Mt, Kt, Nt})
and TensorAccessorArgs auto-generation from buffer lists.
The compute kernel receives no runtime args — only compile-time
tile dimensions — which the EZ API makes obvious by having no
.runtime_args() call after the kernel declaration.
- matmul_multi_core: SPMD matmul with 2D work distribution
Uses split_work_to_cores for balanced partitioning, then
runtime_args_at() per core to set tile offset/count. The old API
required manually extracting the core grid from split_work_to_cores
output, creating the program, setting up CBs, creating kernels,
and then iterating cores again for runtime args — all as separate
disconnected API calls. The EZ version reads top-to-bottom as a
coherent program description.
- matmul_multicore_reuse: Block-tiled matmul with data reuse
The most complex example. Uses bmm_op_utils::get_large_matmul_params
for automatic block parameter selection, num_cores_to_corerangeset
for core allocation, and shared circular buffer indices (c_16/c_24
aliasing the same L1 memory) via the raw CircularBufferConfig
overload. Per-core runtime args compute 2D block coordinates from
a linear core index.
The shared CB pattern (c_16 for output, c_24 for intermediate)
is functionally required — not an optimization — because the
compute kernel writes partial results to c_24 and the writer
reads finished tiles from c_16, and they must share the same
physical L1 memory region.
All educational comments explaining tile arithmetic, block
parameter constraints, and data reuse strategies are preserved.
Contributor
Author
|
I just realized that the named_args() pattern of needing to be before the kernel (i.e. program()) statement feels backwards, especially since the runtime_args() are after. Looking into refactoring that, but the rest can be reviewed. |
Refactor ProgramBuilder so kernels are created lazily at build() time instead of eagerly in reader()/writer()/compute(). This enables defines() and named_args() to be called on KernelRef after the kernel method, and eliminates the need for done() to chain back to the builder. KernelRef now stores all kernel configuration as a deferred descriptor and provides forwarding methods (cb, reader, writer, compute, kernel, on, semaphore, build) that implicitly transition back to the builder.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Ticket
N/A — new experimental API
Problem description
The TT-Metalium programming API is powerful but demands significant boilerplate for even simple programs. A basic "read tiles, compute, write back" example requires manually managing MeshDevice lifecycle, constructing DeviceLocalBufferConfig + ReplicatedBufferConfig + MeshBuffer, assembling CircularBufferConfig objects, building TensorAccessorArgs, wiring DataMovementConfig/ComputeConfig, orchestrating MeshWorkload + EnqueueMeshWorkload, and cleaning up — typically 6-8 disconnected setup steps before any domain logic appears.
This is a steep learning curve for newcomers and makes the programming examples — the primary teaching material — harder to follow than they need to be. The intent of each program gets buried under API ceremony.
What's changed
A new
tt::tt_metal::experimental::eznamespace provides two thin RAII wrappers (~690 lines of library code) that let programs express what they want instead of how to wire it up:DeviceContext — owns device lifecycle, buffer creation, and data transfer:
ProgramBuilder — fluent builder for programs with deferred kernel creation and method chaining:
Kernel creation is deferred until
build(). This means per-kernel configuration —defines(),named_args(),runtime_args()— can be called on theKernelRefreturned byreader()/writer()/compute(), and each applies only to that kernel. Forwarding methods onKernelRefallow seamless chaining back to the builder without any intermediate commit step.The API is intentionally minimal — it wraps existing Metalium primitives without hiding them.
ctx.device()andbuilder.build()provide escape hatches to the full API at any point.All 21 programming examples have been converted. The reduction is most dramatic in the simpler examples — exactly the ones newcomers encounter first — where up to 45% of the code was pure boilerplate. Even the complex multi-core matmul examples see meaningful cleanup. Educational comments about hardware concepts (tile layout, circular buffers, NoC mechanics, SFPU pipelines) are preserved throughout.
Importantly, the lines removed are not comments or logic — they are API plumbing. The domain logic and educational content are preserved.
Key ProgramBuilder features exercised across the examples:
reader()/writer()auto-generate TensorAccessorArgs from buffer lists.on(core)places kernels on specific cores for multi-core programs.runtime_args(lambda)computes per-core args without manual core iteration.defines()and.named_args()onKernelReffor per-kernel compile-time configuration.semaphore()and.physical_core()for synchronization and NoC multicast.cb(index, l1_buffer)for L1-backed circular buffers.cb(CircularBufferConfig)raw overload for shared CB index patternsCommit structure:
Checklist