Skip to content

Commit 2190abc

Browse files
authored
Fused lerp operation & lerp_tile LLK (#37441)
### Ticket #36763 ### Problem description For Wan2.2, we want to perform the following: ``` permuted_noise_pred = permuted_noise_uncond + current_guidance_scale * ( permuted_noise_pred - permuted_noise_uncond ) ``` This is equivalent to `lerp(permuted_noise_uncond, permuted_noise_pred, current_guidance_scale)`. However, `ttnn.lerp` is defined as a composite and may be lossy due to `ttnn::add`, `ttnn::subtract` and `ttnn::multiply`, especially with bfloat16 data. Fusing the operations should also improve performance. ### What's changed - Add fused lerp implementation - Update `ttnn.lerp` to take `output_tensor` - Add `lerp_tile` LLK - Update tests for lerp, tests now use ULP instead of PCC. - Add shape for Wan2.2 in test. #### Performance Performance (Wormhole N150) On 9472 x 64 tensor in DRAM: dtype | Branch | Total Duration [ms] -- | -- | -- bfloat16 | main | 8813 bfoat16 | new | 1495 float32 | main | 15461 float32 | new | 2761 ### Checklist - [x] [![All post-commit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml/badge.svg?branch=nmaurice/36763-fused-lerp)](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml?query=branch:nmaurice/36763-fused-lerp) - [x] [![Blackhole Post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml/badge.svg?branch=nmaurice/36763-fused-lerp)](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml?query=branch:nmaurice/36763-fused-lerp) (unrelated failure, also in other [branches](https://github.com/tenstorrent/tt-metal/actions/runs/21864535471/job/63102851430)) - [x] [![cpp-unit-tests](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml/badge.svg?branch=nmaurice/36763-fused-lerp)](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml?query=branch:nmaurice/36763-fused-lerp) (executes `test_lerp.py`) - [x] New/Existing tests provide coverage for changes #### Model tests If your changes cover model-related code, you should run tests corresponding to affected models and platforms (Single card, T3K, Galaxy). "Choose your pipeline" workflows facilitate running multiple kinds of tests in a single run. Each offers `models-mandatory` and `models-extended` presets. The former includes a minimal set of tests, to be run always. The latter extends that with additional ones - use your best judgement in deciding which is the most appropriate for your PR. - [ ] [![(Single) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select.yaml/badge.svg?branch=nmaurice/36763-fused-lerp)](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select.yaml?query=branch:nmaurice/36763-fused-lerp) - [ ] `models-mandatory` preset (runs: [Device perf regressions](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-device-models.yaml) and [Frequent model and ttnn tests](https://github.com/tenstorrent/tt-metal/actions/workflows/fast-dispatch-full-regressions-and-models.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/single-card-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/perf-models.yaml) tests) - [ ] other selection - specify runs - [ ] [![(T3K) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-t3k.yaml/badge.svg?branch=nmaurice/36763-fused-lerp)](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-t3k.yaml?query=branch:nmaurice/36763-fused-lerp) - [ ] `models-mandatory` preset (runs: [Unit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-unit-tests.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/t3000-model-perf-tests.yaml) tests) - [ ] other selection - specify runs - [ ] [![(Galaxy) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-galaxy.yaml/badge.svg?branch=nmaurice/36763-fused-lerp)](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-galaxy.yaml?query=branch:nmaurice/36763-fused-lerp) - [ ] `models-mandatory` preset (runs: [Quick tests](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-quick.yaml)) - [ ] `models-extended` preset (runs: the mandatory tests, plus [Demo](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-demo-tests.yaml) and [Model perf](https://github.com/tenstorrent/tt-metal/actions/workflows/galaxy-perf-tests.yaml) tests) - [ ] other selection - specify runs
1 parent a4bcdc7 commit 2190abc

File tree

21 files changed

+452
-63
lines changed

21 files changed

+452
-63
lines changed

tests/ttnn/nightly/unit_tests/operations/eltwise/test_lerp.py

Lines changed: 88 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -8,69 +8,117 @@
88

99
import ttnn
1010

11-
from math import pi
12-
from tests.ttnn.utils_for_testing import assert_with_pcc
11+
from tests.ttnn.utils_for_testing import assert_with_ulp
12+
13+
14+
def run_lerp_test(
15+
device,
16+
h,
17+
w,
18+
low,
19+
high,
20+
end,
21+
weight,
22+
ttnn_function,
23+
use_scalar_weight=False,
24+
ulp_threshold=1,
25+
input_dtype="bfloat16",
26+
output_dtype=None,
27+
):
28+
torch_input_dtype = getattr(torch, input_dtype)
29+
30+
torch_input_tensor_a = torch.linspace(low, high, steps=h * w, dtype=torch_input_dtype).reshape((h, w))
31+
torch_input_tensor_b = torch.full((h, w), end, dtype=torch_input_dtype)
32+
33+
golden_function = ttnn.get_golden_function(ttnn_function)
34+
35+
if use_scalar_weight:
36+
torch_weight = weight
37+
ttnn_weight = weight
38+
else:
39+
torch_weight = torch.full((h, w), weight, dtype=torch_input_dtype)
40+
ttnn_weight = ttnn.from_torch(torch_weight, layout=ttnn.TILE_LAYOUT, device=device)
1341

42+
input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
43+
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)
1444

