Skip to content

Commit b4a5329

Browse files
committed
Merge branch 'features/nonzero-updates' into merge_nonzero
2 parents a05e0ea + 8a2292f commit b4a5329

7 files changed

Lines changed: 193 additions & 137 deletions

File tree

.github/workflows/array-api.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ jobs:
1717
with:
1818
path: heat
1919
- name: Setup MPI
20-
uses: mpi4py/setup-mpi@dbbb80b116bea57fc1788daf7dbbf7ab3df3a0f1 # v1.4.2
20+
uses: mpi4py/setup-mpi@f200dce75b64188be849b46657dcf86c721937b2 # v1.4.3
2121
with:
2222
mpi: ${{ matrix.mpi }}
2323
- name: Use Python ${{ matrix.python-version }}

.github/workflows/push_main.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,12 +97,12 @@ jobs:
9797
name: Check REUSE compliance
9898
steps:
9999
- name: Harden Runner
100-
uses: step-security/harden-runner@a5ad31d6a139d249332a2605b85202e8c0b78450 # v2.19.1
100+
uses: step-security/harden-runner@9af89fc71515a100421586dfdb3dc9c984fbf411 # v2.19.4
101101
with:
102102
egress-policy: audit
103103

104104
- name: Checkout
105-
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
105+
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
106106

107107
- name: Setup Python
108108
uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0

.github/workflows/weekly.yml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,11 +30,11 @@ jobs:
3030
uses: actions/checkout@df4cb1c069e1874edd31b4311f1884172cec0e10 # v6.0.3
3131
# Initializes the CodeQL tools for scanning.
3232
- name: Initialize CodeQL
33-
uses: github/codeql-action/init@87557b9c84dde89fdd9b10e88954ac2f4248e463 # v4.36.1
33+
uses: github/codeql-action/init@8aad20d150bbac5944a9f9d289da16a4b0d87c1e # v4.36.2
3434
with:
3535
languages: python
3636
- name: Perform CodeQL Analysis
37-
uses: github/codeql-action/analyze@87557b9c84dde89fdd9b10e88954ac2f4248e463 # v4.36.1
37+
uses: github/codeql-action/analyze@8aad20d150bbac5944a9f9d289da16a4b0d87c1e # v4.36.2
3838
with:
3939
category: "/language:python"
4040
scorecard:
@@ -83,7 +83,7 @@ jobs:
8383

8484
# Upload the results to GitHub's code scanning dashboard.
8585
- name: "Upload to code-scanning"
86-
uses: github/codeql-action/upload-sarif@87557b9c84dde89fdd9b10e88954ac2f4248e463 # v4.36.1
86+
uses: github/codeql-action/upload-sarif@8aad20d150bbac5944a9f9d289da16a4b0d87c1e # v4.36.2
8787
with:
8888
sarif_file: results.sarif
8989
check-links:

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ repos:
2626

2727
- repo: https://github.com/astral-sh/ruff-pre-commit
2828
# Ruff version.
29-
rev: v0.15.16
29+
rev: v0.15.17
3030
hooks:
3131
# Run the linter.
3232
- id: ruff

heat/core/indexing.py

Lines changed: 85 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -7,18 +7,18 @@
77
from .communication import MPI
88
from .dndarray import DNDarray
99
from . import factories
10+
from .sanitation import sanitize_in
1011
from . import types
1112
from . import manipulations
12-
from . import sanitation
1313

1414
__all__ = ["nonzero", "where"]
1515

1616

1717
def nonzero(x: DNDarray, as_tuple: bool = True) -> tuple[DNDarray, ...] | DNDarray:
1818
"""
19-
Return a Tuple of :class:`~heat.core.dndarray.DNDarray`s, one for each dimension of ``x``,
19+
Return a tuple of :class:`~heat.core.dndarray.DNDarray`s, one for each dimension of ``x``,
2020
containing the indices of the non-zero elements in that dimension. If ``x`` is split then
21-
the result is split in the 0th dimension. However, this :class:`~heat.core.dndarray.DNDarray`
21+
the result is split in the first dimension. However, this :class:`~heat.core.dndarray.DNDarray`
2222
can be UNBALANCED as it contains the indices of the non-zero elements on each node.
2323
The values in ``x`` are always tested and returned in row-major, C-style order.
2424
The corresponding non-zero values can be obtained with: ``x[nonzero(x)]``.
@@ -54,29 +54,24 @@ def nonzero(x: DNDarray, as_tuple: bool = True) -> tuple[DNDarray, ...] | DNDarr
5454
>>> y[ht.nonzero(y > 3)]
5555
DNDarray([4, 5, 6, 7, 8, 9], dtype=ht.int64, device=cpu:0, split=0)
5656
"""
57-
sanitation.sanitize_in(x)
58-
local_x = x.larray
57+
sanitize_in(x)
5958

