@@ -256,30 +256,78 @@ def nanmean(input, dim=None, keepdim=False, *, dtype=None, out=None):
256256def nanmedian (input , dim = None , keepdim = False , * , out = None ):
257257 if dim is None :
258258 flattened_input = ivy .flatten (input )
259- sorted_input = ivy .sort (flattened_input )
260- nonnan_index = int (sorted_input .shape [0 ] - ivy .isnan (sorted_input ).sum ())
261- return sorted_input [(nonnan_index - 1 ) // 2 ]
259+ non_nan_mask = ~ ivy .isnan (flattened_input )
260+ non_nan_values = flattened_input [non_nan_mask ]
261+
262+ if non_nan_values .size == 0 :
263+ return ivy .array (float ('nan' ))
264+
265+ sorted_values = ivy .sort (non_nan_values )
266+ n = sorted_values .shape [0 ]
267+ if n % 2 == 1 :
268+ return sorted_values [n // 2 ]
269+ else :
270+ return sorted_values [n // 2 - 1 ]
262271
263272 nanmedian_tuple = namedtuple ("nanmedian" , ["values" , "indices" ])
264273
265274 if input .ndim == 0 :
266275 result = nanmedian_tuple (input , ivy .array (0 ))
267276 else :
268- sorted_indices = ivy .argsort (input , axis = dim )
269- nonnan_index = (
270- sorted_indices .shape [dim ] - ivy .isnan (input ).sum (axis = 1 ) - 1
271- ) // 2
272- nonnan_index = ivy .expand_dims (nonnan_index , axis = 1 )
273- nanmedian_indices = ivy .gather_nd (sorted_indices , nonnan_index , batch_dims = 1 )
274- nanmedian_values = ivy .take_along_axis (
275- input , ivy .expand_dims (nanmedian_indices , axis = dim ), dim
276- ).squeeze (axis = dim )
277+ if dim < 0 :
278+ dim = input .ndim + dim
279+
280+ input_transposed = ivy .moveaxis (input , dim , - 1 )
281+ original_shape = list (input_transposed .shape )
282+
283+ reshaped = ivy .reshape (input_transposed , (- 1 , original_shape [- 1 ]))
284+
285+ median_values = []
286+ median_indices = []
287+
288+ for i in range (reshaped .shape [0 ]):
289+ row = reshaped [i ]
290+ non_nan_mask = ~ ivy .isnan (row )
291+ non_nan_values = row [non_nan_mask ]
292+
293+ if non_nan_values .size == 0 :
294+ median_values .append (float ('nan' ))
295+ median_indices .append (0 )
296+ else :
297+ non_nan_indices = ivy .nonzero (non_nan_mask )[0 ]
298+
299+ sorted_indices = ivy .argsort (non_nan_values )
300+ n = non_nan_values .shape [0 ]
301+
302+ if n % 2 == 1 :
303+ median_idx = n // 2
304+ median_val = non_nan_values [sorted_indices [median_idx ]]
305+ original_idx = non_nan_indices [sorted_indices [median_idx ]]
306+ else :
307+ median_idx = n // 2 - 1
308+ median_val = non_nan_values [sorted_indices [median_idx ]]
309+ original_idx = non_nan_indices [sorted_indices [median_idx ]]
310+
311+ median_values .append (median_val )
312+ median_indices .append (original_idx )
313+
314+ median_values = ivy .array (median_values )
315+ median_indices = ivy .array (median_indices )
316+
317+ result_shape = original_shape [:- 1 ]
318+ if result_shape :
319+ median_values = ivy .reshape (median_values , result_shape )
320+ median_indices = ivy .reshape (median_indices , result_shape )
321+ else :
322+ median_values = median_values [0 ]
323+ median_indices = median_indices [0 ]
277324
278325 if keepdim :
279- nanmedian_values = ivy .expand_dims (nanmedian_values , axis = dim )
280- nanmedian_indices = ivy .expand_dims (nanmedian_tuple , axis = dim )
326+ median_values = ivy .expand_dims (median_values , axis = dim )
327+ median_indices = ivy .expand_dims (median_indices , axis = dim )
328+
329+ result = nanmedian_tuple (median_values , median_indices )
281330
282- result = nanmedian_tuple (nanmedian_values , nanmedian_indices )
283331 if out is not None :
284332 ivy .inplace_update (out [0 ], result .values )
285333 ivy .inplace_update (out [1 ], result .indices )
0 commit comments