13
13
from pyop2 .datatypes import as_numpy_dtype
14
14
from pyop2 .exceptions import KernelTypeError , MapValueError , SetTypeError
15
15
from pyop2 .global_kernel import (GlobalKernelArg , DatKernelArg , MixedDatKernelArg ,
16
- MatKernelArg , MixedMatKernelArg , GlobalKernel )
16
+ MatKernelArg , MixedMatKernelArg , PassthroughKernelArg , GlobalKernel )
17
17
from pyop2 .local_kernel import LocalKernel , CStringLocalKernel , LoopyLocalKernel
18
18
from pyop2 .types import (Access , Global , AbstractDat , Dat , DatView , MixedDat , Mat , Set ,
19
19
MixedSet , ExtrudedSet , Subset , Map , ComposedMap , MixedMap )
@@ -39,6 +39,10 @@ class GlobalParloopArg(ParloopArg):
39
39
40
40
data : Global
41
41
42
+ @property
43
+ def _kernel_args_ (self ):
44
+ return self .data ._kernel_args_
45
+
42
46
@property
43
47
def map_kernel_args (self ):
44
48
return ()
@@ -59,6 +63,10 @@ def __post_init__(self):
59
63
if self .map_ is not None :
60
64
self .check_map (self .map_ )
61
65
66
+ @property
67
+ def _kernel_args_ (self ):
68
+ return self .data ._kernel_args_
69
+
62
70
@property
63
71
def map_kernel_args (self ):
64
72
return self .map_ ._kernel_args_ if self .map_ else ()
@@ -81,6 +89,10 @@ class MixedDatParloopArg(ParloopArg):
81
89
def __post_init__ (self ):
82
90
self .check_map (self .map_ )
83
91
92
+ @property
93
+ def _kernel_args_ (self ):
94
+ return self .data ._kernel_args_
95
+
84
96
@property
85
97
def map_kernel_args (self ):
86
98
return self .map_ ._kernel_args_ if self .map_ else ()
@@ -102,6 +114,10 @@ def __post_init__(self):
102
114
for m in self .maps :
103
115
self .check_map (m )
104
116
117
+ @property
118
+ def _kernel_args_ (self ):
119
+ return self .data ._kernel_args_
120
+
105
121
@property
106
122
def map_kernel_args (self ):
107
123
rmap , cmap = self .maps
@@ -120,12 +136,34 @@ def __post_init__(self):
120
136
for m in self .maps :
121
137
self .check_map (m )
122
138
139
+ @property
140
+ def _kernel_args_ (self ):
141
+ return self .data ._kernel_args_
142
+
123
143
@property
124
144
def map_kernel_args (self ):
125
145
rmap , cmap = self .maps
126
146
return tuple (itertools .chain (* itertools .product (rmap ._kernel_args_ , cmap ._kernel_args_ )))
127
147
128
148
149
+ @dataclass
150
+ class PassthroughParloopArg (ParloopArg ):
151
+ # a pointer
152
+ data : int
153
+
154
+ @property
155
+ def _kernel_args_ (self ):
156
+ return (self .data ,)
157
+
158
+ @property
159
+ def map_kernel_args (self ):
160
+ return ()
161
+
162
+ @property
163
+ def maps (self ):
164
+ return ()
165
+
166
+
129
167
class Parloop :
130
168
"""A parallel loop invocation.
131
169
@@ -167,7 +205,7 @@ def arglist(self):
167
205
"""Prepare the argument list for calling generated code."""
168
206
arglist = self .iterset ._kernel_args_
169
207
for d in self .arguments :
170
- arglist += d .data . _kernel_args_
208
+ arglist += d ._kernel_args_
171
209
172
210
# Collect an ordered set of maps (ignore duplicates)
173
211
maps = {m : None for d in self .arguments for m in d .map_kernel_args }
@@ -224,6 +262,8 @@ def __call__(self):
224
262
def increment_dat_version (self ):
225
263
"""Increment dat versions of :class:`DataCarrier`s in the arguments."""
226
264
for lk_arg , gk_arg , pl_arg in self .zipped_arguments :
265
+ if isinstance (pl_arg , PassthroughParloopArg ):
266
+ continue
227
267
assert isinstance (pl_arg .data , DataCarrier )
228
268
if lk_arg .access is not Access .READ :
229
269
if pl_arg .data in self .reduced_globals :
@@ -520,6 +560,10 @@ class GlobalLegacyArg(LegacyArg):
520
560
data : Global
521
561
access : Access
522
562
563
+ @property
564
+ def dtype (self ):
565
+ return self .data .dtype
566
+
523
567
@property
524
568
def global_kernel_arg (self ):
525
569
return GlobalKernelArg (self .data .dim )
@@ -537,6 +581,10 @@ class DatLegacyArg(LegacyArg):
537
581
map_ : Optional [Map ]
538
582
access : Access
539
583
584
+ @property
585
+ def dtype (self ):
586
+ return self .data .dtype
587
+
540
588
@property
541
589
def global_kernel_arg (self ):
542
590
map_arg = self .map_ ._global_kernel_arg if self .map_ is not None else None
@@ -556,6 +604,10 @@ class MixedDatLegacyArg(LegacyArg):
556
604
map_ : MixedMap
557
605
access : Access
558
606
607
+ @property
608
+ def dtype (self ):
609
+ return self .data .dtype
610
+
559
611
@property
560
612
def global_kernel_arg (self ):
561
613
args = []
@@ -579,6 +631,10 @@ class MatLegacyArg(LegacyArg):
579
631
lgmaps : Optional [Tuple [Any , Any ]] = None
580
632
needs_unrolling : Optional [bool ] = False
581
633
634
+ @property
635
+ def dtype (self ):
636
+ return self .data .dtype
637
+
582
638
@property
583
639
def global_kernel_arg (self ):
584
640
map_args = [m ._global_kernel_arg for m in self .maps ]
@@ -599,6 +655,10 @@ class MixedMatLegacyArg(LegacyArg):
599
655
lgmaps : Tuple [Any ] = None
600
656
needs_unrolling : Optional [bool ] = False
601
657
658
+ @property
659
+ def dtype (self ):
660
+ return self .data .dtype
661
+
602
662
@property
603
663
def global_kernel_arg (self ):
604
664
nrows , ncols = self .data .sparsity .shape
@@ -618,6 +678,28 @@ def parloop_arg(self):
618
678
return MixedMatParloopArg (self .data , tuple (self .maps ), self .lgmaps )
619
679
620
680
681
+ @dataclass
682
+ class PassthroughArg (LegacyArg ):
683
+ """Argument that is simply passed to the local kernel without packing.
684
+
685
+ :param dtype: The datatype of the argument. This is needed for code generation.
686
+ :param data: A pointer to the data.
687
+ """
688
+ # We don't know what the local kernel is doing with this argument
689
+ access = Access .RW
690
+
691
+ dtype : Any
692
+ data : int
693
+
694
+ @property
695
+ def global_kernel_arg (self ):
696
+ return PassthroughKernelArg ()
697
+
698
+ @property
699
+ def parloop_arg (self ):
700
+ return PassthroughParloopArg (self .data )
701
+
702
+
621
703
def ParLoop (* args , ** kwargs ):
622
704
return LegacyParloop (* args , ** kwargs )
623
705
@@ -641,7 +723,7 @@ def LegacyParloop(local_knl, iterset, *args, **kwargs):
641
723
# finish building the local kernel
642
724
local_knl .accesses = tuple (a .access for a in args )
643
725
if isinstance (local_knl , CStringLocalKernel ):
644
- local_knl .dtypes = tuple (a .data . dtype for a in args )
726
+ local_knl .dtypes = tuple (a .dtype for a in args )
645
727
646
728
global_knl_args = tuple (a .global_kernel_arg for a in args )
647
729
extruded = iterset ._extruded
0 commit comments