Skip to content

Commit ba5bac7

Browse files
#0: Update test
1 parent 280371f commit ba5bac7

File tree

10 files changed

+64
-43
lines changed

10 files changed

+64
-43
lines changed

tests/ttnn/unit_tests/operations/eltwise/test_binary_fp32.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -194,21 +194,19 @@ def test_squared_sum_fp32_activ(device):
194194
"shape",
195195
[
196196
[1, 1, 16, 16],
197-
[1, 1, 80, 80],
198-
[1, 1, 320, 384],
199197
[1, 3, 320, 384],
200198
],
201199
)
202200
def test_add_fp32_input_activ(device, ttnn_function, shape):
203201
x_torch = torch.ones(shape, dtype=torch.float32) * 2
204202
y_torch = torch.ones(shape, dtype=torch.float32) * 4
205-
z_torch = torch.square(torch.nn.functional.silu(x_torch) + y_torch)
203+
z_torch = torch.pow(torch.nn.functional.silu(x_torch) + y_torch, 4)
206204
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
207205
y_tt = ttnn.from_torch(y_torch, dtype=ttnn.float32, layout=ttnn.TILE_LAYOUT, device=device)
208206
z_tt_add = ttnn.add(
209207
x_tt,
210208
y_tt,
211-
activations=[ttnn.UnaryWithParam(ttnn.UnaryOpType.SQUARE)],
209+
activations=[ttnn.UnaryWithParam(ttnn.UnaryOpType.POWER, 4)],
212210
input_tensor_a_activations=[ttnn.UnaryOpType.SILU],
213211
)
214212
tt_out = ttnn.to_torch(z_tt_add)

tests/ttnn/unit_tests/operations/eltwise/test_unary_pow.py

Lines changed: 34 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,26 @@
1-
# SPDX-FileCopyrightText: © 2024 Tenstorrent Inc.
1+
# SPDX-FileCopyrightText: © 2026 Tenstorrent Inc.
22

33
# SPDX-License-Identifier: Apache-2.0
44

55
import torch
66
import pytest
77
import ttnn
8-
from tests.ttnn.utils_for_testing import assert_with_ulp
9-
from tests.ttnn.unit_tests.operations.eltwise.test_expm1 import flush_subnormal_values
8+
from tests.ttnn.utils_for_testing import (
9+
assert_with_ulp,
10+
generate_all_bfloat16_bitpatterns,
11+
flush_subnormal_values_to_zero,
12+
)
1013

1114

1215
def generate_clean_bf16_tensor(dtype=torch.bfloat16):
13-
all_bitpatterns = torch.arange(0, 2**16, dtype=torch.int32).to(torch.uint16)
14-
input_tensor = all_bitpatterns.view(torch.bfloat16) # 65536 values
15-
fp32 = input_tensor.to(torch.float32)
16+
all_bf16 = generate_all_bfloat16_bitpatterns(torch.bfloat16).flatten()
17+
fp32 = all_bf16.to(torch.float32)
1618

1719
# Remove special values (NaN, -0.0, +inf, -inf, subnormals)
1820
neg_zero_mask = (fp32 == 0.0) & torch.signbit(fp32)
1921
tiny = torch.finfo(torch.bfloat16).tiny # 2**-126
2022
good_mask = torch.isfinite(fp32) & ~neg_zero_mask & (fp32.abs() >= tiny)
21-
fp32 = fp32[good_mask] # 65024 values
23+
fp32 = fp32[good_mask] # ~65024 clean values
2224

2325
return fp32.to(dtype)
2426

@@ -41,7 +43,7 @@ def test_pow_arange_masking(exponent, device):
4143
# Generate all possible bit pattern for bf16
4244
tt_input = generate_clean_bf16_tensor(torch.bfloat16)
4345
# If input is subnormal then we assume hardware will flush it to 0.0
44-
tt_input = flush_subnormal_values(tt_input)
46+
tt_input = flush_subnormal_values_to_zero(tt_input)
4547

