77from .communication import MPI
88from .dndarray import DNDarray
99from . import factories
10+ from .sanitation import sanitize_in
1011from . import types
1112from . import manipulations
12- from . import sanitation
1313
1414__all__ = ["nonzero" , "where" ]
1515
1616
1717def nonzero (x : DNDarray , as_tuple : bool = True ) -> tuple [DNDarray , ...] | DNDarray :
1818 """
19- Return a Tuple of :class:`~heat.core.dndarray.DNDarray`s, one for each dimension of ``x``,
19+ Return a tuple of :class:`~heat.core.dndarray.DNDarray`s, one for each dimension of ``x``,
2020 containing the indices of the non-zero elements in that dimension. If ``x`` is split then
21- the result is split in the 0th dimension. However, this :class:`~heat.core.dndarray.DNDarray`
21+ the result is split in the first dimension. However, this :class:`~heat.core.dndarray.DNDarray`
2222 can be UNBALANCED as it contains the indices of the non-zero elements on each node.
2323 The values in ``x`` are always tested and returned in row-major, C-style order.
2424 The corresponding non-zero values can be obtained with: ``x[nonzero(x)]``.
@@ -54,29 +54,24 @@ def nonzero(x: DNDarray, as_tuple: bool = True) -> tuple[DNDarray, ...] | DNDarr
5454 >>> y[ht.nonzero(y > 3)]
5555 DNDarray([4, 5, 6, 7, 8, 9], dtype=ht.int64, device=cpu:0, split=0)
5656 """
57- sanitation .sanitize_in (x )
58- local_x = x .larray
57+ sanitize_in (x )
5958
6059 if not x .is_distributed ():
6160 # nonzero indices as tuple
62- nonzero = torch .nonzero (input = local_x , as_tuple = as_tuple )
63- # ensure output split is consistent with distributed execution
64- out_split = 0 if x .split is not None else None
65-
61+ nonzero = torch .nonzero (input = x .larray , as_tuple = as_tuple )
6662 # bookkeeping for final DNDarray construct
6763 if as_tuple :
6864 nonzero = list (nonzero )
6965 for i , nz_tensor in enumerate (nonzero ):
70- nonzero [i ] = factories .array (
71- nz_tensor , split = out_split , device = x .device , comm = x .comm
72- )
66+ nonzero [i ] = factories .array (nz_tensor , device = x .device , comm = x .comm )
7367 return tuple (nonzero )
74- # nonzero indices as single 2D DNDarray
75- return factories .array (nonzero , split = out_split , device = x .device , comm = x .comm )
68+ else :
69+ # nonzero indices as single 2D DNDarray
70+ return factories .array (nonzero , device = x .device , comm = x .comm )
7671
7772 # distributed case
78- lcl_nonzero = torch .nonzero (input = local_x , as_tuple = False )
79- nonzero_size = torch .tensor (lcl_nonzero .shape [0 ], dtype = torch .int64 )
73+ lcl_nonzero = torch .nonzero (input = x . larray , as_tuple = False )
74+ nonzero_size = torch .tensor (lcl_nonzero .shape [0 ], dtype = torch .int64 , device = "cpu" )
8075 nonzero_dtype = types .canonical_heat_type (lcl_nonzero .dtype )
8176
8277 # global nonzero_size
@@ -85,7 +80,33 @@ def nonzero(x: DNDarray, as_tuple: bool = True) -> tuple[DNDarray, ...] | DNDarr
8580 _ , displs = x .counts_displs ()
8681 lcl_nonzero [:, x .split ] += displs [x .comm .rank ]
8782
88- if x .split != 0 :
83+ if x .split == 0 :
84+ # for split=0, the local nonzero indices are already globally ordered along the split axis
85+ if as_tuple : # return indices as tuple of 1D DNDarrays
86+ lcl_nonzero = lcl_nonzero .unbind (dim = 1 )
87+ return tuple (
88+ DNDarray (
89+ nz_tensor ,
90+ gshape = (nonzero_size .item (),),
91+ dtype = nonzero_dtype ,
92+ split = 0 ,
93+ device = x .device ,
94+ comm = x .comm ,
95+ balanced = False ,
96+ )
97+ for nz_tensor in lcl_nonzero
98+ )
99+ else : # return indices as single 2D DNDarray
100+ return DNDarray (
101+ lcl_nonzero ,
102+ gshape = (nonzero_size .item (), x .ndim ),
103+ dtype = nonzero_dtype ,
104+ split = 0 ,
105+ device = x .device ,
106+ comm = x .comm ,
107+ balanced = False ,
108+ )
109+ else :
89110 # construct global 2D DNDarray of nz indices:
90111 shape_2d = (nonzero_size .item (), x .ndim )
91112 global_nonzero = DNDarray (
@@ -100,59 +121,33 @@ def nonzero(x: DNDarray, as_tuple: bool = True) -> tuple[DNDarray, ...] | DNDarr
100121 # vectorized sorting of nz indices along axis 0
101122 global_nonzero .balance_ ()
102123 global_nonzero = manipulations .unique (global_nonzero , axis = 0 )
103- if not as_tuple :
104- # return indices as single 2D DNDarray
105- return global_nonzero
106- # return indices as tuple of 1D DNDarrays
107- lcl_nonzero = global_nonzero .larray .unbind (dim = 1 )
108- return tuple (
109- DNDarray (
110- nz_tensor ,
111- gshape = (nonzero_size .item (),),
112- dtype = nonzero_dtype ,
113- split = 0 ,
114- device = x .device ,
115- comm = x .comm ,
116- balanced = True ,
124+ if as_tuple : # return indices as tuple of 1D DNDarrays
125+ lcl_nonzero = global_nonzero .larray .unbind (dim = 1 )
126+ return tuple (
127+ DNDarray (
128+ nz_tensor ,
129+ gshape = (nonzero_size .item (),),
130+ dtype = nonzero_dtype ,
131+ split = 0 ,
132+ device = x .device ,
133+ comm = x .comm ,
134+ balanced = True ,
135+ )
136+ for nz_tensor in lcl_nonzero
117137 )
118- for nz_tensor in lcl_nonzero
119- )
120-
121- # for split=0, the local nonzero indices are already globally ordered along the split axis
122- if not as_tuple :
123- # return indices as single 2D DNDarray
124- return DNDarray (
125- lcl_nonzero ,
126- gshape = (nonzero_size .item (), x .ndim ),
127- dtype = nonzero_dtype ,
128- split = 0 ,
129- device = x .device ,
130- comm = x .comm ,
131- balanced = False ,
132- )
133- # return indices as tuple of 1D DNDarrays
134- lcl_nonzero = lcl_nonzero .unbind (dim = 1 )
135- return tuple (
136- DNDarray (
137- nz_tensor ,
138- gshape = (nonzero_size .item (),),
139- dtype = nonzero_dtype ,
140- split = 0 ,
141- device = x .device ,
142- comm = x .comm ,
143- balanced = False ,
144- )
145- for nz_tensor in lcl_nonzero
146- )
138+ else : # return indices as single 2D DNDarray
139+ return global_nonzero
147140
148141
149142DNDarray .nonzero = lambda self : nonzero (self , as_tuple = True )
150143DNDarray .nonzero .__doc__ = nonzero .__doc__
151144
152145
153146def where (
154- cond : DNDarray , x : None | int | float | DNDarray = None , y : None | int | float | DNDarray = None
155- ) -> DNDarray | tuple [DNDarray , ...]:
147+ cond : DNDarray ,
148+ x : None | int | float | DNDarray = None ,
149+ y : None | int | float | DNDarray = None ,
150+ ) -> DNDarray :
156151 """
157152 Return a :class:`~heat.core.dndarray.DNDarray` containing elements chosen from ``x`` or ``y`` depending on condition.
158153 Result is a :class:`~heat.core.dndarray.DNDarray` with elements from ``x`` where ``cond`` is True, and from ``y`` elsewhere.
@@ -161,24 +156,38 @@ def where(
161156
162157 Parameters
163158 ----------
164- cond: DNDarray
165- When True, yield ``x``, otherwise yield ``y``.
166- x, y: DNDarray or scalar, optional
167- Values from which to choose. ``x``, ``y`` and ``cond`` must be broadcastable to some shape.
168- If ``x`` and ``y`` are distributed, they must have the same split axis as ``cond``.
159+ cond : DNDarray
160+ Condition of interest, where true yield ``x`` otherwise yield ``y``
161+ x : DNDarray or int or float, optional
162+ Values from which to choose. ``x``, ``y`` and condition need to be broadcastable to some shape.
163+ y : DNDarray or int or float, optional
164+ Values from which to choose. ``x``, ``y`` and condition need to be broadcastable to some shape.
165+
166+ Raises
167+ ------
168+ NotImplementedError
169+ if splits of the two input :class:`~heat.core.dndarray.DNDarray` differ
170+ TypeError
171+ if only x or y is given or both are not DNDarrays or numerical scalars
172+
173+ Notes
174+ -----
175+ When only condition is provided, this function is a shorthand for :func:`nonzero` and the function returns a tuple
176+ of :class:`~heat.core.dndarray.DNDarray`, analogously to ``numpy.where``.
169177
170178 Examples
171179 --------
172180 >>> import heat as ht
173181 >>> x = ht.arange(10, split=0)
174182 >>> ht.where(x < 5, x, 10 * x)
175- DNDarray([ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90], dtype=ht.int64, device=cpu:0, split=0)
176-
177- >>> # Indices retrieval (shorthand for nonzero)
178- >>> y = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=0)
179- >>> ht.where(y > 3)
180- (DNDarray([1, 1, 1, 2, 2, 2], dtype=ht.int64, device=cpu:0, split=0),
181- DNDarray([0, 1, 2, 0, 1, 2], dtype=ht.int64, device=cpu:0, split=0))
183+ DNDarray(MPI-rank: 0, Shape: (10,), Split: 0, Local Shape: (10,), Device: cpu:0, Dtype: int32, Data:
184+ [ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90])
185+ >>> y = ht.array([[0, 1, 2], [0, 2, 4], [0, 3, 6]])
186+ >>> ht.where(y < 4, y, -1)
187+ DNDarray(MPI-rank: 0, Shape: (3, 3), Split: None, Local Shape: (3, 3), Device: cpu:0, Dtype: int64, Data:
188+ [[ 0, 1, 2],
189+ [ 0, 2, -1],
190+ [ 0, 3, -1]])
182191 """
183192 # ---- binary where(cond, x, y) branch ------------------------------------
184193 if cond .split is not None and (isinstance (x , DNDarray ) or isinstance (y , DNDarray )):
@@ -198,10 +207,8 @@ def where(
198207 return cond .dtype (cond == 0 ) * y + cond * x
199208
200209 # ---- where(cond) "indices only" branch ----------------------------------
201- elif x is None and y is None :
202- # nonzero() properly handles all cases
203- nz = nonzero (cond )
204- return nz
210+ elif x is None and y is None : # delegate to nonzero(cond)
211+ return nonzero (cond ) # tuple of DNDarrays, one per dimension
205212
206213 # ---- invalid combinations ----------------------------------------------
207214 else :
0 commit comments