Skip to content

Commit 22130a1

Browse files
committed
fix: torch.nanmedian frontend
1 parent 0f5073f commit 22130a1

File tree

2 files changed

+69
-16
lines changed

2 files changed

+69
-16
lines changed

ivy/functional/frontends/torch/reduction_ops.py

Lines changed: 63 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -256,30 +256,78 @@ def nanmean(input, dim=None, keepdim=False, *, dtype=None, out=None):
256256
def nanmedian(input, dim=None, keepdim=False, *, out=None):
257257
if dim is None:
258258
flattened_input = ivy.flatten(input)
259-
sorted_input = ivy.sort(flattened_input)
260-
nonnan_index = int(sorted_input.shape[0] - ivy.isnan(sorted_input).sum())
261-
return sorted_input[(nonnan_index - 1) // 2]
259+
non_nan_mask = ~ivy.isnan(flattened_input)
260+
non_nan_values = flattened_input[non_nan_mask]
261+
262+
if non_nan_values.size == 0:
263+
return ivy.array(float('nan'))
264+
265+
sorted_values = ivy.sort(non_nan_values)
266+
n = sorted_values.shape[0]
267+
if n % 2 == 1:
268+
return sorted_values[n // 2]
269+
else:
270+
return sorted_values[n // 2 - 1]
262271

263272
nanmedian_tuple = namedtuple("nanmedian", ["values", "indices"])
264273

265274
if input.ndim == 0:
266275
result = nanmedian_tuple(input, ivy.array(0))
267276
else:
268-
sorted_indices = ivy.argsort(input, axis=dim)
269-
nonnan_index = (
270-
sorted_indices.shape[dim] - ivy.isnan(input).sum(axis=1) - 1
271-
) // 2
272-
nonnan_index = ivy.expand_dims(nonnan_index, axis=1)
273-
nanmedian_indices = ivy.gather_nd(sorted_indices, nonnan_index, batch_dims=1)
274-
nanmedian_values = ivy.take_along_axis(
275-
input, ivy.expand_dims(nanmedian_indices, axis=dim), dim
276-
).squeeze(axis=dim)
277+
if dim < 0:
278+
dim = input.ndim + dim
279+
280+
input_transposed = ivy.moveaxis(input, dim, -1)
281+
original_shape = list(input_transposed.shape)
282+
283+
reshaped = ivy.reshape(input_transposed, (-1, original_shape[-1]))
284+
285+
median_values = []
286+
median_indices = []
287+
288+
for i in range(reshaped.shape[0]):
289+
row = reshaped[i]
290+
non_nan_mask = ~ivy.isnan(row)
291+
non_nan_values = row[non_nan_mask]
292+
293+
if non_nan_values.size == 0:
294+
median_values.append(float('nan'))
295+
median_indices.append(0)
296+
else:
297+
non_nan_indices = ivy.nonzero(non_nan_mask)[0]
298+
299+
sorted_indices = ivy.argsort(non_nan_values)
300+
n = non_nan_values.shape[0]
301+
302+
if n % 2 == 1:
303+
median_idx = n // 2
304+
median_val = non_nan_values[sorted_indices[median_idx]]
305+
original_idx = non_nan_indices[sorted_indices[median_idx]]
306+
else:
307+
median_idx = n // 2 - 1
308+
median_val = non_nan_values[sorted_indices[median_idx]]
309+
original_idx = non_nan_indices[sorted_indices[median_idx]]
310+
311+
median_values.append(median_val)
312+
median_indices.append(original_idx)
313+
314+
median_values = ivy.array(median_values)
315+
median_indices = ivy.array(median_indices)
316+
317+
result_shape = original_shape[:-1]
318+
if result_shape:
319+
median_values = ivy.reshape(median_values, result_shape)
320+
median_indices = ivy.reshape(median_indices, result_shape)
321+
else:
322+
median_values = median_values[0]
323+
median_indices = median_indices[0]
277324

278325
if keepdim:
279-
nanmedian_values = ivy.expand_dims(nanmedian_values, axis=dim)
280-
nanmedian_indices = ivy.expand_dims(nanmedian_tuple, axis=dim)
326+
median_values = ivy.expand_dims(median_values, axis=dim)
327+
median_indices = ivy.expand_dims(median_indices, axis=dim)
328+
329+
result = nanmedian_tuple(median_values, median_indices)
281330

282-
result = nanmedian_tuple(nanmedian_values, nanmedian_indices)
283331
if out is not None:
284332
ivy.inplace_update(out[0], result.values)
285333
ivy.inplace_update(out[1], result.indices)

ivy_tests/test_ivy/test_frontends/test_torch/test_reduction_ops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -679,10 +679,13 @@ def test_torch_nanmean(
679679
@handle_frontend_test(
680680
fn_tree="torch.nanmedian",
681681
dtype_input_axis=helpers.dtype_values_axis(
682-
available_dtypes=helpers.get_dtypes("numeric"),
682+
available_dtypes=helpers.get_dtypes("float"),
683683
min_num_dims=1,
684684
valid_axis=True,
685685
force_int_axis=True,
686+
min_value=-1e04,
687+
max_value=1e04,
688+
abs_smallest_val=1e-04,
686689
),
687690
keepdim=st.booleans(),
688691
)
@@ -707,6 +710,8 @@ def test_torch_nanmedian(
707710
input=input[0],
708711
dim=dim,
709712
keepdim=keepdim,
713+
atol=1e-02,
714+
rtol=1e-02,
710715
)
711716

712717

0 commit comments

Comments
 (0)