4648
tt_in = ttnn.from_torch(
4749
tt_input,
@@ -57,7 +59,29 @@ def test_pow_arange_masking(exponent, device):
5759
tt_result = ttnn.pow(tt_in, exponent)
5860
result = ttnn.to_torch(tt_result)
5961
# If expected output is subnormal then its calculated value should be 0.0 (hardware assumed to flush to 0.0)
60-
result = flush_subnormal_values(result)
61-
golden = flush_subnormal_values(golden)
62+
result = flush_subnormal_values_to_zero(result)
63+
golden = flush_subnormal_values_to_zero(golden)
6264

6365
assert_with_ulp(golden, result, 1, allow_nonfinite=True)
66+
67+
68+
@pytest.mark.parametrize(
69+
"op_type,exponent",
70+
[
71+
(ttnn.UnaryOpType.POWER_ITERATIVE, 0),
72+
(ttnn.UnaryOpType.POWER_ITERATIVE, 2),
73+
(ttnn.UnaryOpType.POWER, 0),
74+
(ttnn.UnaryOpType.POWER, 2),
75+
(ttnn.UnaryOpType.POWER, 1.5),
76+
(ttnn.UnaryOpType.POWER, -1.9),
77+
],
78+
)
79+
def test_power_as_activation(device, op_type, exponent):
80+
x_torch = torch.rand([16, 16], dtype=torch.bfloat16) + 1.5
81+
z_torch = torch.pow(x_torch + x_torch, exponent)
82+
83+
x_tt = ttnn.from_torch(x_torch, dtype=ttnn.bfloat16, layout=ttnn.TILE_LAYOUT, device=device)
84+
z_tt = ttnn.add(x_tt, x_tt, activations=[ttnn.UnaryWithParam(op_type, exponent)])
85+
tt_out = ttnn.to_torch(z_tt)
86+
87+
assert_with_ulp(z_torch, tt_out, 1)

tt_metal/hw/ckernels/blackhole/metal/llk_api/llk_sfpu/ckernel_sfpu_unary_power.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ sfpi_inline sfpi::vFloat _sfpu_unary_power_21f_(sfpi::vFloat base, sfpi::vFloat
5151
sfpi::vFloat x = sfpi::setexp(abs_base, 127); // set exp to exp bias (put base in range of 1-2)
5252

5353
// 3rd order polynomial approx - determined using rminimax over [1,2]
54-
sfpi::vFloat series_result = x * (x * (x * 0x2.44734p-4f - 0xd.e712ap-4f) + 0x2.4f5388p+0f) - 0x1.952992p+0f;
54+
vFloat series_result = PolynomialEvaluator::eval(x, -0x1.952992p+0f, 0x2.4f5388p+0f, -0xd.e712ap-4f, 0x2.44734p-4f);
5555

5656
// Convert exponent to float
5757
sfpi::vInt exp = sfpi::exexp(base);
@@ -192,14 +192,14 @@ inline void calculate_unary_power(const uint32_t exponent) {
192192
*/
193193
template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
194194
inline void calculate_unary_power_iterative(const uint32_t exponent) {
195-
// Old iterative approach for integer exponents 0, 1, 2, 3
196-
// exponent contains IEEE 754 float bits - convert to actual integer
195+
// iterative approach for positive integer exponents
196+
// exponent contains IEEE 754 float bits - convert to integer
197197
const float exp_float = Converter::as_float(exponent);
198198
const uint exp = (uint)exp_float;
199199
#pragma GCC unroll 8
200200
for (int d = 0; d < ITERATIONS; d++) {
201-
vFloat in = sfpi::dst_reg[0];
202-
vFloat result = 1.0f;
201+
sfpi::vFloat in = sfpi::dst_reg[0];
202+
sfpi::vFloat result = 1.0f;
203203
uint e = exp;
204204
while (e > 0) {
205205
if (e & 1) {

tt_metal/hw/ckernels/wormhole_b0/metal/llk_api/llk_sfpu/ckernel_sfpu_unary_power.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ sfpi_inline sfpi::vFloat _sfpu_unary_power_21f_(sfpi::vFloat base, sfpi::vFloat
5252
sfpi::vFloat x = sfpi::setexp(abs_base, 127); // set exp to exp bias (put base in range of 1-2)
5353

5454
// 3rd order polynomial approx - determined using rminimax over [1,2]
55-
sfpi::vFloat series_result = x * (x * (x * 0x2.44734p-4f - 0xd.e712ap-4f) + 0x2.4f5388p+0f) - 0x1.952992p+0f;
55+
vFloat series_result = PolynomialEvaluator::eval(x, -0x1.952992p+0f, 0x2.4f5388p+0f, -0xd.e712ap-4f, 0x2.44734p-4f);
5656

5757
// Convert exponent to float
5858
sfpi::vInt exp = sfpi::exexp(base);
@@ -193,14 +193,14 @@ inline void calculate_unary_power(const uint32_t exponent) {
193193
*/
194194
template <bool APPROXIMATION_MODE, int ITERATIONS = 8>
195195
inline void calculate_unary_power_iterative(const uint32_t exponent) {
196-
// Old iterative approach for integer exponents 0, 1, 2, 3
197-
// exponent contains IEEE 754 float bits - convert to actual integer
196+
// iterative approach for positive integer exponents
197+
// exponent contains IEEE 754 float bits - convert to integer
198198
const float exp_float = Converter::as_float(exponent);
199199
const uint exp = (uint)exp_float;
200200
#pragma GCC unroll 8
201201
for (int d = 0; d < ITERATIONS; d++) {
202-
vFloat in = sfpi::dst_reg[0];
203-
vFloat result = 1.0f;
202+
sfpi::vFloat in = sfpi::dst_reg[0];
203+
sfpi::vFloat result = 1.0f;
204204
uint e = exp;
205205
while (e > 0) {
206206
if (e & 1) {

tt_metal/include/compute_kernel_api.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -348,7 +348,7 @@ ALWI void power_tile_init() { MATH((llk_math_eltwise_unary_sfpu_power_init<APPRO
348348
* acquired state via *acquire_dst* call. This call is blocking and is only
349349
* available on the compute engine.
350350
*
351-
* Note: Uses iterative multiplication for positive integer exponents. Optimal for small exponents (0,1,2,3).
351+
* Note: Unlike power_tile, power_iterative_tile() only supports positive integer scalars. It uses an iterative multiplication loop to compute values, and is faster than power_tile for small exponents (e.g. 1, 2, 3)
352352
*
353353
* Return value: None
354354
*
@@ -358,14 +358,14 @@ ALWI void power_tile_init() { MATH((llk_math_eltwise_unary_sfpu_power_init<APPRO
358358
* | param0 | The exponent as IEEE 754 float bits | uint32_t | Must be a positive integer exponent | True |
359359
*/
360360
// clang-format on
361-
ALWI void power_tile_iterative(uint32_t idst, uint32_t param0) {
361+
ALWI void power_iterative_tile(uint32_t idst, uint32_t param0) {
362362
MATH((llk_math_eltwise_unary_sfpu_power_iterative<APPROX>(idst, param0)));
363363
}
364364

365365
/**
366366
* Please refer to documentation for any_init.
367367
*/
368-
ALWI void power_tile_iterative_init() { MATH((llk_math_eltwise_unary_sfpu_power_iterative_init<APPROX>())); }
368+
ALWI void power_iterative_tile_init() { MATH((llk_math_eltwise_unary_sfpu_power_iterative_init<APPROX>())); }
369369

370370
// clang-format off
371371
// exp2 : y = 2 ^ x ==> [y = exp(x * log(2))]

ttnn/cpp/ttnn/deprecated/tt_dnn/kernels/compute/moreh_common.hpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1029,8 +1029,8 @@ ALWI void power_tile_to_cb(
10291029
copy_tile_init_with_dt(cb_x);
10301030
copy_tile(cb_x, 0, dst0);
10311031

1032-
power_tile_iterative_init();
1033-
power_tile_iterative(dst0, p);
1032+
power_iterative_tile_init();
1033+
power_iterative_tile(dst0, p);
10341034

10351035
if (p_is_negative) {
10361036
recip_tile_init();
@@ -1124,8 +1124,8 @@ ALWI void power_tile_with_abs_x_to_cb(
11241124
abs_tile_init();
11251125
abs_tile(dst0);
11261126

1127-
power_tile_iterative_init();
1128-
power_tile_iterative(dst0, p);
1127+
power_iterative_tile_init();
1128+
power_iterative_tile(dst0, p);
11291129

11301130
if (p_is_negative) {
11311131
recip_tile_init();
@@ -1219,8 +1219,8 @@ ALWI void power_and_recip_tile_to_cb(
12191219
copy_tile_init_with_dt(cb_x);
12201220
copy_tile(cb_x, 0, dst0);
12211221

1222-
power_tile_iterative_init();
1223-
power_tile_iterative(dst0, p);
1222+
power_iterative_tile_init();
1223+
power_iterative_tile(dst0, p);
12241224

12251225
if (p_is_negative) {
12261226
recip_tile_init();

ttnn/cpp/ttnn/operations/eltwise/binary/binary_composite.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ namespace operations::binary {
2929
struct ExecutePower {
3030
static Tensor invoke(
3131
const Tensor& input_tensor,
32-
uint32_t exponent,
32+
int32_t exponent,
3333
const std::optional<MemoryConfig>& output_mem_config = std::nullopt,
3434
const std::optional<Tensor>& optional_output_tensor = std::nullopt);
3535

ttnn/cpp/ttnn/operations/eltwise/binary/binary_nanobind.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2089,7 +2089,7 @@ void bind_power(nb::module_& mod, const binary_operation_t& /*operation*/, const
20892089
ttnn::nanobind_overload_t{
20902090
[](const binary_operation_t& self,
20912091
const Tensor& input_tensor,
2092-
uint32_t exponent,
2092+
int32_t exponent,
20932093
const std::optional<MemoryConfig>& memory_config,
20942094
const std::optional<Tensor>& output_tensor) -> ttnn::Tensor {
20952095
return self(input_tensor, exponent, memory_config, output_tensor);

ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_composite_op.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -809,10 +809,9 @@ Tensor ExecutePower::invoke(
809809
float exponent,
810810
const std::optional<MemoryConfig>& output_mem_config,
811811
const std::optional<Tensor>& output_tensor) {
812-
// For exponents 0, 1, 2, 3: use iterative approach
813-
if (exponent == 0.0f || exponent == 1.0f || exponent == 2.0f || exponent == 3.0f) {
814-
return ttnn::operations::unary::ExecuteUnaryTSVariant<ttnn::operations::unary::UnaryOpType::POWER_ITERATIVE>::
815-
invoke(input_a, exponent, output_mem_config, output_tensor);
812+
float exponent_floor = std::floor(exponent);
813+
if (static_cast<int32_t>(exponent_floor) == exponent) {
814+
return ExecutePower::invoke(input_a, static_cast<int32_t>(exponent), output_mem_config, output_tensor);
816815
}
817816
return ttnn::operations::unary::ExecuteUnaryTSVariant<ttnn::operations::unary::UnaryOpType::POWER>::invoke(
818817
input_a, exponent, output_mem_config, output_tensor);
@@ -821,11 +820,11 @@ Tensor ExecutePower::invoke(
821820
// power - integer exponent
822821
Tensor ExecutePower::invoke(
823822
const Tensor& input,
824-
uint32_t exponent,
823+
int32_t exponent,
825824
const std::optional<MemoryConfig>& output_mem_config,
826825
const std::optional<Tensor>& output_tensor) {
827826
// For exponents 0, 1, 2, 3: use iterative approach
828-
if (exponent <= 3) {
827+
if (exponent == 0 || exponent == 1 || exponent == 2 || exponent == 3) {
829828
return ttnn::operations::unary::ExecuteUnaryTSVariant<ttnn::operations::unary::UnaryOpType::POWER_ITERATIVE>::
830829
invoke(input, exponent, output_mem_config, output_tensor);
831830
}

ttnn/cpp/ttnn/operations/eltwise/unary/common/unary_op_utils.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -152,8 +152,8 @@ std::pair<std::string, std::string> get_op_init_and_func_parameterized(
152152
case UnaryOpType::POWER_ITERATIVE:
153153
// For exponents 0, 1, 2, 3: use iterative approach
154154
return {
155-
"power_tile_iterative_init();",
156-
fmt::format("power_tile_iterative({}, {:#x}u);", idst, std::bit_cast<uint32_t>(param0))};
155+
"power_iterative_tile_init();",
156+
fmt::format("power_iterative_tile({}, {:#x}u);", idst, std::bit_cast<uint32_t>(param0))};
157157
case UnaryOpType::LEAKY_RELU:
158158
return {
159159
"leaky_relu_tile_init();",

0 commit comments

Comments
 (0)