15-
def run_lerp_test_float(device, h, w, low, high, end, weight, ttnn_function, torch_function, pcc=0.9999):
16-
torch_input_tensor_a = torch.linspace(low, high, steps=h * w, dtype=torch.bfloat16).reshape((h, w))
17-
torch_input_tensor_b = torch.full((h, w), end, dtype=torch.bfloat16)
45+
calculated_tensor = None
46+
if output_dtype is not None:
47+
torch_dtype = getattr(torch, output_dtype)
48+
ttnn_output_dtype = getattr(ttnn, output_dtype)
49+
torch_input_tensor_a = torch_input_tensor_a.to(torch_dtype)
50+
torch_input_tensor_b = torch_input_tensor_b.to(torch_dtype)
51+
calculated_tensor = ttnn.empty((h, w), dtype=ttnn_output_dtype, layout=ttnn.TILE_LAYOUT, device=device)
1852

19-
torch_output_tensor = torch_function(torch_input_tensor_a, torch_input_tensor_b, weight)
53+
golden_output_tensor = golden_function(
54+
torch_input_tensor_a,
55+
torch_input_tensor_b,
56+
torch_weight,
57+
)
2058

21-
input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
22-
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)
59+
calculated_tensor = ttnn_function(input_tensor_a, input_tensor_b, ttnn_weight, output_tensor=calculated_tensor)
2360

24-
output_tensor = ttnn_function(input_tensor_a, input_tensor_b, weight)
25-
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
26-
output_tensor = ttnn.from_device(output_tensor)
27-
output_tensor = ttnn.to_torch(output_tensor)
61+
if output_dtype is not None:
62+
assert calculated_tensor.dtype == ttnn_output_dtype
2863

29-
assert_with_pcc(torch_output_tensor, output_tensor, pcc)
64+
calculated_tensor = ttnn.to_torch(calculated_tensor)
65+
assert_with_ulp(golden_output_tensor, calculated_tensor, ulp_threshold=ulp_threshold)
3066

3167

3268
@pytest.mark.parametrize("h", [64])
3369
@pytest.mark.parametrize("w", [128])
3470
@pytest.mark.parametrize("weight", [0.5])
35-
def test_lerp_float_a(device, h, w, weight):
36-
run_lerp_test_float(device, h, w, 0, 90, 100, weight, ttnn.lerp, torch.lerp)
71+
@pytest.mark.parametrize("input_dtype", ["bfloat16", "float32"])
72+
def test_lerp_float_a(device, h, w, weight, input_dtype):
73+
run_lerp_test(device, h, w, 0, 90, 100, weight, ttnn.lerp, use_scalar_weight=True, input_dtype=input_dtype)
3774

3875

3976
@pytest.mark.parametrize("h", [64])
4077
@pytest.mark.parametrize("w", [128])
4178
@pytest.mark.parametrize("weight", [0.75])
42-
def test_lerp_float_b(device, h, w, weight):
43-
run_lerp_test_float(device, h, w, 1, 80, 99, weight, ttnn.lerp, torch.lerp, pcc=0.999)
44-
45-
46-
def run_lerp_test_tensor(device, h, w, low, high, end, weight, ttnn_function, torch_function, pcc=0.9999):
47-
torch_input_tensor_a = torch.linspace(low, high, steps=h * w, dtype=torch.bfloat16).reshape((h, w))
48-
torch_input_tensor_b = torch.full((h, w), end, dtype=torch.bfloat16)
49-
torch_weight = torch.full((h, w), weight, dtype=torch.bfloat16)
50-
51-
torch_output_tensor = torch_function(torch_input_tensor_a, torch_input_tensor_b, torch_weight)
52-
53-
input_tensor_a = ttnn.from_torch(torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device)
54-
input_tensor_b = ttnn.from_torch(torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device)
55-
input_weight = ttnn.from_torch(torch_weight, layout=ttnn.TILE_LAYOUT, device=device)
56-
57-
output_tensor = ttnn_function(input_tensor_a, input_tensor_b, input_weight)
58-
output_tensor = ttnn.to_layout(output_tensor, ttnn.ROW_MAJOR_LAYOUT)
59-
output_tensor = ttnn.from_device(output_tensor)
60-
output_tensor = ttnn.to_torch(output_tensor)
61-
62-
assert_with_pcc(torch_output_tensor, output_tensor, pcc)
79+
@pytest.mark.parametrize("input_dtype", ["bfloat16", "float32"])
80+
def test_lerp_float_b(device, h, w, weight, input_dtype):
81+
run_lerp_test(
82+
device, h, w, 1, 80, 99, weight, ttnn.lerp, use_scalar_weight=True, ulp_threshold=2, input_dtype=input_dtype
83+
)
6384

