Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -14,20 +14,45 @@

namespace mlir::tt::ttnn::workarounds::decomposition {

// static bool isBatchedLinearOp(ttnn::LinearOp linearOp) {
// RankedTensorType inputBType = linearOp.getB().getType();
// auto inputBShape = inputBType.getShape();
// int64_t rank = inputBShape.size();
//
// // if rank <= 2, it cannot be batched. Return false.
// // Check if batched: any dimension before the last 2 has size > 1.
// // i.e. <1x3x128x32xbf16> is batched because of 3.
// // i.e. <1x1x128x32xbf16> is not batched because all dims before last 2
// are 1.
// // i.e. <128xbf16> is not batched because rank < 2.
// return rank > 2 && llvm::any_of(inputBShape.drop_back(2),
// [](int64_t dim) { return dim > 1; });
// }
static bool isBatchedLinearOp(ttnn::LinearOp linearOp) {
RankedTensorType inputBType = linearOp.getB().getType();
auto inputBShape = inputBType.getShape();
int64_t rank = inputBShape.size();

// if rank <= 2, it cannot be batched. Return false.
// Check if batched: any dimension before the last 2 has size > 1.
// i.e. <1x3x128x32xbf16> is batched because of 3.
// i.e. <1x1x128x32xbf16> is not batched because all dims before last 2 are 1.
// i.e. <128xbf16> is not batched because rank < 2.
return rank > 2 && llvm::any_of(inputBShape.drop_back(2),
[](int64_t dim) { return dim > 1; });
}

// Helper function to check if bias is effectively 1D
// Returns true if bias has only one non-unit dimension (e.g., <64>, <1x64>,
// <1x1x64>)
static bool isEffectively1DBias(TypedValue<RankedTensorType> bias) {
if (!bias) {
return false;
}

RankedTensorType biasType = bias.getType();
if (!biasType) {
return false;
}

auto biasShape = biasType.getShape();
int64_t nonUnitDims = 0;

for (int64_t dim : biasShape) {
if (dim > 1) {
nonUnitDims++;
}
}

// Bias is effectively 1D if it has at most one non-unit dimension
return nonUnitDims <= 1;
}

// Calculate the output shape of a matmul operation following tt-metal's logic.
// Reference: ttnn/cpp/ttnn/operations/matmul/matmul.cpp
Expand All @@ -37,13 +62,38 @@ computeMatmulOutputShape(llvm::ArrayRef<int64_t> shapeA, bool transposeA,
int64_t rankA = shapeA.size();
int64_t rankB = shapeB.size();

// if (rankA == 1 || rankB == 1) {
// TT_assertv(false,
// "Should not reach linear op workaround if rankA or rankB is
// 1");
// }

SmallVector<int64_t> outputShape;

// Handle rank 1 cases
if (rankA == 1 && rankB == 1) {
// vector dot vector -> scalar (but represented as 1D tensor with size 1)
outputShape.push_back(1);
return outputShape;
}

if (rankA == 1) {
// vector-matrix: (K,) x (..., K, N) -> (..., N)
// Result shape is all batch dims from B plus the last dim
outputShape.append(shapeB.begin(), shapeB.end() - 2);
outputShape.push_back(transposeB ? shapeB[rankB - 2] : shapeB[rankB - 1]);
return outputShape;
}

if (rankB == 1) {
// matrix-vector: (..., M, K) x (K,) -> (..., M)
// Result shape is all dims from A except the last (contraction) dim
if (transposeA) {
// If A is transposed, the contraction dim is second-to-last, keep last
outputShape.append(shapeA.begin(), shapeA.end() - 2);
outputShape.push_back(shapeA[rankA - 1]);
} else {
// Normal case: contraction dim is last, keep all but last
outputShape.append(shapeA.begin(), shapeA.end() - 1);
}
return outputShape;
}

// Both inputs are at least rank 2
SmallVector<int64_t> batchShapeA(shapeA.begin(), shapeA.end() - 2);
SmallVector<int64_t> batchShapeB(shapeB.begin(), shapeB.end() - 2);
mlir::OpTrait::util::getBroadcastedShape(batchShapeA, batchShapeB,
Expand Down Expand Up @@ -72,6 +122,18 @@ LogicalResult
LinearOpRewritePattern::matchAndRewrite(ttnn::LinearOp srcOp,
PatternRewriter &rewriter) const {

// Only decompose if bias exists AND (bias is non-1D OR input B is batched)
if (!srcOp.getBias()) {
return failure();
}

bool biasIsNon1D = !isEffectively1DBias(srcOp.getBias());
bool inputBIsBatched = isBatchedLinearOp(srcOp);

if (!biasIsNon1D && !inputBIsBatched) {
return failure();
}

RankedTensorType inputAType = srcOp.getA().getType();
RankedTensorType inputBType = srcOp.getB().getType();
RankedTensorType outputType = srcOp.getResult().getType();
Expand All @@ -91,11 +153,18 @@ LinearOpRewritePattern::matchAndRewrite(ttnn::LinearOp srcOp,
rewriter.getContext(), outputEncoding.getDataType());

// Step 1: Create MatMul operation
// Convert activation to StringAttr if present
mlir::StringAttr activationAttr;
if (auto activation = srcOp.getActivation()) {
activationAttr = rewriter.getStringAttr(activation.value());
}

MatmulOp matmulOp = rewriter.create<ttnn::MatmulOp>(
ttmlir::utils::appendLocationSuffix(srcOp.getLoc(), "_decomp_matmul"),
matmulOutputType, srcOp.getA(), srcOp.getB(), srcOp.getTransposeA(),
srcOp.getTransposeB(),
/*matmul_program_config=*/mlir::Attribute(), /*activation=*/nullptr);
/*matmul_program_config=*/mlir::Attribute(),
/*activation=*/activationAttr);

// Step 2: Create Add operation with bias
AddOp addOp = rewriter.create<ttnn::AddOp>(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def is_callback_enabled():
return debug_stats != "DebugStats Disabled"


@pytest.mark.skip(reason="See https://github.com/tenstorrent/tt-mlir/issues/5789")
def test_intermidate_tensor_manipulation(helper: Helper, request):
binary_path = os.path.join(FLATBUFFER_BASE_PATH, "linear.mlir.tmp.ttnn")
assert os.path.exists(binary_path), f"Binary file not found: {binary_path}"
Expand Down
3 changes: 0 additions & 3 deletions test/python/golden/test_ttnn_fusing.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,9 +84,6 @@ def matmul_sigmoid(
), f"Standalone {activation_name} operation should be fused"


@pytest.mark.xfail(
reason="Fails golden, see https://github.com/tenstorrent/tt-mlir/issues/5789"
)
@pytest.mark.parametrize(
"shapes",
[
Expand Down
6 changes: 0 additions & 6 deletions test/python/golden/test_ttnn_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,6 @@ def clamp_tensor(
)


@pytest.mark.skip(
reason="Segfault, see https://github.com/tenstorrent/tt-mlir/issues/5789"
)
@pytest.mark.parametrize(
"shapes", [[(10, 64, 32), (32, 128), (1,)]], ids=shapes_list_str
)
Expand Down Expand Up @@ -221,9 +218,6 @@ def matmul(
)


@pytest.mark.skip(
reason="Segfault, see https://github.com/tenstorrent/tt-mlir/issues/5789"
)
@pytest.mark.parametrize(
"shapes",
[
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// RUN: ttmlir-opt --ttnn-fusing -o %t %s
// RUN: FileCheck %s --input-file=%t
// UNSUPPORTED: true

// Test fusing sigmoid activation into linear operation
module {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,14 @@ module {
%result = "ttnn.linear"(%arg0, %arg1, %bias) : (tensor<1x3x64x128xbf16>, tensor<1x3x128x32xbf16>, tensor<14x4x3x64x32xbf16>) -> tensor<14x4x3x64x32xbf16>
return %result : tensor<14x4x3x64x32xbf16>
}
func.func @linear_with_sigmoid(%arg0: tensor<100x384xbf16>, %arg1: tensor<4x384xbf16>, %arg2: tensor<1x100x4xbf16>) -> tensor<1x100x4xbf16> {
// CHECK-LABEL: func.func @linear_with_sigmoid
// CHECK: "ttnn.matmul"
// CHECK-SAME: activation = "sigmoid"
// CHECK-SAME: -> tensor<100x4xbf16
// CHECK: "ttnn.add"
// CHECK-SAME: -> tensor<1x100x4xbf16
%result = "ttnn.linear"(%arg0, %arg1, %arg2) <{activation = "sigmoid", transpose_a = false, transpose_b = true}> : (tensor<100x384xbf16>, tensor<4x384xbf16>, tensor<1x100x4xbf16>) -> tensor<1x100x4xbf16>
return %result : tensor<1x100x4xbf16>
}
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" -o %t %s
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" --ttnn-adjust-deallocs -o %t2 %s
// RUN: [ "$(cat %t | grep -c ttnn.deallocate)" -eq 3 ]
// RUN: [ "$(cat %t2 | grep -c ttnn.deallocate)" -eq 1 ]
// UNSUPPORTED: true
// RUN: [ "$(cat %t | grep -c ttnn.deallocate)" -eq 4 ]
// RUN: [ "$(cat %t2 | grep -c ttnn.deallocate)" -eq 2 ]
//
// Test for --ttnn-adjust-deallocs pass.
// The test runs the --ttir-to-ttnn-backend-pipeline twice, and follows up with --ttnn-adjust-deallocs for the second run.
Expand Down
21 changes: 12 additions & 9 deletions test/ttmlir/Dialect/TTNN/linear/linear_tests_positive.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline -o %t %s
// RUN: FileCheck %s --input-file=%t
// UNSUPPORTED: true
module {
func.func @linear_1d_1d_bias(%arg0: tensor<128xbf16>, %arg1: tensor<128xbf16>, %bias: tensor<1xbf16>) -> tensor<1xbf16> {
%0 = ttir.empty() : tensor<1xbf16>
Expand All @@ -25,10 +24,12 @@ module {
}

func.func @linear_2d_2d_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> {
// Bias is expected to be a 2d tensor so it is broken into matmul + add.
%0 = ttir.empty() : tensor<64x64xbf16>
// CHECK: "ttnn.linear"
// CHECK: "ttnn.matmul"
// CHECK-SAME: tensor<64x128xbf16
// CHECK-SAME: tensor<128x64xbf16
// CHECK: "ttnn.add"
// CHECK-SAME: tensor<64x64xbf16
// CHECK-SAME: tensor<64x64xbf16
%1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16>
Expand Down Expand Up @@ -56,39 +57,41 @@ module {
}

// Linear with transposed inputs tests.
func.func @linear_2d_tranpose_2d_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>, %bias: tensor<128x128xbf16>) -> tensor<128x128xbf16> {
func.func @linear_2d_tranpose_1d_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>, %bias: tensor<128xbf16>) -> tensor<128x128xbf16> {
%0 = ttir.empty() : tensor<128x128xbf16>
// CHECK: "ttnn.linear"
// CHECK-SAME: transpose_a = true
// CHECK-SAME: transpose_b = false
// CHECK-SAME: tensor<64x128xbf16
// CHECK-SAME: tensor<64x128xbf16
// CHECK-SAME: tensor<128xbf16
// CHECK-SAME: tensor<128x128xbf16
%1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{transpose_a = true}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<128x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16>
%1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{transpose_a = true}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16>
return %1 : tensor<128x128xbf16>
}

func.func @linear_2d_2d_transpose_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> {
func.func @linear_2d_1d_transpose_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<64x128xbf16>, %bias: tensor<64xbf16>) -> tensor<64x64xbf16> {
%0 = ttir.empty() : tensor<64x64xbf16>
// CHECK: "ttnn.linear"
// CHECK-SAME: transpose_a = false
// CHECK-SAME: transpose_b = true
// CHECK-SAME: tensor<64x128xbf16
// CHECK-SAME: tensor<64x128xbf16
// CHECK-SAME: tensor<64x64xbf16
%1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{transpose_b = true}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16>
// CHECK-SAME: tensor<64xbf16
%1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{transpose_b = true}> : (tensor<64x128xbf16>, tensor<64x128xbf16>, tensor<64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16>
return %1 : tensor<64x64xbf16>
}

func.func @linear_2d_tranpose_2d_transpose(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<128x128xbf16>) -> tensor<128x128xbf16> {
func.func @linear_2d_tranpose_2d_transpose(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<128xbf16>) -> tensor<128x128xbf16> {
%0 = ttir.empty() : tensor<128x128xbf16>
// CHECK: "ttnn.linear"
// CHECK-SAME: transpose_a = true
// CHECK-SAME: transpose_b = true
// CHECK-SAME: tensor<64x128xbf16
// CHECK-SAME: tensor<128x64xbf16
// CHECK-SAME: tensor<128xbf16
// CHECK-SAME: tensor<128x128xbf16
%1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{transpose_a = true, transpose_b = true}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<128x128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16>
%1 = "ttir.linear"(%arg0, %arg1, %bias, %0) <{transpose_a = true, transpose_b = true}> : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<128xbf16>, tensor<128x128xbf16>) -> tensor<128x128xbf16>
return %1 : tensor<128x128xbf16>
}

Expand Down
7 changes: 3 additions & 4 deletions test/ttmlir/Dialect/TTNN/linear/simple_linear.mlir
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline -o %t %s
// RUN: FileCheck %s --input-file=%t
// UNSUPPORTED: true

module {
func.func @simple_linear_with_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> {
func.func @simple_linear_with_bias(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64xbf16>) -> tensor<64x64xbf16> {
%0 = ttir.empty() : tensor<64x64xbf16>
// CHECK: "ttnn.linear"
// CHECK-SAME: tensor<64x128xbf16
// CHECK-SAME: tensor<128x64xbf16
// CHECK-SAME: tensor<64xbf16
// CHECK-SAME: tensor<64x64xbf16
// CHECK-SAME: tensor<64x64xbf16
%1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16>
%1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16>
return %1 : tensor<64x64xbf16>
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// REQUIRES: opmodel
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=GreedyL1Interleaved tensor-l1-usage-cap=0.75" -o %t %s
// RUN: FileCheck %s --input-file=%t
// UNSUPPORTED: true
module attributes {} {
func.func @forward(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x96xbf16>, %arg2: tensor<64x96xbf16>, %arg3: tensor<96x32xbf16>, %arg4: tensor<64x32xbf16>) -> tensor<64x32xbf16> {
// CHECK: #[[L1_:.*]] = #ttnn.buffer_type<l1>
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// REQUIRES: opmodel
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="enable-optimizer=true memory-layout-analysis-enabled=true memory-layout-analysis-policy=GreedyL1Interleaved tensor-l1-usage-cap=0.75" -o %t %s
// RUN: FileCheck %s --input-file=%t
// UNSUPPORTED: true
#loc = loc("MNISTLinear":4294967295:0)
module @"tt-forge-graph" attributes {} {
func.func @main(%arg0: tensor<1x784xf32> loc("MNISTLinear":4294967295:0), %arg1: tensor<1x10xf32> loc("MNISTLinear":4294967295:0), %arg2: tensor<256x10xf32> loc("MNISTLinear":4294967295:0), %arg3: tensor<1x256xf32> loc("MNISTLinear":4294967295:0), %arg4: tensor<784x256xf32> loc("MNISTLinear":4294967295:0)) -> tensor<1x10xf32> {
Expand Down
1 change: 0 additions & 1 deletion test/ttmlir/Silicon/TTNN/n150/deallocate.mlir
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" -o %t.mlir %s
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer -o %t.ttnn %t.mlir
// UNSUPPORTED: true
#loc = loc("Dealloc":4294967295:0)
module @"dealloc_test" attributes {} {
func.func @main(%arg0: tensor<1x784xf32> loc("Dealloc":4294967295:0), %arg1: tensor<1x10xf32> loc("Dealloc":4294967295:0), %arg2: tensor<256x10xf32> loc("Dealloc":4294967295:0), %arg3: tensor<1x256xf32> loc("Dealloc":4294967295:0), %arg4: tensor<784x256xf32> loc("Dealloc":4294967295:0)) -> tensor<1x10xf32> {
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Silicon/TTNN/n150/matmul/linear.mlir
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" -o %t.mlir %s
// RUN: ttmlir-translate --ttnn-to-flatbuffer -o %t.ttnn %t.mlir
// UNSUPPORTED: true

module {
func.func @linear(%arg0: tensor<2x34x1024xf32>, %arg1: tensor<1024x1024xf32>, %bias: tensor<2x34x1024xf32>) -> tensor<2x34x1024xf32> {
%0 = ttir.empty() : tensor<2x34x1024xf32>
Expand Down
2 changes: 1 addition & 1 deletion test/ttmlir/Silicon/TTNN/n150/mixed_precision/linear.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ module {
func.func @linear_with_implicit_broadcast(%arg0: tensor<2x34x1024xf32>, %arg1: tensor<1024x1024xf32> {ttcore.argument_type = #ttcore.argument_type<parameter>}, %bias: tensor<1024xf32>) -> tensor<2x34x1024xf32> {
// CHECK-LABEL: func.func @linear_with_implicit_broadcast
// CHECK: %[[BFP8_WEIGHT:.*]] = ttcore.load_cached({{.*}}, [%arg1]) : {{.*}} -> tensor<{{.*}}bfp_bf8{{.*}}>
// CHECK: "ttnn.matmul"(%arg0, %[[BFP8_WEIGHT]])
// CHECK: "ttnn.linear"(%arg0, %[[BFP8_WEIGHT]], %arg2)
%0 = ttir.empty() : tensor<2x34x1024xf32>
%1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<2x34x1024xf32>, tensor<1024x1024xf32>, tensor<1024xf32>, tensor<2x34x1024xf32>) -> tensor<2x34x1024xf32>
return %1 : tensor<2x34x1024xf32>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" -o mnist_linear_out.mlir %s
// RUN: FileCheck %s --input-file=mnist_linear_out.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer -o %t.ttnn mnist_linear_out.mlir
// UNSUPPORTED: true
#loc = loc("MNISTLinear":0:0)
module @MNISTLinear attributes {} {
func.func @forward(%arg0: tensor<1x784xf32> {ttir.name = "input_1"} loc("MNISTLinear":0:0), %arg1: tensor<784x256xf32> {ttir.name = "l1.weight"} loc("MNISTLinear":0:0), %arg2: tensor<256xf32> {ttir.name = "l1.bias"} loc("MNISTLinear":0:0), %arg3: tensor<256x10xf32> {ttir.name = "l2.weight"} loc("MNISTLinear":0:0), %arg4: tensor<10xf32> {ttir.name = "l2.bias"} loc("MNISTLinear":0:0)) -> (tensor<1x10xf32> {ttir.name = "MNISTLinear.output_softmax_9"}) {
Expand Down
7 changes: 3 additions & 4 deletions test/ttmlir/Silicon/TTNN/n150/perf/test_perf_linear.mlir
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
// RUN: ttmlir-opt --ttir-to-ttnn-backend-pipeline="system-desc-path=%system_desc_path%" -o %t.mlir %s
// RUN: FileCheck %s --input-file=%t.mlir
// RUN: ttmlir-translate --ttnn-to-flatbuffer -o %t.ttnn %t.mlir
// UNSUPPORTED: true

module {
func.func @linear(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64x64xbf16>) -> tensor<64x64xbf16> {
func.func @linear(%arg0: tensor<64x128xbf16>, %arg1: tensor<128x64xbf16>, %bias: tensor<64xbf16>) -> tensor<64x64xbf16> {
%0 = ttir.empty() : tensor<64x64xbf16>
// CHECK: "ttnn.linear"
// CHECK-SAME: tensor<64x128xbf16
// CHECK-SAME: tensor<128x64xbf16
// CHECK-SAME: tensor<64xbf16
// CHECK-SAME: tensor<64x64xbf16
// CHECK-SAME: tensor<64x64xbf16
%1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64x64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16>
%1 = "ttir.linear"(%arg0, %arg1, %bias, %0) : (tensor<64x128xbf16>, tensor<128x64xbf16>, tensor<64xbf16>, tensor<64x64xbf16>) -> tensor<64x64xbf16>
return %1 : tensor<64x64xbf16>
}
}
Loading
Loading