Skip to content

Commit b1f6f1a

Browse files
VirdhatchaniKNmuthutt
authored andcommitted
#2902: Fix rsqrt
1 parent 545f0dd commit b1f6f1a

File tree

4 files changed

+23
-21
lines changed

4 files changed

+23
-21
lines changed

tests/tt_eager/python_api_testing/sweep_tests/pytests/tt_dnn/test_eltwise_unary.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,7 @@ def test_run_eltwise_rsqrt_op(
197197
output_mem_config,
198198
):
199199
datagen_func = [
200-
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=1, high=1e8), torch.bfloat16)
200+
generation_funcs.gen_func_with_cast(partial(generation_funcs.gen_rand, low=0, high=1e8), torch.bfloat16)
201201
]
202202
test_args = generation_funcs.gen_default_dtype_layout_device(input_shapes)[0]
203203
test_args["fast_and_approx"] = fast_and_approx

tests/tt_eager/python_api_testing/sweep_tests/reference_eltwise/eltwise_unary_rsqrt.py

Lines changed: 6 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,31 +8,19 @@
88

99
torch.manual_seed(2)
1010

11-
12-
def rsqrt_approx(x, iterations):
13-
# Initial approximation
11+
def rsqrt(x, iterations):
1412
y = 1.0 / x
15-
13+
condition = (x > 0) & (x < 1)
14+
y = torch.where(condition, torch.tensor(1.0), y)
1615
for _ in range(iterations):
1716
y = y * (1.5 - 0.5 * x * y * y) # Newton-Raphson iteration
1817
return y
1918

20-
21-
def rsqrt_accurate(x, iterations):
22-
# Initial approximation
23-
y = 1.0 / x
24-
25-
for _ in range(iterations):
26-
y = y * (1.5 - 0.5 * x * y * y) # Newton-Raphson iteration
27-
return y
28-
29-
30-
n = np.linspace(1, 10, 100)
19+
n = np.linspace(0, 10, 100)
3120
n = torch.from_numpy(n)
3221
lhs = torch.rsqrt(n)
33-
rhs_approx = rsqrt_approx(n, 10)
34-
rhs_accurate = rsqrt_accurate(n, 25)
35-
22+
rhs_approx = rsqrt(n, 10)
23+
rhs_accurate = rsqrt(n, 25)
3624

3725
plt.plot(n, lhs, "-r", label="rsqrt")
3826
plt.plot(n, rhs_accurate, "--g", label="custom rsqrt accurate")

tt_metal/src/ckernels/grayskull/common/inc/ckernel_sfpu.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -268,14 +268,21 @@ inline void calculate_rsqrt()
268268

269269
vFloat in = dst_reg[0];
270270
vFloat result = sfpu_reciprocal<false>(in);
271+
v_if(dst_reg[0] < 1.0f){
272+
result = 1.0f;
273+
}v_endif;
271274

272275
for (int r = 0; r < RECIPROCAL_ITERATIONS; r++)
273276
{
274277
// y = y * (1.5 - 0.5 * x * y * y) Newton's method iteration.
275278
result = result * (1.5F - 0.5F * dst_reg[0] * result * result);
276279
}
277280

278-
dst_reg[0] = result;
281+
v_if(dst_reg[0] == 0.0f){
282+
dst_reg[0] = std::numeric_limits<float>::infinity();
283+
}v_else{
284+
dst_reg[0] = result;
285+
}v_endif;
279286

280287
dst_reg++;
281288

tt_metal/src/ckernels/wormhole_b0/common/inc/ckernel_sfpu.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,14 +190,21 @@ inline void calculate_rsqrt()
190190

191191
vFloat in = dst_reg[0];
192192
vFloat result = sfpu_reciprocal(in);
193+
v_if(dst_reg[0] < 1.0f){
194+
result = 1.0f;
195+
}v_endif;
193196

194197
for (int r = 0; r < RECIPROCAL_ITERATIONS; r++)
195198
{
196199
// y = y * (1.5 - 0.5 * x * y * y) Newton's method iteration.
197200
result = result * (1.5F - 0.5F * dst_reg[0] * result * result);
198201
}
199202

200-
dst_reg[0] = result;
203+
v_if(dst_reg[0] == 0.0f){
204+
dst_reg[0] = std::numeric_limits<float>::infinity();
205+
}v_else{
206+
dst_reg[0] = result;
207+
}v_endif;
201208

202209
dst_reg++;
203210

0 commit comments

Comments
 (0)