6059
if not x.is_distributed():
6160
# nonzero indices as tuple
62-
nonzero = torch.nonzero(input=local_x, as_tuple=as_tuple)
63-
# ensure output split is consistent with distributed execution
64-
out_split = 0 if x.split is not None else None
65-
61+
nonzero = torch.nonzero(input=x.larray, as_tuple=as_tuple)
6662
# bookkeeping for final DNDarray construct
6763
if as_tuple:
6864
nonzero = list(nonzero)
6965
for i, nz_tensor in enumerate(nonzero):
70-
nonzero[i] = factories.array(
71-
nz_tensor, split=out_split, device=x.device, comm=x.comm
72-
)
66+
nonzero[i] = factories.array(nz_tensor, device=x.device, comm=x.comm)
7367
return tuple(nonzero)
74-
# nonzero indices as single 2D DNDarray
75-
return factories.array(nonzero, split=out_split, device=x.device, comm=x.comm)
68+
else:
69+
# nonzero indices as single 2D DNDarray
70+
return factories.array(nonzero, device=x.device, comm=x.comm)
7671

7772
# distributed case
78-
lcl_nonzero = torch.nonzero(input=local_x, as_tuple=False)
79-
nonzero_size = torch.tensor(lcl_nonzero.shape[0], dtype=torch.int64)
73+
lcl_nonzero = torch.nonzero(input=x.larray, as_tuple=False)
74+
nonzero_size = torch.tensor(lcl_nonzero.shape[0], dtype=torch.int64, device="cpu")
8075
nonzero_dtype = types.canonical_heat_type(lcl_nonzero.dtype)
8176

8277
# global nonzero_size
@@ -85,7 +80,33 @@ def nonzero(x: DNDarray, as_tuple: bool = True) -> tuple[DNDarray, ...] | DNDarr
8580
_, displs = x.counts_displs()
8681
lcl_nonzero[:, x.split] += displs[x.comm.rank]
8782

