From 6230ae53b109f6941976d64062df3c4f9382cbbf Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Fri, 20 Feb 2026 09:31:50 +0100 Subject: [PATCH 1/3] Refactor and bugfix in `heat.indexing.nonzero` --- heat/core/indexing.py | 29 +++++++++++++---------------- 1 file changed, 13 insertions(+), 16 deletions(-) diff --git a/heat/core/indexing.py b/heat/core/indexing.py index 27564ef02c..916aa450df 100644 --- a/heat/core/indexing.py +++ b/heat/core/indexing.py @@ -15,8 +15,8 @@ def nonzero(x: DNDarray) -> DNDarray: """ - Return a :class:`~heat.core.dndarray.DNDarray` containing the indices of the elements that are non-zero.. (using ``torch.nonzero``) - If ``x`` is split then the result is split in the 0th dimension. However, this :class:`~heat.core.dndarray.DNDarray` + Return a :class:`~heat.core.dndarray.DNDarray` containing the indices of the elements that are non-zero (using ``torch.nonzero``). + If ``x`` is split then the result is split in the first dimension. However, this :class:`~heat.core.dndarray.DNDarray` can be UNBALANCED as it contains the indices of the non-zero elements on each node. Returns an array with one entry for each dimension of ``x``, containing the indices of the non-zero elements in that dimension. The values in ``x`` are always tested and returned in row-major, C-style order. @@ -53,26 +53,23 @@ def nonzero(x: DNDarray) -> DNDarray: """ sanitation.sanitize_in(x) - if x.split is None: - # if there is no split then just return the values from torch - lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False) - gout = list(lcl_nonzero.size()) - is_split = None - else: - # a is split - lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False) + lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False) + + # add offsets mapping from local indices to global indices if x is split + if x.split is not None: _, _, slices = x.comm.chunk(x.shape, x.split) lcl_nonzero[..., x.split] += slices[x.split].start - gout = list(lcl_nonzero.size()) - gout[0] = x.comm.allreduce(gout[0], MPI.SUM) - is_split = 0 if x.ndim == 1: lcl_nonzero = lcl_nonzero.squeeze(dim=1) - for g in range(len(gout) - 1, -1, -1): - if gout[g] == 1 and len(gout) > 1: - del gout[g] + # compute global shape of the index array + gout = list(lcl_nonzero.shape) + if x.split is None: + is_split = None + else: + gout[0] = x.comm.allreduce(gout[0], MPI.SUM) + is_split = 0 return DNDarray( lcl_nonzero, From 0af251af8e44edcf81062289bf5893e1fed2800b Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Fri, 20 Feb 2026 10:10:53 +0100 Subject: [PATCH 2/3] Added test for edge case in `heat.indexing.nonzero` --- heat/core/tests/test_indexing.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/heat/core/tests/test_indexing.py b/heat/core/tests/test_indexing.py index 4707aa28ab..f759969874 100644 --- a/heat/core/tests/test_indexing.py +++ b/heat/core/tests/test_indexing.py @@ -23,6 +23,15 @@ def test_nonzero(self): a[nz] = 10.0 self.assertEqual(ht.all(a[nz] == 10), 1) + # edge case: single non-zero element + for split in [None, 1]: + a = ht.zeros((4, 3), dtype=ht.bool, split=split) + a[1, 2] = True + nz = ht.indexing.nonzero(a) + self.assertEqual(nz.gshape, (1, 2)) + self.assertTrue(ht.allclose(a[nz], a[a])) + + def test_where(self): # cases to test # no x and y From c1e8c2b33dd66a259d03e64e2d555c9d18816cdb Mon Sep 17 00:00:00 2001 From: Thomas Baumann <39156931+brownbaerchen@users.noreply.github.com> Date: Tue, 3 Mar 2026 12:33:13 +0100 Subject: [PATCH 3/3] Modifying test --- heat/core/tests/test_indexing.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/heat/core/tests/test_indexing.py b/heat/core/tests/test_indexing.py index f759969874..0bd1c9307d 100644 --- a/heat/core/tests/test_indexing.py +++ b/heat/core/tests/test_indexing.py @@ -24,10 +24,12 @@ def test_nonzero(self): self.assertEqual(ht.all(a[nz] == 10), 1) # edge case: single non-zero element - for split in [None, 1]: + for split in [None, 0, 1]: a = ht.zeros((4, 3), dtype=ht.bool, split=split) a[1, 2] = True nz = ht.indexing.nonzero(a) + a.resplit_(None) + nz.resplit_(None) self.assertEqual(nz.gshape, (1, 2)) self.assertTrue(ht.allclose(a[nz], a[a]))