Skip to content

Commit c2ca860

Browse files
committed
Add type check for fused workunits
1 parent 844f9aa commit c2ca860

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

pykokkos/interface/parallel_dispatch.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def handle_args(is_for: bool, *args) -> HandledArgs:
100100

101101
name: Optional[str] = None
102102
policy: Union[ExecutionPolicy, int]
103-
workunit: Callable
103+
workunit: Union[Callable, List[Callable]]
104104
view: Optional[ViewType] = None
105105
initial_value: Union[int, float] = 0
106106

@@ -151,13 +151,19 @@ def handle_args(is_for: bool, *args) -> HandledArgs:
151151
raise TypeError(
152152
f"ERROR: name expected to be type 'str', got '{name}' of type '{type(name)}'"
153153
)
154-
if not isinstance(policy, ExecutionPolicy) and not isinstance(policy, int):
154+
if not (isinstance(policy, ExecutionPolicy) or isinstance(policy, int)):
155155
raise TypeError(
156156
f"ERROR: policy expected to be type 'ExecutionPolicy' or 'int', got '{policy}' of type '{type(policy)}'"
157157
)
158-
if not isinstance(workunit, Callable):
158+
if not (
159+
isinstance(workunit, Callable)
160+
or (
161+
isinstance(workunit, list)
162+
and all(isinstance(w, Callable) for w in workunit)
163+
)
164+
):
159165
raise TypeError(
160-
f"ERROR: workunit expected to be type 'Callable', got '{workunit}' of type '{type(workunit)}'"
166+
f"ERROR: workunit expected to be type 'Callable' or 'List[Callable]', got '{workunit}' of type '{type(workunit)}'"
161167
)
162168

163169
return HandledArgs(name, policy, workunit, view, initial_value)

0 commit comments

Comments
 (0)