Skip to content

Commit 5d61e30

Browse files
committed
fix isnan and isfinite errors
1 parent effcd59 commit 5d61e30

File tree

2 files changed

+56
-0
lines changed

2 files changed

+56
-0
lines changed

pykokkos/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
set_device_id,
1818
)
1919

20+
from pykokkos.lib.ufuncs import _isnan as isnan, _isfinite as isfinite
21+
2022
from pykokkos.lib.info import iinfo, finfo
2123
from pykokkos.lib.create import zeros, zeros_like, ones, ones_like, full, full_like
2224
from pykokkos.lib.manipulate import reshape, ravel, expand_dims

pykokkos/lib/ufuncs.py

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,3 +145,57 @@ def _equal(view1, view2, profiler_name: Optional[str] = None):
145145
view2=view2,
146146
)
147147
return out
148+
149+
def _isnan(view, profiler_name: Optional[str] = None):
150+
dtype = view.dtype
151+
ndims = len(view.shape)
152+
if ndims > 2:
153+
raise NotImplementedError("isnan() ufunc only supports up to 2D views")
154+
out = pk.View([*view.shape], dtype=pk.bool)
155+
if view.shape == ():
156+
tid = 1
157+
else:
158+
tid = view.shape[0]
159+
if view.ndim == 0:
160+
new_view = pk.View([1], dtype=view.dtype)
161+
new_view[0] = view
162+
view = new_view
163+
_ufunc_kernel_dispatcher(
164+
profiler_name=profiler_name,
165+
tid=tid,
166+
dtype=dtype,
167+
ndims=ndims,
168+
op="isnan",
169+
sub_dispatcher=pk.parallel_for,
170+
out=out,
171+
view=view,
172+
)
173+
return out
174+
175+
def _isfinite(view, profiler_name: Optional[str] = None):
176+
dtype = view.dtype
177+
ndims = len(view.shape)
178+
if ndims > 2:
179+
raise NotImplementedError("isfinite() ufunc only supports up to 2D views")
180+
if view.size == 0:
181+
out = pk.View(view.shape, dtype=pk.bool)
182+
return out
183+
out = pk.View([*view.shape], dtype=pk.bool)
184+
if view.shape == ():
185+
new_view = pk.View([1], dtype=dtype)
186+
new_view[:] = view
187+
view = new_view
188+
tid = 1
189+
else:
190+
tid = view.shape[0]
191+
_ufunc_kernel_dispatcher(
192+
profiler_name=profiler_name,
193+
tid=tid,
194+
dtype=dtype,
195+
ndims=ndims,
196+
op="isfinite",
197+
sub_dispatcher=pk.parallel_for,
198+
out=out,
199+
view=view,
200+
)
201+
return out

0 commit comments

Comments
 (0)