88-
if x.split != 0:
83+
if x.split == 0:
84+
# for split=0, the local nonzero indices are already globally ordered along the split axis
85+
if as_tuple: # return indices as tuple of 1D DNDarrays
86+
lcl_nonzero = lcl_nonzero.unbind(dim=1)
87+
return tuple(
88+
DNDarray(
89+
nz_tensor,
90+
gshape=(nonzero_size.item(),),
91+
dtype=nonzero_dtype,
92+
split=0,
93+
device=x.device,
94+
comm=x.comm,
95+
balanced=False,
96+
)
97+
for nz_tensor in lcl_nonzero
98+
)
99+
else: # return indices as single 2D DNDarray
100+
return DNDarray(
101+
lcl_nonzero,
102+
gshape=(nonzero_size.item(), x.ndim),
103+
dtype=nonzero_dtype,
104+
split=0,
105+
device=x.device,
106+
comm=x.comm,
107+
balanced=False,
108+
)
109+
else:
89110
# construct global 2D DNDarray of nz indices:
90111
shape_2d = (nonzero_size.item(), x.ndim)
91112
global_nonzero = DNDarray(
@@ -100,59 +121,33 @@ def nonzero(x: DNDarray, as_tuple: bool = True) -> tuple[DNDarray, ...] | DNDarr
100121
# vectorized sorting of nz indices along axis 0
101122
global_nonzero.balance_()
102123
global_nonzero = manipulations.unique(global_nonzero, axis=0)
103-
if not as_tuple:
104-
# return indices as single 2D DNDarray
105-
return global_nonzero
106-
# return indices as tuple of 1D DNDarrays
107-
lcl_nonzero = global_nonzero.larray.unbind(dim=1)
108-
return tuple(
109-
DNDarray(
110-
nz_tensor,
111-
gshape=(nonzero_size.item(),),
112-
dtype=nonzero_dtype,
113-
split=0,
114-
device=x.device,
115-
comm=x.comm,
116-
balanced=True,
124+
if as_tuple: # return indices as tuple of 1D DNDarrays
125+
lcl_nonzero = global_nonzero.larray.unbind(dim=1)
126+
return tuple(
127+
DNDarray(
128+
nz_tensor,
129+
gshape=(nonzero_size.item(),),
130+
dtype=nonzero_dtype,
131+
split=0,
132+
device=x.device,
133+
comm=x.comm,
134+
balanced=True,
135+
)
136+
for nz_tensor in lcl_nonzero
117137
)
118-
for nz_tensor in lcl_nonzero
119-
)
120-
121-
# for split=0, the local nonzero indices are already globally ordered along the split axis
122-
if not as_tuple:
123-
# return indices as single 2D DNDarray
124-
return DNDarray(
125-
lcl_nonzero,
126-
gshape=(nonzero_size.item(), x.ndim),
127-
dtype=nonzero_dtype,
128-
split=0,
129-
device=x.device,
130-
comm=x.comm,
131-
balanced=False,
132-
)
133-
# return indices as tuple of 1D DNDarrays
134-
lcl_nonzero = lcl_nonzero.unbind(dim=1)
135-
return tuple(
136-
DNDarray(
137-
nz_tensor,
138-
gshape=(nonzero_size.item(),),
139-
dtype=nonzero_dtype,
140-
split=0,
141-
device=x.device,
142-
comm=x.comm,
143-
balanced=False,
144-
)
145-
for nz_tensor in lcl_nonzero
146-
)
138+
else: # return indices as single 2D DNDarray
139+
return global_nonzero
147140

148141

149142
DNDarray.nonzero = lambda self: nonzero(self, as_tuple=True)
150143
DNDarray.nonzero.__doc__ = nonzero.__doc__
151144

152145

153146
def where(
154-
cond: DNDarray, x: None | int | float | DNDarray = None, y: None | int | float | DNDarray = None
155-
) -> DNDarray | tuple[DNDarray, ...]:
147+
cond: DNDarray,
148+
x: None | int | float | DNDarray = None,
149+
y: None | int | float | DNDarray = None,
150+
) -> DNDarray:
156151
"""
157152
Return a :class:`~heat.core.dndarray.DNDarray` containing elements chosen from ``x`` or ``y`` depending on condition.
158153
Result is a :class:`~heat.core.dndarray.DNDarray` with elements from ``x`` where ``cond`` is True, and from ``y`` elsewhere.
@@ -161,24 +156,38 @@ def where(
161156
162157
Parameters
163158
----------
164-
cond: DNDarray
165-
When True, yield ``x``, otherwise yield ``y``.
166-
x, y: DNDarray or scalar, optional
167-
Values from which to choose. ``x``, ``y`` and ``cond`` must be broadcastable to some shape.
168-
If ``x`` and ``y`` are distributed, they must have the same split axis as ``cond``.
159+
cond : DNDarray
160+
Condition of interest, where true yield ``x`` otherwise yield ``y``
161+
x : DNDarray or int or float, optional
162+
Values from which to choose. ``x``, ``y`` and condition need to be broadcastable to some shape.
163+
y : DNDarray or int or float, optional
164+
Values from which to choose. ``x``, ``y`` and condition need to be broadcastable to some shape.
165+
166+
Raises
167+
------
168+
NotImplementedError
169+
if splits of the two input :class:`~heat.core.dndarray.DNDarray` differ
170+
TypeError
171+
if only x or y is given or both are not DNDarrays or numerical scalars
172+
173+
Notes
174+
-----
175+
When only condition is provided, this function is a shorthand for :func:`nonzero` and the function returns a tuple
176+
of :class:`~heat.core.dndarray.DNDarray`, analogously to ``numpy.where``.
169177
170178
Examples
171179
--------
172180
>>> import heat as ht
173181
>>> x = ht.arange(10, split=0)
174182
>>> ht.where(x < 5, x, 10 * x)
175-
DNDarray([ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90], dtype=ht.int64, device=cpu:0, split=0)
176-
177-
>>> # Indices retrieval (shorthand for nonzero)
178-
>>> y = ht.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]], split=0)
179-
>>> ht.where(y > 3)
180-
(DNDarray([1, 1, 1, 2, 2, 2], dtype=ht.int64, device=cpu:0, split=0),
181-
DNDarray([0, 1, 2, 0, 1, 2], dtype=ht.int64, device=cpu:0, split=0))
183+
DNDarray(MPI-rank: 0, Shape: (10,), Split: 0, Local Shape: (10,), Device: cpu:0, Dtype: int32, Data:
184+
[ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90])
185+
>>> y = ht.array([[0, 1, 2], [0, 2, 4], [0, 3, 6]])
186+
>>> ht.where(y < 4, y, -1)
187+
DNDarray(MPI-rank: 0, Shape: (3, 3), Split: None, Local Shape: (3, 3), Device: cpu:0, Dtype: int64, Data:
188+
[[ 0, 1, 2],
189+
[ 0, 2, -1],
190+
[ 0, 3, -1]])
182191
"""
183192
# ---- binary where(cond, x, y) branch ------------------------------------
184193
if cond.split is not None and (isinstance(x, DNDarray) or isinstance(y, DNDarray)):
@@ -198,10 +207,8 @@ def where(
198207
return cond.dtype(cond == 0) * y + cond * x
199208

200209
# ---- where(cond) "indices only" branch ----------------------------------
201-
elif x is None and y is None:
202-
# nonzero() properly handles all cases
203-
nz = nonzero(cond)
204-
return nz
210+
elif x is None and y is None: # delegate to nonzero(cond)
211+
return nonzero(cond) # tuple of DNDarrays, one per dimension
205212

206213
# ---- invalid combinations ----------------------------------------------
207214
else:

heat/core/linalg/eigh.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ def _subspaceiteration(
7575
device=columnnorms.device,
7676
)
7777
* statistics.percentile(columnnorms, 100.0 * (1 - (k + safetyparam) / columnnorms.shape[0]))
78-
)
78+
)[0]
7979
X = C[:, idx].balance()
8080

8181
# actual subspace iteration

0 commit comments

Comments
 (0)