2323THE SOFTWARE.
2424"""
2525
26+ from typing import Tuple , Sequence
27+
2628import numpy as np
2729from pymbolic import var
2830from pytools import memoize_method
29- from cgen import Declarator
31+ from cgen import Declarator , Generable
3032
3133from loopy .target .c import CFamilyTarget , CFamilyASTBuilder
3234from loopy .target .c .codegen .expression import ExpressionToCExpressionMapper
3638from loopy .kernel .array import VectorArrayDimTag , FixedStrideArrayDimTag , ArrayBase
3739from loopy .kernel .data import AddressSpace , ImageArg , ConstantArg
3840from loopy .kernel .function_interface import ScalarCallable
41+ from loopy .codegen import CodeGenerationState
42+ from loopy .codegen .result import CodeGenerationResult
3943
4044
4145# {{{ dtype registry wrappers
@@ -624,20 +628,26 @@ def preamble_generators(self):
624628
625629 # {{{ top-level codegen
626630
627- def get_function_declaration (self , codegen_state , codegen_result ,
628- schedule_index ):
629- fdecl = super ().get_function_declaration (
631+ def get_function_declaration (
632+ self , codegen_state : CodeGenerationState ,
633+ codegen_result : CodeGenerationResult , schedule_index : int
634+ ) -> Tuple [Sequence [Tuple [str , str ]], Generable ]:
635+ preambles , fdecl = super ().get_function_declaration (
630636 codegen_state , codegen_result , schedule_index )
631637
632638 from loopy .target .c import FunctionDeclarationWrapper
633639 assert isinstance (fdecl , FunctionDeclarationWrapper )
634640 if not codegen_state .is_entrypoint :
635641 # auxiliary kernels need not mention opencl speicific qualifiers
636642 # for a functions signature
637- return fdecl
643+ return preambles , fdecl
638644
639- fdecl = fdecl .subdecl
645+ return preambles , FunctionDeclarationWrapper (
646+ self ._wrap_kernel_decl (codegen_state , schedule_index , fdecl .subdecl ))
640647
648+ def _wrap_kernel_decl (
649+ self , codegen_state : CodeGenerationState , schedule_index : int ,
650+ fdecl : Declarator ) -> Declarator :
641651 from cgen .opencl import CLKernel , CLRequiredWorkGroupSize
642652 fdecl = CLKernel (fdecl )
643653
@@ -654,7 +664,7 @@ def get_function_declaration(self, codegen_state, codegen_result,
654664
655665 fdecl = CLRequiredWorkGroupSize (local_sizes , fdecl )
656666
657- return FunctionDeclarationWrapper ( fdecl )
667+ return fdecl
658668
659669 def generate_top_of_body (self , codegen_state ):
660670 from loopy .kernel .data import ImageArg
0 commit comments