Skip to content

Commit 4ffd828

Browse files
Implement FMOD as LLK op (#37050)
### Ticket #35977 , #33746 ### Problem description Provide context for the problem. ### What's changed - Implement binary FMOD as LLK op as part of Migration - Replaced PCC Checks with ULP / All close check - ULP Testing : [Link](https://docs.google.com/spreadsheets/d/1M-Z-DMHojp6AsCNpBAhljFOmw9t_n3p51agLnGc9hhM/edit?gid=1453420635#gid=1453420635) - Perf Comparison for L1 Interleaved config (Device Kernel Duration) <google-sheets-html-origin><!--td {border: 1px solid #cccccc;}br {mso-data-placement:same-cell;}--> Shape | Main branch DKD (ns) | Proposed branch DKD (ns) | Kernel Duration down by -- | -- | -- | -- Single tile (32, 32) | 19682 | 4614 | 326.57% 8 tile (64, 128) | 27016 | 4811 | 461.55% ### Checklist - [x] [![All post-commit tests](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml/badge.svg?branch=virdhatchani/Migrate_LLK_FMOD)](https://github.com/tenstorrent/tt-metal/actions/workflows/all-post-commit-workflows.yaml?query=branch:virdhatchani/Migrate_LLK_FMOD) - [x] [![Blackhole Post commit](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml/badge.svg?branch=virdhatchani/Migrate_LLK_FMOD)](https://github.com/tenstorrent/tt-metal/actions/workflows/blackhole-post-commit.yaml?query=branch:virdhatchani/Migrate_LLK_FMOD) - Passed as in main <!-- [Main](https://github.com/tenstorrent/tt-metal/actions/runs/22015627012) --> - [x] [![cpp-unit-tests](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml/badge.svg?branch=virdhatchani/Migrate_LLK_FMOD)](https://github.com/tenstorrent/tt-metal/actions/workflows/tt-metal-l2-nightly.yaml?query=branch:virdhatchani/Migrate_LLK_FMOD) - Passed as in main <!-- [Main](https://github.com/tenstorrent/tt-metal/actions/runs/22015629780) --> - [x] New/Existing tests provide coverage for changes #### Model tests - [x] [![(Single) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select.yaml/badge.svg?branch=virdhatchani/Migrate_LLK_FMOD)](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select.yaml?query=branch:virdhatchani/Migrate_LLK_FMOD) - Passed as in main <!-- [Main](https://github.com/tenstorrent/tt-metal/actions/runs/22015805040) --> - [x] [![(Galaxy) Choose your pipeline](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-galaxy.yaml/badge.svg?branch=virdhatchani/Migrate_LLK_FMOD)](https://github.com/tenstorrent/tt-metal/actions/workflows/pipeline-select-galaxy.yaml?query=branch:virdhatchani/Migrate_LLK_FMOD) - Passed as in main <!-- [Main](https://github.com/tenstorrent/tt-metal/actions/runs/22015930568) -->
1 parent 72e5195 commit 4ffd828

File tree

22 files changed

+421
-154
lines changed

22 files changed

+421
-154
lines changed

tests/ttnn/unit_tests/operations/eltwise/test_binary_composite.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -456,8 +456,8 @@ def test_binary_fmod_decimal_ttnn(input_shapes, device):
456456
golden_function = ttnn.get_golden_function(ttnn.fmod)
457457
golden_tensor = golden_function(in_data1, in_data2, device=device)
458458

459-
comp_pass = compare_pcc([output_tensor], [golden_tensor], 0.9999)
460-
assert comp_pass
459+
output_torch = ttnn.to_torch(output_tensor)
460+
assert torch.allclose(output_torch, golden_tensor, rtol=5e-2, atol=1e-5)
461461

462462

463463
@pytest.mark.parametrize(
@@ -476,7 +476,7 @@ def test_fmod_ttnn(input_shapes, device):
476476
golden_function = ttnn.get_golden_function(ttnn.fmod)
477477
golden_tensor = golden_function(in_data1, scalar, device=device)
478478

479-
comp_pass = compare_pcc([output_tensor], [golden_tensor])
479+
comp_pass = assert_with_ulp(golden_tensor, output_tensor, 1)
480480
assert comp_pass, f"Failed for scalar={scalar}"
481481

482482

tests/ttnn/unit_tests/operations/eltwise/test_div_ops.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from models.common.utility_functions import torch_random
1010
from functools import partial
1111
from tests.tt_eager.python_api_testing.sweep_tests.generation_funcs import gen_func_with_cast_tt
12+
from tests.ttnn.utils_for_testing import assert_with_ulp
1213

1314
pytestmark = pytest.mark.use_module_device
1415

@@ -141,8 +142,7 @@ def test_binary_fmod_bf16(
141142
output = ttnn.fmod(input_tensor_a, input_tensor_b)
142143
output = ttnn.to_torch(output)
143144

144-
pcc = ttnn.pearson_correlation_coefficient(torch_output_tensor, output)
145-
assert pcc >= 0.99
145+
assert_with_ulp(torch_output_tensor, output, 1)
146146

147147

148148
# This test was added for #17362

tests/ttnn/unit_tests/operations/eltwise/test_fmod.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import torch
66
import pytest
77
import ttnn
8-
from tests.ttnn.nightly.unit_tests.operations.eltwise.backward.utility_funcs import compare_equal
8+
from tests.ttnn.utils_for_testing import assert_with_ulp
99

1010

1111
@pytest.mark.parametrize(
@@ -44,3 +44,24 @@ def test_fmod_nan(testing_dtype, device):
4444
output_tensor = ttnn.to_torch(tt_result)
4545

4646
assert torch.equal(torch.isnan(golden), torch.isnan(output_tensor))
47+
48+
49+
@pytest.mark.parametrize("dtype", ["bfloat16", "float32"])
50+
def test_fmod_binary_accuracy(device, dtype):
51+
"""Test fmod binary operation with specific values."""
52+
torch_dtype = getattr(torch, dtype)
53+
ttnn_dtype = getattr(ttnn, dtype)
54+
55+
torch_input_a = torch.tensor([[5.0, 7.0, -5.0, -7.0, 3.5, 10.0, 1.5, -1.5, 9.0, 15.0]], dtype=torch_dtype)
56+
torch_input_b = torch.tensor([[2.0, 4.0, 2.0, 4.0, 2.0, 4.0, 0.5, 0.5, -2.0, -4.0]], dtype=torch_dtype)
57+
58+
golden_fn = ttnn.get_golden_function(ttnn.fmod)
59+
golden = golden_fn(torch_input_a, torch_input_b, device=device)
60+
61+
input_tensor_a = ttnn.from_torch(torch_input_a, dtype=ttnn_dtype, layout=ttnn.TILE_LAYOUT, device=device)
62+
input_tensor_b = ttnn.from_torch(torch_input_b, dtype=ttnn_dtype, layout=ttnn.TILE_LAYOUT, device=device)
63+
64+
output = ttnn.fmod(input_tensor_a, input_tensor_b)
65+
output = ttnn.to_torch(output)
66+
67+
assert_with_ulp(golden, output, 1)

tt_metal/hw/CMakeLists.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -611,7 +611,7 @@ target_sources(
611611
inc/api/compute/eltwise_unary/where.h
612612
inc/api/compute/ema.h
613613
inc/api/compute/experimental/mul_reduce_scalar.h
614-
inc/api/compute/fmod_int32.h
614+
inc/api/compute/binary_fmod.h
615615
inc/api/compute/gcd.h
616616
inc/api/compute/layernorm.h
617617
inc/api/compute/lcm.h
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
// SPDX-FileCopyrightText: © 2026 Tenstorrent AI ULC
2+
//
3+
// SPDX-License-Identifier: Apache-2.0
4+
5+
#pragma once
6+
7+
#include "ckernel.h"
8+
#include "ckernel_defs.h"
9+
#include "ckernel_sfpu_remainder_int32.h"
10+
#include "sfpi.h"
11+
12+
namespace ckernel::sfpu {
13+
14+
// FMOD = a - trunc(a / b) * b
15+
// Implemented using 32-bit integer remainder kernel (see ckernel_sfpu_remainder_int32.h)
16+
sfpi_inline void calculate_fmod_int32_body(
17+
const uint dst_index_in0, const uint dst_index_in1, const uint dst_index_out) {
18+
// size of each tile in Dest is 64/SFP_DESTREG_STRIDE = 32 rows when using sfpi to load/store
19+
constexpr uint dst_tile_size_sfpi = 32;
20+
21+
// Read inputs
22+
sfpi::vInt a_signed = sfpi::dst_reg[dst_index_in0 * dst_tile_size_sfpi];
23+
sfpi::vInt b_signed = sfpi::dst_reg[dst_index_in1 * dst_tile_size_sfpi];
24+
25+
// Compute unsigned remainder
26+
sfpi::vInt r = compute_unsigned_remainder_int32(a_signed, b_signed);
27+
28+
// FMOD sign handling (result has the same sign as a)
29+
v_if(a_signed < 0) { r = -r; }
30+
v_endif;
31+
32+
sfpi::dst_reg[dst_index_out * dst_tile_size_sfpi] = r;
33+
}
34+
35+
template <bool is_fp32_dest_acc_en = false>
36+
sfpi_inline sfpi::vFloat _sfpu_binary_fmod_(sfpi::vFloat in0, sfpi::vFloat in1) {
37+
// fmod(a, b) = a - trunc(a/b) * b
38+
39+
sfpi::vFloat a = in0;
40+
sfpi::vFloat b = in1;
41+
sfpi::vFloat b_abs = sfpi::abs(b);
42+
43+
// Compute reciprocal 1/b
44+
sfpi::vFloat recip = ckernel::sfpu::_sfpu_reciprocal_<2>(b);
45+
46+
// Compute a/b = a * (1/b)
47+
sfpi::vFloat div_result = a * recip;
48+
49+
// Compute trunc(a/b)
50+
// Input in LReg0, output in LReg1. LReg2/LReg3 are clobbered by _trunc_body_(),
51+
// so we must read them to inform the SFPI register allocator they are not immediately available.
52+
sfpi::l_reg[sfpi::LRegs::LReg0] = div_result;
53+
_trunc_body_();
54+
sfpi::vFloat trunc_div = sfpi::l_reg[sfpi::LRegs::LReg1];
55+
sfpi::vFloat tmp2 = sfpi::l_reg[sfpi::LRegs::LReg2];
56+
sfpi::vFloat tmp3 = sfpi::l_reg[sfpi::LRegs::LReg3];
57+
58+
// Compute fmod = a - trunc(a/b) * b
59+
sfpi::vFloat result = a - trunc_div * b;
60+
61+
// Post-correction - fmod result must satisfy |result| < |b|
62+
// If |result| >= |b|, the truncation was wrong by 1
63+
sfpi::vFloat result_abs = sfpi::abs(result);
64+
65+
// If result >= b, we truncated too low, add/subtract b to correct
66+
v_if(result_abs >= b_abs) {
67+
// Determine correction direction based on sign of result
68+
v_if(result >= sfpi::vFloat(0.0f)) {
69+
result = result - b_abs; // result was positive and too big
70+
}
71+
v_else {
72+
result = result + b_abs; // result was negative and too big (magnitude)
73+
}
74+
v_endif;
75+
}
76+
v_endif;
77+
78+
// Sign correction - fmod result must have same sign as 'a' (or be zero)
79+
// If a > 0 and result < 0, the truncation was 1 too high, need to add b
80+
// If a < 0 and result > 0, the truncation was 1 too low, need to subtract b
81+
// This fixes cases where a/b ≈ 0.9999999 but rounds to 1 due to reciprocal error
82+
v_if(a >= sfpi::vFloat(0.0f)) {
83+
// a is positive, result should be >= 0
84+
v_if(result < sfpi::vFloat(0.0f)) {
85+
result = result + b_abs; // over-truncated
86+
}
87+
v_endif;
88+
}
89+
v_else {
90+
// a is negative, result should be <= 0
91+
v_if(result > sfpi::vFloat(0.0f)) {
92+
result = result - b_abs; // under-truncated
93+
}
94+
v_endif;
95+
}
96+
v_endif;
97+
98+
// Handle special cases using conditional assignment (NOT early return!)
99+
// When a == b, fmod(a, b) = 0
100+
v_if(a == b) { result = sfpi::vFloat(0.0f); }
101+
v_endif;
102+
103+
// Handle division by zero - return NaN
104+
v_if(b == sfpi::vFloat(0.0f)) { result = sfpi::vFloat(std::numeric_limits<float>::quiet_NaN()); }
105+
v_endif;
106+
107+
if constexpr (!is_fp32_dest_acc_en) {
108+
result = reinterpret<sfpi::vFloat>(sfpi::float_to_fp16b(result, 0));
109+
}
110+
111+
return result;
112+
}
113+
114+
template <bool APPROXIMATION_MODE, int ITERATIONS>
115+
inline void calculate_fmod_int32(const uint dst_index_in0, const uint dst_index_in1, const uint dst_index_out) {
116+
#pragma GCC unroll 8
117+
for (int d = 0; d < ITERATIONS; d++) {
118+
calculate_fmod_int32_body(dst_index_in0, dst_index_in1, dst_index_out);
119+
sfpi::dst_reg++;
120+
}
121+
}
122+
123+
template <bool APPROXIMATION_MODE, int ITERATIONS = 8, bool is_fp32_dest_acc_en = false>
124+
inline void calculate_sfpu_binary_fmod(const uint dst_index_in0, const uint dst_index_in1, const uint dst_index_out) {
125+
for (int d = 0; d < ITERATIONS; d++) {
126+
// size of each tile in Dest is 64/SFP_DESTREG_STRIDE = 32 rows when using sfpi to load/store
127+
constexpr uint dst_tile_size_sfpi = 32;
128+
sfpi::vFloat in0 = sfpi::dst_reg[dst_index_in0 * dst_tile_size_sfpi];
129+
sfpi::vFloat in1 = sfpi::dst_reg[dst_index_in1 * dst_tile_size_sfpi];
130+
131+
sfpi::vFloat result = _sfpu_binary_fmod_<is_fp32_dest_acc_en>(in0, in1);
132+
133+
sfpi::dst_reg[dst_index_out * dst_tile_size_sfpi] = result;
134+
sfpi::dst_reg++;
135+
}
136+
}
137+
138+
template <bool APPROXIMATION_MODE>
139+
inline void fmod_int32_init() {
140+
div_floor_init<APPROXIMATION_MODE>();
141+
}
142+
143+
template <bool APPROXIMATION_MODE>
144+
inline void fmod_binary_init() {
145+
_init_sfpu_reciprocal_<false>();
146+
}
147+
148+
} // namespace ckernel::sfpu

tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_fmod_int32.h

Lines changed: 0 additions & 49 deletions
This file was deleted.

tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_fmod_int32.h renamed to tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/llk_math_eltwise_binary_sfpu_binary_fmod.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#include "llk_math_eltwise_binary_sfpu_init.h"
88
#include "llk_math_eltwise_binary_sfpu_params.h"
9-
#include "ckernel_sfpu_fmod_int32.h"
9+
#include "ckernel_sfpu_binary_fmod.h"
1010

1111
namespace ckernel {
1212

@@ -22,4 +22,20 @@ inline void llk_math_eltwise_binary_sfpu_fmod_int32(
2222
sfpu::calculate_fmod_int32<APPROXIMATE, 8>, dst_index0, dst_index1, odst, vector_mode);
2323
}
2424

25+
template <bool APPROXIMATE>
26+
inline void llk_math_eltwise_binary_sfpu_binary_fmod_init() {
27+
llk_math_eltwise_binary_sfpu_init<SfpuType::unused, APPROXIMATE>(sfpu::fmod_binary_init<APPROXIMATE>);
28+
}
29+
30+
template <bool APPROXIMATE, bool is_fp32_dest_acc_en = false>
31+
inline void llk_math_eltwise_binary_sfpu_binary_fmod(
32+
uint dst_index0, uint32_t dst_index1, uint32_t odst, int vector_mode = VectorMode::RC) {
33+
_llk_math_eltwise_binary_sfpu_params_<APPROXIMATE>(
34+
sfpu::calculate_sfpu_binary_fmod<APPROXIMATE, 8, is_fp32_dest_acc_en>,
35+
dst_index0,
36+
dst_index1,
37+
odst,
38+
vector_mode);
39+
}
40+
2541
} // namespace ckernel

0 commit comments

Comments
 (0)