Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 13 additions & 16 deletions heat/core/indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@

def nonzero(x: DNDarray) -> DNDarray:
"""
Return a :class:`~heat.core.dndarray.DNDarray` containing the indices of the elements that are non-zero.. (using ``torch.nonzero``)
If ``x`` is split then the result is split in the 0th dimension. However, this :class:`~heat.core.dndarray.DNDarray`
Return a :class:`~heat.core.dndarray.DNDarray` containing the indices of the elements that are non-zero (using ``torch.nonzero``).
If ``x`` is split then the result is split in the first dimension. However, this :class:`~heat.core.dndarray.DNDarray`
can be UNBALANCED as it contains the indices of the non-zero elements on each node.
Returns an array with one entry for each dimension of ``x``, containing the indices of the non-zero elements in that dimension.
The values in ``x`` are always tested and returned in row-major, C-style order.
Expand Down Expand Up @@ -53,26 +53,23 @@ def nonzero(x: DNDarray) -> DNDarray:
"""
sanitation.sanitize_in(x)

if x.split is None:
# if there is no split then just return the values from torch
lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False)
gout = list(lcl_nonzero.size())
is_split = None
else:
# a is split
lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False)
lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False)

# add offsets mapping from local indices to global indices if x is split
if x.split is not None:
_, _, slices = x.comm.chunk(x.shape, x.split)
lcl_nonzero[..., x.split] += slices[x.split].start
gout = list(lcl_nonzero.size())
gout[0] = x.comm.allreduce(gout[0], MPI.SUM)
is_split = 0

if x.ndim == 1:
lcl_nonzero = lcl_nonzero.squeeze(dim=1)

for g in range(len(gout) - 1, -1, -1):
if gout[g] == 1 and len(gout) > 1:
del gout[g]
# compute global shape of the index array
gout = list(lcl_nonzero.shape)
if x.split is None:
is_split = None
else:
gout[0] = x.comm.allreduce(gout[0], MPI.SUM)
is_split = 0

return DNDarray(
lcl_nonzero,
Expand Down
11 changes: 11 additions & 0 deletions heat/core/tests/test_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,17 @@ def test_nonzero(self):
a[nz] = 10.0
self.assertEqual(ht.all(a[nz] == 10), 1)

# edge case: single non-zero element
for split in [None, 0, 1]:
a = ht.zeros((4, 3), dtype=ht.bool, split=split)
a[1, 2] = True
nz = ht.indexing.nonzero(a)
a.resplit_(None)
nz.resplit_(None)
self.assertEqual(nz.gshape, (1, 2))
self.assertTrue(ht.allclose(a[nz], a[a]))


def test_where(self):
# cases to test
# no x and y
Expand Down