6485

6586
@pytest.mark.parametrize("h", [64])
6687
@pytest.mark.parametrize("w", [128])
6788
@pytest.mark.parametrize("weight", [0.5])
68-
def test_lerp_tensor_a(device, h, w, weight):
69-
run_lerp_test_tensor(device, h, w, 0, 90, 100, weight, ttnn.lerp, torch.lerp)
89+
@pytest.mark.parametrize("input_dtype", ["bfloat16", "float32"])
90+
def test_lerp_tensor_a(device, h, w, weight, input_dtype):
91+
run_lerp_test(device, h, w, 0, 90, 100, weight, ttnn.lerp, use_scalar_weight=False, input_dtype=input_dtype)
7092

7193

7294
@pytest.mark.parametrize("h", [64])
7395
@pytest.mark.parametrize("w", [128])
7496
@pytest.mark.parametrize("weight", [0.75])
75-
def test_lerp_tensor_b(device, h, w, weight):
76-
run_lerp_test_tensor(device, h, w, 1, 80, 99, weight, ttnn.lerp, torch.lerp, pcc=0.999)
97+
@pytest.mark.parametrize("input_dtype", ["bfloat16", "float32"])
98+
def test_lerp_tensor_b(device, h, w, weight, input_dtype):
99+
run_lerp_test(
100+
device, h, w, 1, 80, 99, weight, ttnn.lerp, use_scalar_weight=False, ulp_threshold=2, input_dtype=input_dtype
101+
)
102+
103+
104+
@pytest.mark.parametrize("h", [64])
105+
@pytest.mark.parametrize("w", [9472])
106+
@pytest.mark.parametrize("weight", [0.75])
107+
@pytest.mark.parametrize("input_dtype", ["bfloat16", "float32"])
108+
def test_lerp_fp32_preallocated_output(device, h, w, weight, input_dtype):
109+
"""Lerp with bfloat16 inputs (two tensors + scalar weight) and preallocated float32 output.
110+
Checks that output is correct within 1 ULP for float32."""
111+
run_lerp_test(
112+
device,
113+
h,
114+
w,
115+
1,
116+
80,
117+
99,
118+
weight,
119+
ttnn.lerp,
120+
use_scalar_weight=True,
121+
ulp_threshold=1,
122+
output_dtype="float32",
123+
input_dtype=input_dtype,
124+
)

tests/ttnn/unit_tests/operations/eltwise/test_ternary_composite.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,8 +223,8 @@ def test_lerp_overload_ttnn(input_shapes, value, device):
223223
golden_fn = ttnn.get_golden_function(ttnn.lerp)
224224
golden_tensor = golden_fn(in_data1, in_data2, value)
225225

226-
comp_pass = compare_pcc([output_tensor], [golden_tensor])
227-
assert comp_pass
226+
output_torch = output_tensor.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
227+
assert_with_ulp(golden_tensor, output_torch, ulp_threshold=2)
228228

229229

