Skip to content

Commit 73b0c7a

Browse files
authored
D2M: Add experimental get_noc_multicast_addr API that flips start/end coords on noc1 (#3855)
### Ticket No mlir issue, part of overarching matmul objective. ### Problem description tt-metal does not flip start and end coordinates in `get_noc_multicast_addr` This is required because the NoC requires the start coordinate to be reached first. ### What's changed This PR introduces an experimental API to reverse start and end coordinates for multicast on noc1. Note: This is not to do with flipping coordinate values in terms of the noc layout. This is simply reversing the start & end coordinates. ### Checklist - [ ] New/Existing tests provide coverage for changes
1 parent 85b2c8a commit 73b0c7a

File tree

7 files changed

+67
-9
lines changed

7 files changed

+67
-9
lines changed

include/ttmlir/Dialect/TTKernel/IR/TTKernelOps.td

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -986,7 +986,17 @@ def TTKernel_MyYOp : TTKernel_Op<"my_y"> {
986986
def TTKernel_GetNocMulticastAddrOp : TTKernel_Op<"get_noc_multicast_addr"> {
987987
let summary = "GetNocMulticastAddr";
988988
let description = [{
989-
GetNocMulticastAddr
989+
Default tt-metal get_noc_multicast_addr
990+
}];
991+
992+
let arguments = (ins IndexLike:$noc_x_start, IndexLike:$noc_y_start, IndexLike:$noc_x_end, IndexLike:$noc_y_end, AnyTypeOf<[I32, TTKernel_L1Addr, TTKernel_Semaphore]>:$addr, Optional<I8>:$noc);
993+
let results = (outs TTKernel_NocAddr:$mcastNocAddr);
994+
}
995+
996+
def TTKernel_ExperimentalGetNocMulticastAddrOp : TTKernel_Op<"experimental::get_noc_multicast_addr"> {
997+
let summary = "Experimental GetNocMulticastAddr";
998+
let description = [{
999+
Default tt-metal get_noc_multicast_addr, but flips mcast start and end coordinates on NOC1.
9901000
}];
9911001

9921002
let arguments = (ins IndexLike:$noc_x_start, IndexLike:$noc_y_start, IndexLike:$noc_x_end, IndexLike:$noc_y_end, AnyTypeOf<[I32, TTKernel_L1Addr, TTKernel_Semaphore]>:$addr, Optional<I8>:$noc);

include/ttmlir/Target/TTKernel/LLKs/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ include(GenerateRawStringHeader)
55
set(LLK_HEADERS
66
${CMAKE_SOURCE_DIR}/include/ttmlir/Target/TTKernel/LLKs/experimental_tilize_llks.h
77
${CMAKE_SOURCE_DIR}/include/ttmlir/Target/TTKernel/LLKs/experimental_untilize_llks.h
8+
${CMAKE_SOURCE_DIR}/include/ttmlir/Target/TTKernel/LLKs/experimental_dataflow_api.h
89
)
910

1011
# Set the output directory for generated headers
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
// SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#ifndef TTMLIR_TARGET_TTKERNEL_LLKS_EXPERIMENTAL_DATAFLOW_API_H
6+
#define TTMLIR_TARGET_TTKERNEL_LLKS_EXPERIMENTAL_DATAFLOW_API_H
7+
8+
namespace experimental {
9+
10+
FORCE_INLINE
11+
std::uint64_t
12+
get_noc_multicast_addr(std::uint32_t noc_x_start, std::uint32_t noc_y_start,
13+
std::uint32_t noc_x_end, std::uint32_t noc_y_end,
14+
std::uint32_t addr, uint8_t noc = noc_index) {
15+
/*
16+
Get an encoding which contains tensix core and address you want to
17+
read from/write to via the noc
18+
*/
19+
if (noc) {
20+
// noc 1
21+
return NOC_MULTICAST_ADDR(
22+
DYNAMIC_NOC_X(noc, noc_x_end), DYNAMIC_NOC_Y(noc, noc_y_end),
23+
DYNAMIC_NOC_X(noc, noc_x_start), DYNAMIC_NOC_Y(noc, noc_y_start), addr);
24+
} else {
25+
// noc 0
26+
return NOC_MULTICAST_ADDR(
27+
DYNAMIC_NOC_X(noc, noc_x_start), DYNAMIC_NOC_Y(noc, noc_y_start),
28+
DYNAMIC_NOC_X(noc, noc_x_end), DYNAMIC_NOC_Y(noc, noc_y_end), addr);
29+
}
30+
}
31+
32+
} // namespace experimental
33+
34+
#endif

lib/Conversion/TTIRToTTKernel/TTIRToTTKernel.cpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -638,9 +638,10 @@ class TTIRDMARewriter : public OpConversionPattern<ttir::DMAOp> {
638638
op.getLoc(), op.getMcastShape()[0], op.getMcastShape()[1]);
639639
auto numDests = rewriter.create<arith::IndexCastOp>(
640640
op.getLoc(), rewriter.getI32Type(), numDestsIdx);
641-
auto mcastAddr = rewriter.create<ttkernel::GetNocMulticastAddrOp>(
642-
op.getLoc(), virtX, virtY, mcastEndX, mcastEndY, dstL1Start,
643-
nullptr);
641+
auto mcastAddr =
642+
rewriter.create<ttkernel::ExperimentalGetNocMulticastAddrOp>(
643+
op.getLoc(), virtX, virtY, mcastEndX, mcastEndY, dstL1Start,
644+
nullptr);
644645
if (adaptor.getSrc() == adaptor.getDst()) {
645646
// If src and dst refer to the same memref, we do not loopback mcast
646647
// Dests are one less because the sender core is not included
@@ -989,9 +990,10 @@ class TTIRSemaphoreUpdateRewriter : public OpConversionPattern<ConcreteOp> {
989990
op.getLoc(), op.getMcastShape()[0], op.getMcastShape()[1]);
990991
Value numDests = rewriter.create<arith::IndexCastOp>(
991992
op.getLoc(), rewriter.getI32Type(), numDestsIdx);
992-
auto mcastAddr = rewriter.create<ttkernel::GetNocMulticastAddrOp>(
993-
op.getLoc(), virtX, virtY, mcastEndX, mcastEndY, semaphoreAddr,
994-
nullptr);
993+
auto mcastAddr =
994+
rewriter.create<ttkernel::ExperimentalGetNocMulticastAddrOp>(
995+
op.getLoc(), virtX, virtY, mcastEndX, mcastEndY, semaphoreAddr,
996+
nullptr);
995997

996998
auto semaphorePtr =
997999
rewriter.create<ttkernel::CastToL1PtrOp>(op.getLoc(), semaphoreAddr);

lib/Conversion/TTKernelToEmitC/TTKernelToEmitC.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,8 @@ class ConvertTTKernelToEmitCPass
659659
TTKernelToEmitCOpaqueRewriter<ttkernel::NocAsyncWriteTileOp>,
660660
TTKernelToEmitCOpaqueRewriter<ttkernel::NocAsyncWriteBarrierOp>,
661661
TTKernelToEmitCOpaqueRewriter<ttkernel::GetNocMulticastAddrOp>,
662+
TTKernelToEmitCOpaqueRewriter<
663+
ttkernel::ExperimentalGetNocMulticastAddrOp>,
662664
TTKernelToEmitCOpaqueRewriter<
663665
ttkernel::NocAsyncWriteMulticastOnePacketOp>,
664666
TTKernelToEmitCOpaqueRewriter<ttkernel::NocAsyncWriteMulticastOp>,

lib/Target/TTKernel/TTKernelToCpp.cpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "ttmlir/Dialect/TTKernel/IR/TTKernelOpsTypes.h"
88

9+
#include "ttmlir/Target/TTKernel/LLKs/experimental_dataflow_api_generated.h"
910
#include "ttmlir/Target/TTKernel/LLKs/experimental_tilize_llks_generated.h"
1011
#include "ttmlir/Target/TTKernel/LLKs/experimental_untilize_llks_generated.h"
1112

@@ -43,6 +44,7 @@ class ScopedModuleHelper {
4344

4445
builder->create<emitc::IncludeOp>(loc, "dataflow_api.h",
4546
/*isStandard=*/false);
47+
emitExperimentalLLKs();
4648
}
4749
if (threadType == ThreadType::Compute) {
4850
builder->create<emitc::IncludeOp>(loc, "llk_defs.h",
@@ -165,6 +167,13 @@ void dprint(Arg &&arg, ArgV&&... argv) {
165167
experimental_untilize_llks_generated_len);
166168
builder->create<emitc::VerbatimOp>(loc, experimentalUntilizeLLKs);
167169
}
170+
171+
if (hasCall("experimental::get_noc_multicast_addr")) {
172+
auto experimentalDataflowLLKs =
173+
StringRef(experimental_dataflow_api_generated,
174+
experimental_dataflow_api_generated_len);
175+
builder->create<emitc::VerbatimOp>(loc, experimentalDataflowLLKs);
176+
}
168177
}
169178

170179
bool hasCall(StringRef name) {

test/ttmlir/Conversion/TTIRToTTKernel/dma_lowering.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ func.func @test_local_to_remote_multicast_regular(%arg0: memref<2x2x!ttcore.tile
4242
%c4 = arith.constant 4 : index
4343
// CHECK: ttkernel.get_read_ptr
4444
// CHECK: ttkernel.get_write_ptr
45-
// CHECK: ttkernel.get_noc_multicast_addr
45+
// CHECK: ttkernel.experimental::get_noc_multicast_addr
4646
// CHECK: ttkernel.noc_async_write_multicast
4747
%0 = ttir.dma %arg0[%c0, %c0], %arg0[%c0, %c0] core[%c1, %c2] mcast[%c3, %c4] : (memref<2x2x!ttcore.tile<32x32, f32>, #l1_>, memref<2x2x!ttcore.tile<32x32, f32>, #l1_>) -> !ttir.mem_tx
4848
ttir.dma_wait %0
@@ -59,7 +59,7 @@ func.func @test_local_to_remote_multicast_loopback(%arg0: memref<2x2x!ttcore.til
5959
%c4 = arith.constant 4 : index
6060
// CHECK: ttkernel.get_read_ptr
6161
// CHECK: ttkernel.get_write_ptr
62-
// CHECK: ttkernel.get_noc_multicast_addr
62+
// CHECK: ttkernel.experimental::get_noc_multicast_addr
6363
// CHECK: ttkernel.noc_async_write_multicast_loopback_src
6464
%0 = ttir.dma %arg0[%c0, %c0], %arg1[%c0, %c0] core[%c1, %c2] mcast[%c3, %c4] : (memref<2x2x!ttcore.tile<32x32, f32>, #l1_>, memref<2x2x!ttcore.tile<32x32, f32>, #l1_>) -> !ttir.mem_tx
6565
ttir.dma_wait %0

0 commit comments

Comments
 (0)