Skip to content

Commit 246e0ae

Browse files
PyOpenCL target: Add, test overflow of large argument counts into SVM struct
Co-authored-by: Matthias Diener <[email protected]>
1 parent 57fda39 commit 246e0ae

File tree

9 files changed

+445
-78
lines changed

9 files changed

+445
-78
lines changed

loopy/codegen/result.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,10 @@
2727
import islpy as isl
2828

2929

30+
if TYPE_CHECKING:
31+
from loopy.codegen import CodeGenerationState
32+
33+
3034
def process_preambles(preambles: Sequence[Tuple[int, str]]) -> Sequence[str]:
3135
seen_preamble_tags = set()
3236
dedup_preambles = []
@@ -171,7 +175,8 @@ def all_code(self):
171175
+ "\n\n"
172176
+ str(self.host_program.ast))
173177

174-
def current_program(self, codegen_state):
178+
def current_program(
179+
self, codegen_state: "CodeGenerationState") -> GeneratedProgram:
175180
if codegen_state.is_generating_device_code:
176181
if self.device_programs:
177182
result = self.device_programs[-1]
@@ -344,13 +349,23 @@ def generate_host_or_device_program(codegen_state, schedule_index):
344349

345350
cur_prog = codegen_result.current_program(codegen_state)
346351
body_ast = cur_prog.ast
347-
fdecl_ast = ast_builder.get_function_declaration(
352+
fdef_preambles, fdecl_ast = ast_builder.get_function_declaration(
348353
codegen_state, codegen_result, schedule_index)
349354

350355
fdef_ast = ast_builder.get_function_definition(
351356
codegen_state, codegen_result,
352357
schedule_index, fdecl_ast, body_ast)
353358

359+
if fdef_preambles:
360+
if codegen_state.is_generating_device_code:
361+
codegen_result = codegen_result.copy(
362+
device_preambles=(
363+
codegen_result.device_preambles + tuple(fdef_preambles)))
364+
else:
365+
codegen_result = codegen_result.copy(
366+
host_preambles=(
367+
codegen_result.host_preambles + tuple(fdef_preambles)))
368+
354369
codegen_result = codegen_result.with_new_program(
355370
codegen_state,
356371
cur_prog.copy(

loopy/target/__init__.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,8 @@ def get_function_definition(
203203
def get_function_declaration(
204204
self, codegen_state: CodeGenerationState,
205205
codegen_result: CodeGenerationResult, schedule_index: int
206-
) -> ASTType:
206+
) -> Tuple[Sequence[Tuple[str, str]], ASTType]:
207+
"""Returns preambles and the AST for the function declaration."""
207208
raise NotImplementedError
208209

209210
def generate_top_of_body(
@@ -293,14 +294,16 @@ def __str__(self):
293294
return ""
294295

295296

296-
class DummyHostASTBuilder(ASTBuilderBase):
297+
class DummyHostASTBuilder(ASTBuilderBase[None]):
297298
def get_function_definition(self, codegen_state, codegen_result,
298299
schedule_index, function_decl, function_body):
299300
return function_body
300301

301-
def get_function_declaration(self, codegen_state, codegen_result,
302-
schedule_index):
303-
return None
302+
def get_function_declaration(
303+
self, codegen_state, codegen_result,
304+
schedule_index,
305+
) -> Tuple[Sequence[Tuple[str, str]], None]:
306+
return [], None
304307

305308
def get_temporary_decls(self, codegen_state, schedule_index):
306309
return []

loopy/target/c/__init__.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
THE SOFTWARE.
2424
"""
2525

26-
from typing import cast, Tuple, Optional
26+
from typing import cast, Tuple, Optional, Sequence
2727
import re
2828

2929
import numpy as np # noqa
@@ -817,8 +817,10 @@ def get_function_definition(
817817
else:
818818
return Collection(result+[Line(), fbody])
819819

820-
def get_function_declaration(self, codegen_state: CodeGenerationState,
821-
codegen_result: CodeGenerationResult, schedule_index: int) -> Generable:
820+
def get_function_declaration(
821+
self, codegen_state: CodeGenerationState,
822+
codegen_result: CodeGenerationResult, schedule_index: int
823+
) -> Tuple[Sequence[Tuple[str, str]], Generable]:
822824
kernel = codegen_state.kernel
823825

824826
assert codegen_state.kernel.linearization is not None
@@ -846,7 +848,7 @@ def get_function_declaration(self, codegen_state: CodeGenerationState,
846848
passed_names = [arg.name for arg in kernel.args]
847849
written_names = kernel.get_written_variables()
848850

849-
return FunctionDeclarationWrapper(
851+
return [], FunctionDeclarationWrapper(
850852
FunctionDeclaration(
851853
name,
852854
[self.arg_to_cgen_declarator(

loopy/target/cuda.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
THE SOFTWARE.
2424
"""
2525

26+
from typing import Tuple, Sequence
27+
2628
import numpy as np
2729
from pymbolic import var
2830
from pytools import memoize_method
29-
from cgen import Declarator, Const
31+
from cgen import Declarator, Const, Generable
3032

3133
from loopy.target.c import CFamilyTarget, CFamilyASTBuilder
3234
from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper
@@ -35,6 +37,8 @@
3537
from loopy.kernel.array import ArrayBase, FixedStrideArrayDimTag, VectorArrayDimTag
3638
from loopy.kernel.data import AddressSpace, ImageArg, ConstantArg, ArrayArg
3739
from loopy.kernel.function_interface import ScalarCallable
40+
from loopy.codegen.result import CodeGenerationResult
41+
from loopy.codegen import CodeGenerationState
3842

3943

4044
# {{{ vector types
@@ -316,9 +320,11 @@ def known_callables(self):
316320

317321
# {{{ top-level codegen
318322

319-
def get_function_declaration(self, codegen_state, codegen_result,
320-
schedule_index):
321-
fdecl = super().get_function_declaration(
323+
def get_function_declaration(
324+
self, codegen_state: CodeGenerationState,
325+
codegen_result: CodeGenerationResult, schedule_index: int
326+
) -> Tuple[Sequence[Tuple[str, str]], Generable]:
327+
preambles, fdecl = super().get_function_declaration(
322328
codegen_state, codegen_result, schedule_index)
323329

324330
from loopy.target.c import FunctionDeclarationWrapper
@@ -348,7 +354,7 @@ def get_function_declaration(self, codegen_state, codegen_result,
348354

349355
fdecl = CudaLaunchBounds(nthreads, fdecl)
350356

351-
return FunctionDeclarationWrapper(fdecl)
357+
return preambles, FunctionDeclarationWrapper(fdecl)
352358

353359
def preamble_generators(self):
354360

loopy/target/ispc.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
"""
2525

2626

27-
from typing import cast, Tuple
27+
from typing import cast, Tuple, Sequence
2828

2929
import numpy as np # noqa
3030
import pymbolic.primitives as p
@@ -202,8 +202,10 @@ def get_dtype_registry(self):
202202
class ISPCASTBuilder(CFamilyASTBuilder):
203203
# {{{ top-level codegen
204204

205-
def get_function_declaration(self, codegen_state: CodeGenerationState,
206-
codegen_result: CodeGenerationResult, schedule_index: int) -> Generable:
205+
def get_function_declaration(
206+
self, codegen_state: CodeGenerationState,
207+
codegen_result: CodeGenerationResult, schedule_index: int
208+
) -> Tuple[Sequence[Tuple[str, str]], Generable]:
207209
name = codegen_result.current_program(codegen_state).name
208210
kernel = codegen_state.kernel
209211

@@ -243,7 +245,7 @@ def get_function_declaration(self, codegen_state: CodeGenerationState,
243245
arg_decls))
244246

245247
from loopy.target.c import FunctionDeclarationWrapper
246-
return FunctionDeclarationWrapper(result)
248+
return [], FunctionDeclarationWrapper(result)
247249

248250
def get_kernel_call(self, codegen_state: CodeGenerationState,
249251
subkernel_name: str,

loopy/target/opencl.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,10 +23,12 @@
2323
THE SOFTWARE.
2424
"""
2525

26+
from typing import Tuple, Sequence
27+
2628
import numpy as np
2729
from pymbolic import var
2830
from pytools import memoize_method
29-
from cgen import Declarator
31+
from cgen import Declarator, Generable
3032

3133
from loopy.target.c import CFamilyTarget, CFamilyASTBuilder
3234
from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper
@@ -36,6 +38,8 @@
3638
from loopy.kernel.array import VectorArrayDimTag, FixedStrideArrayDimTag, ArrayBase
3739
from loopy.kernel.data import AddressSpace, ImageArg, ConstantArg
3840
from loopy.kernel.function_interface import ScalarCallable
41+
from loopy.codegen import CodeGenerationState
42+
from loopy.codegen.result import CodeGenerationResult
3943

4044

4145
# {{{ dtype registry wrappers
@@ -624,20 +628,26 @@ def preamble_generators(self):
624628

625629
# {{{ top-level codegen
626630

627-
def get_function_declaration(self, codegen_state, codegen_result,
628-
schedule_index):
629-
fdecl = super().get_function_declaration(
631+
def get_function_declaration(
632+
self, codegen_state: CodeGenerationState,
633+
codegen_result: CodeGenerationResult, schedule_index: int
634+
) -> Tuple[Sequence[Tuple[str, str]], Generable]:
635+
preambles, fdecl = super().get_function_declaration(
630636
codegen_state, codegen_result, schedule_index)
631637

632638
from loopy.target.c import FunctionDeclarationWrapper
633639
assert isinstance(fdecl, FunctionDeclarationWrapper)
634640
if not codegen_state.is_entrypoint:
635641
# auxiliary kernels need not mention opencl speicific qualifiers
636642
# for a functions signature
637-
return fdecl
643+
return preambles, fdecl
638644

639-
fdecl = fdecl.subdecl
645+
return preambles, FunctionDeclarationWrapper(
646+
self._wrap_kernel_decl(codegen_state, schedule_index, fdecl.subdecl))
640647

648+
def _wrap_kernel_decl(
649+
self, codegen_state: CodeGenerationState, schedule_index: int,
650+
fdecl: Declarator) -> Declarator:
641651
from cgen.opencl import CLKernel, CLRequiredWorkGroupSize
642652
fdecl = CLKernel(fdecl)
643653

@@ -654,7 +664,7 @@ def get_function_declaration(self, codegen_state, codegen_result,
654664

655665
fdecl = CLRequiredWorkGroupSize(local_sizes, fdecl)
656666

657-
return FunctionDeclarationWrapper(fdecl)
667+
return fdecl
658668

659669
def generate_top_of_body(self, codegen_state):
660670
from loopy.kernel.data import ImageArg

0 commit comments

Comments
 (0)