230230
@pytest.mark.parametrize(
@@ -244,8 +244,8 @@ def test_lerp_ttnn(input_shapes, device):
244244
golden_fn = ttnn.get_golden_function(ttnn.lerp)
245245
golden_tensor = golden_fn(in_data1, in_data2, in_data3)
246246

247-
comp_pass = compare_pcc([output_tensor], [golden_tensor])
248-
assert comp_pass
247+
output_torch = output_tensor.cpu().to(ttnn.ROW_MAJOR_LAYOUT).to_torch()
248+
assert_with_ulp(golden_tensor, output_torch, ulp_threshold=2)
249249

250250

251251
@pytest.mark.parametrize(
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// SPDX-FileCopyrightText: © 2026 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#pragma once
6+
7+
#include "llk_defs.h"
8+
#include "sfpi.h"
9+
#include "ckernel_sfpu_binary.h"
10+
11+
namespace ckernel::sfpu {
12+
13+
template <bool APPROXIMATION_MODE, bool is_fp32_dest_acc_en, DataFormat data_format, int ITERATIONS>
14+
inline void calculate_lerp(
15+
const uint dst_index_in0, // input (start)
16+
const uint dst_index_in1, // end
17+
const uint dst_index_in2, // weight
18+
const uint dst_index_out) {
19+
static_assert(
20+
data_format == DataFormat::Float32 || data_format == DataFormat::Float16_b,
21+
"Unsupported data format for calculate_lerp(). Supported data formats are: Float32, Float16_b.");
22+
23+
// size of each tile in Dest is 64/SFP_DESTREG_STRIDE = 32 rows when using sfpi to load/store
24+
constexpr uint dst_tile_size_sfpi = 32;
25+
// lerp: out = input + weight * (end - input)
26+
#pragma GCC unroll 8
27+
for (int d = 0; d < ITERATIONS; d++) {
28+
sfpi::vFloat in0 = sfpi::dst_reg[dst_index_in0 * dst_tile_size_sfpi];
29+
sfpi::vFloat in1 = sfpi::dst_reg[dst_index_in1 * dst_tile_size_sfpi];
30+
sfpi::vFloat in2 = sfpi::dst_reg[dst_index_in2 * dst_tile_size_sfpi];
31+
sfpi::vFloat result = in0 + in2 * (in1 - in0);
32+
if constexpr (!is_fp32_dest_acc_en) {
33+
result = float32_to_bf16_rne(result);
34+
}
35+
sfpi::dst_reg[dst_index_out * dst_tile_size_sfpi] = result;
36+
sfpi::dst_reg++;
37+
}
38+
}
39+
40+
} // namespace ckernel::sfpu
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// SPDX-FileCopyrightText: © 2026 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#pragma once
6+
7+
#include "llk_math_eltwise_ternary_sfpu_params.h"
8+
#include "ckernel_sfpu_lerp.h"
9+
10+
namespace ckernel {
11+
12+
template <bool APPROXIMATE, bool is_fp32_dest_acc_en, DataFormat data_format, int ITERATIONS = 8>
13+
inline void llk_math_eltwise_ternary_sfpu_lerp(
14+
uint dst_index0, uint dst_index1, uint dst_index2, uint odst, int vector_mode = (int)VectorMode::RC) {
15+
_llk_math_eltwise_ternary_sfpu_params_<APPROXIMATE>(
16+
sfpu::calculate_lerp<APPROXIMATE, is_fp32_dest_acc_en, data_format, ITERATIONS>,
17+
dst_index0,
18+
dst_index1,
19+
dst_index2,
20+
odst,
21+
vector_mode);
22+
}
23+
24+
template <bool APPROXIMATE>
25+
inline void llk_math_eltwise_ternary_sfpu_lerp_init() {
26+
_llk_math_eltwise_ternary_sfpu_init_<SfpuType::lerp>();
27+
}
28+
29+
} // namespace ckernel

tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu_types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,4 +151,5 @@ enum class SfpuType {
151151
unary_max_uint32,
152152
unary_min_uint32,
153153
addcdiv,
154+
lerp,
154155
};
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
// SPDX-FileCopyrightText: © 2026 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#pragma once
6+
7+
#include "llk_defs.h"
8+
#include "sfpi.h"
9+
#include "ckernel_sfpu_binary.h"
10+
11+
namespace ckernel::sfpu {
12+
13+
template <bool APPROXIMATION_MODE, bool is_fp32_dest_acc_en, DataFormat data_format, int ITERATIONS>
14+
inline void calculate_lerp(
15+
const uint dst_index_in0, // input (start)
16+
const uint dst_index_in1, // end
17+
const uint dst_index_in2, // weight
18+
const uint dst_index_out) {
19+
static_assert(
20+
data_format == DataFormat::Float32 || data_format == DataFormat::Float16_b,
21+
"Unsupported data format for calculate_lerp(). Supported data formats are: Float32, Float16_b.");
22+
23+
// size of each tile in Dest is 64/SFP_DESTREG_STRIDE = 32 rows when using sfpi to load/store
24+
constexpr uint dst_tile_size_sfpi = 32;
25+
// lerp: out = input + weight * (end - input)
26+
#pragma GCC unroll 8
27+
for (int d = 0; d < ITERATIONS; d++) {
28+
sfpi::vFloat in0 = sfpi::dst_reg[dst_index_in0 * dst_tile_size_sfpi];
29+
sfpi::vFloat in1 = sfpi::dst_reg[dst_index_in1 * dst_tile_size_sfpi];
30+
sfpi::vFloat in2 = sfpi::dst_reg[dst_index_in2 * dst_tile_size_sfpi];
31+
sfpi::vFloat result = in0 + in2 * (in1 - in0);
32+
if constexpr (!is_fp32_dest_acc_en) {
33+
result = float32_to_bf16_rne(result);
34+
}
35+
sfpi::dst_reg[dst_index_out * dst_tile_size_sfpi] = result;
36+
sfpi::dst_reg++;
37+
}
38+
}
39+
40+
} // namespace ckernel::sfpu
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
// SPDX-FileCopyrightText: © 2026 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#pragma once
6+
7+
#include "llk_math_eltwise_ternary_sfpu_params.h"
8+
#include "ckernel_sfpu_lerp.h"
9+
10+
namespace ckernel {
11+
12+
template <bool APPROXIMATE, bool is_fp32_dest_acc_en, DataFormat data_format, int ITERATIONS = 8>
13+
inline void llk_math_eltwise_ternary_sfpu_lerp(
14+
uint dst_index0, uint dst_index1, uint dst_index2, uint odst, int vector_mode = (int)VectorMode::RC) {
15+
_llk_math_eltwise_ternary_sfpu_params_<APPROXIMATE>(
16+
sfpu::calculate_lerp<APPROXIMATE, is_fp32_dest_acc_en, data_format, ITERATIONS>,
17+
dst_index0,
18+
dst_index1,
19+
dst_index2,
20+
odst,
21+
vector_mode);
22+
}
23+
24+
template <bool APPROXIMATE>
25+
inline void llk_math_eltwise_ternary_sfpu_lerp_init() {
26+
_llk_math_eltwise_ternary_sfpu_init_<SfpuType::lerp>();
27+
}
28+
29+
} // namespace ckernel

tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu_types.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,4 +151,5 @@ enum class SfpuType {
151151
unary_max_uint32,
152152
unary_min_uint32,
153153
addcdiv,
154+
lerp,
154155
};
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
// SPDX-FileCopyrightText: © 2026 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#pragma once
6+
7+
#include "api/compute/common_globals.h"
8+
#ifdef TRISC_MATH
9+
#include "llk_math_eltwise_ternary_sfpu_lerp.h"
10+
#endif
11+
12+
namespace ckernel {
13+
14+
// clang-format off
15+
/**
16+
* Performs elementwise linear interpolation (lerp): out = input + weight * (end - input)
17+
*
18+
* | Argument | Description | Type | Valid Range | Required |
19+
* |----------|------------------------------------------------------------|----------|-------------------------------------------------------|----------|
20+
* | idst0 | Index of the tile in DST register buffer (input/start) | uint32_t | Must be less than the size of the DST register buffer | True |
21+
* | idst1 | Index of the tile in DST register buffer (end) | uint32_t | Must be less than the size of the DST register buffer | True |
22+
* | idst2 | Index of the tile in DST register buffer (weight) | uint32_t | Must be less than the size of the DST register buffer | True |
23+
* | odst | Index of the tile in DST register buffer (output) | uint32_t | Must be less than the size of the DST register buffer | True |
24+
*/
25+
// clang-format on
26+
template <DataFormat data_format>
27+
ALWI void lerp_tile(uint32_t idst0, uint32_t idst1, uint32_t idst2, uint32_t odst) {
28+
MATH((llk_math_eltwise_ternary_sfpu_lerp<APPROX, DST_ACCUM_MODE, data_format>(idst0, idst1, idst2, odst)));
29+
}
30+
31+
/**
32+
* Please refer to documentation for any_init.
33+
*/
34+
ALWI void lerp_tile_init() { MATH((llk_math_eltwise_ternary_sfpu_lerp_init<APPROX>())); }
35+
36+
} // namespace ckernel

ttnn/cpp/ttnn/operations/eltwise/ternary/device/kernels/compute/ternary_sfpu_col_scalar_bcast_tts_tst.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
#include "api/compute/eltwise_unary/eltwise_unary.h"
88
#include "api/compute/eltwise_unary/where.h"
9+
#include "api/compute/eltwise_unary/lerp.h"
910
#include "api/compute/eltwise_unary/fill.h"
1011
#include "ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils_common.hpp"
1112
#include "ttnn/operations/eltwise/binary_ng/device/kernels/compute/eltwise_utils_sfpu.hpp"

0 commit comments

Comments
 (0)