Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions loopy/codegen/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -431,7 +431,7 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
if kernel.all_inames():
seen_dtypes.add(kernel.index_dtype)

preambles = list(kernel.preambles)
preambles = kernel.preambles + codegen_result.device_preambles

preamble_info = PreambleInfo(
kernel=kernel,
Expand All @@ -445,7 +445,7 @@ def generate_code_for_a_single_kernel(kernel, callables_table, target,
preamble_generators = (list(kernel.preamble_generators)
+ list(target.get_device_ast_builder().preamble_generators()))
for prea_gen in preamble_generators:
preambles.extend(prea_gen(preamble_info))
preambles = preambles + tuple(prea_gen(preamble_info))

codegen_result = codegen_result.copy(device_preambles=preambles)

Expand Down
44 changes: 37 additions & 7 deletions loopy/codegen/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,17 @@
THE SOFTWARE.
"""

from typing import Any, Sequence, Mapping, Tuple, Optional
from typing import (Any, Sequence, Mapping, Tuple, Optional, TYPE_CHECKING, Union,
Dict, List)
from dataclasses import dataclass, replace

import islpy as isl


if TYPE_CHECKING:
from loopy.codegen import CodeGenerationState


def process_preambles(preambles: Sequence[Tuple[int, str]]) -> Sequence[str]:
seen_preamble_tags = set()
dedup_preambles = []
Expand Down Expand Up @@ -110,8 +115,8 @@ class CodeGenerationResult:
host_program: Optional[GeneratedProgram]
device_programs: Sequence[GeneratedProgram]
implemented_domains: Mapping[str, isl.Set]
host_preambles: Sequence[Tuple[int, str]] = ()
device_preambles: Sequence[Tuple[int, str]] = ()
host_preambles: Sequence[Tuple[str, str]] = ()
device_preambles: Sequence[Tuple[str, str]] = ()

def copy(self, **kwargs: Any) -> "CodeGenerationResult":
return replace(self, **kwargs)
Expand Down Expand Up @@ -170,7 +175,8 @@ def all_code(self):
+ "\n\n"
+ str(self.host_program.ast))

def current_program(self, codegen_state):
def current_program(
self, codegen_state: "CodeGenerationState") -> GeneratedProgram:
if codegen_state.is_generating_device_code:
if self.device_programs:
result = self.device_programs[-1]
Expand Down Expand Up @@ -217,7 +223,10 @@ def with_new_ast(self, codegen_state, new_ast):

# {{{ support code for AST merging

def merge_codegen_results(codegen_state, elements, collapse=True):
def merge_codegen_results(
codegen_state: "CodeGenerationState",
elements: Sequence[Union[CodeGenerationResult, Any]], collapse=True
) -> CodeGenerationResult:
elements = [el for el in elements if el is not None]

if not elements:
Expand All @@ -226,10 +235,16 @@ def merge_codegen_results(codegen_state, elements, collapse=True):
device_programs=[],
implemented_domains={})

# FIXME This is fundamentally broken. What is this even doing?
# I guess partly to blame is the fact that there's an unresolved
# tension between subkernels and callables.
# -AK, 2022-08-28

ast_els = []
new_device_programs = []
new_device_preambles: List[Tuple[str, str]] = []
dev_program_names = set()
implemented_domains = {}
implemented_domains: Dict[str, isl.Set] = {}
codegen_result = None

block_cls = codegen_state.ast_builder.ast_block_class
Expand All @@ -253,6 +268,8 @@ def merge_codegen_results(codegen_state, elements, collapse=True):
new_device_programs.append(dp)
dev_program_names.add(dp.name)

new_device_preambles.extend(el.device_preambles)

cur_ast = el.current_ast(codegen_state)
if (isinstance(cur_ast, block_cls)
and not isinstance(cur_ast, block_scope_cls)):
Expand All @@ -272,9 +289,12 @@ def merge_codegen_results(codegen_state, elements, collapse=True):
if not codegen_state.is_generating_device_code:
kwargs["device_programs"] = new_device_programs

assert codegen_result is not None

return (codegen_result
.with_new_ast(codegen_state, ast)
.copy(
device_preambles=tuple(new_device_preambles),
implemented_domains=implemented_domains,
**kwargs))

Expand Down Expand Up @@ -329,13 +349,23 @@ def generate_host_or_device_program(codegen_state, schedule_index):

cur_prog = codegen_result.current_program(codegen_state)
body_ast = cur_prog.ast
fdecl_ast = ast_builder.get_function_declaration(
fdef_preambles, fdecl_ast = ast_builder.get_function_declaration(
codegen_state, codegen_result, schedule_index)

fdef_ast = ast_builder.get_function_definition(
codegen_state, codegen_result,
schedule_index, fdecl_ast, body_ast)

if fdef_preambles:
if codegen_state.is_generating_device_code:
codegen_result = codegen_result.copy(
device_preambles=(
codegen_result.device_preambles + tuple(fdef_preambles)))
else:
codegen_result = codegen_result.copy(
host_preambles=(
codegen_result.host_preambles + tuple(fdef_preambles)))

codegen_result = codegen_result.with_new_program(
codegen_state,
cur_prog.copy(
Expand Down
17 changes: 12 additions & 5 deletions loopy/target/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,8 @@ def get_function_definition(
def get_function_declaration(
self, codegen_state: CodeGenerationState,
codegen_result: CodeGenerationResult, schedule_index: int
) -> ASTType:
) -> Tuple[Sequence[Tuple[str, str]], ASTType]:
"""Returns preambles and the AST for the function declaration."""
raise NotImplementedError

def generate_top_of_body(
Expand All @@ -224,6 +225,10 @@ def get_kernel_call(self, codegen_state: CodeGenerationState,
def ast_block_class(self):
raise NotImplementedError()

@property
def ast_block_scope_class(self):
raise NotImplementedError()

def get_expression_to_code_mapper(self, codegen_state: CodeGenerationState):
raise NotImplementedError()

Expand Down Expand Up @@ -289,14 +294,16 @@ def __str__(self):
return ""


class DummyHostASTBuilder(ASTBuilderBase):
class DummyHostASTBuilder(ASTBuilderBase[None]):
def get_function_definition(self, codegen_state, codegen_result,
schedule_index, function_decl, function_body):
return function_body

def get_function_declaration(self, codegen_state, codegen_result,
schedule_index):
return None
def get_function_declaration(
self, codegen_state, codegen_result,
schedule_index,
) -> Tuple[Sequence[Tuple[str, str]], None]:
return [], None

def get_temporary_decls(self, codegen_state, schedule_index):
return []
Expand Down
32 changes: 6 additions & 26 deletions loopy/target/c/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
THE SOFTWARE.
"""

from typing import cast, Tuple, Optional
from typing import cast, Tuple, Optional, Sequence
import re

import numpy as np # noqa
Expand Down Expand Up @@ -375,26 +375,6 @@ def map_function_decl_wrapper(self, node, *args, **kwargs):
return FunctionDeclarationWrapper(
self.rec(node.subdecl, *args, **kwargs))


class SubscriptSubsetCounter(IdentityMapper):
def __init__(self, subset_counters):
self.subset_counters = subset_counters


class ASTSubscriptCollector(CASTIdentityMapper):
def __init__(self):
self.subset_counters = {}

def map_expression(self, expr):
from pymbolic.primitives import is_constant
if isinstance(expr, CExpression) or is_constant(expr):
return expr
elif isinstance(expr, str):
return expr
else:
raise LoopyError(
"Unexpected expression type: %s" % type(expr).__name__)

# }}}


Expand Down Expand Up @@ -837,8 +817,10 @@ def get_function_definition(
else:
return Collection(result+[Line(), fbody])

def get_function_declaration(self, codegen_state: CodeGenerationState,
codegen_result: CodeGenerationResult, schedule_index: int) -> Generable:
def get_function_declaration(
self, codegen_state: CodeGenerationState,
codegen_result: CodeGenerationResult, schedule_index: int
) -> Tuple[Sequence[Tuple[str, str]], Generable]:
kernel = codegen_state.kernel

assert codegen_state.kernel.linearization is not None
Expand Down Expand Up @@ -866,7 +848,7 @@ def get_function_declaration(self, codegen_state: CodeGenerationState,
passed_names = [arg.name for arg in kernel.args]
written_names = kernel.get_written_variables()

return FunctionDeclarationWrapper(
return [], FunctionDeclarationWrapper(
FunctionDeclaration(
name,
[self.arg_to_cgen_declarator(
Expand Down Expand Up @@ -1284,8 +1266,6 @@ def emit_if(self, condition_str, ast):
# }}}

def process_ast(self, node):
sc = ASTSubscriptCollector()
sc(node)
return node


Expand Down
16 changes: 11 additions & 5 deletions loopy/target/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
THE SOFTWARE.
"""

from typing import Tuple, Sequence

import numpy as np
from pymbolic import var
from pytools import memoize_method
from cgen import Declarator, Const
from cgen import Declarator, Const, Generable

from loopy.target.c import CFamilyTarget, CFamilyASTBuilder
from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper
Expand All @@ -35,6 +37,8 @@
from loopy.kernel.array import ArrayBase, FixedStrideArrayDimTag, VectorArrayDimTag
from loopy.kernel.data import AddressSpace, ImageArg, ConstantArg, ArrayArg
from loopy.kernel.function_interface import ScalarCallable
from loopy.codegen.result import CodeGenerationResult
from loopy.codegen import CodeGenerationState


# {{{ vector types
Expand Down Expand Up @@ -316,9 +320,11 @@ def known_callables(self):

# {{{ top-level codegen

def get_function_declaration(self, codegen_state, codegen_result,
schedule_index):
fdecl = super().get_function_declaration(
def get_function_declaration(
self, codegen_state: CodeGenerationState,
codegen_result: CodeGenerationResult, schedule_index: int
) -> Tuple[Sequence[Tuple[str, str]], Generable]:
preambles, fdecl = super().get_function_declaration(
codegen_state, codegen_result, schedule_index)

from loopy.target.c import FunctionDeclarationWrapper
Expand Down Expand Up @@ -348,7 +354,7 @@ def get_function_declaration(self, codegen_state, codegen_result,

fdecl = CudaLaunchBounds(nthreads, fdecl)

return FunctionDeclarationWrapper(fdecl)
return preambles, FunctionDeclarationWrapper(fdecl)

def preamble_generators(self):

Expand Down
2 changes: 1 addition & 1 deletion loopy/target/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,7 +594,7 @@ def strify_tuple(t: Optional[Tuple[ExpressionT, ...]]) -> str:
if isinstance(arg, (lp.ArrayArg, lp.ConstantArg)):
args.append(self.get_arg_pass(arg))
else:
args.append("%s" % arg.name)
args.append(arg.name)

gen("")

Expand Down
10 changes: 6 additions & 4 deletions loopy/target/ispc.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
"""


from typing import cast, Tuple
from typing import cast, Tuple, Sequence

import numpy as np # noqa
import pymbolic.primitives as p
Expand Down Expand Up @@ -202,8 +202,10 @@ def get_dtype_registry(self):
class ISPCASTBuilder(CFamilyASTBuilder):
# {{{ top-level codegen

def get_function_declaration(self, codegen_state: CodeGenerationState,
codegen_result: CodeGenerationResult, schedule_index: int) -> Generable:
def get_function_declaration(
self, codegen_state: CodeGenerationState,
codegen_result: CodeGenerationResult, schedule_index: int
) -> Tuple[Sequence[Tuple[str, str]], Generable]:
name = codegen_result.current_program(codegen_state).name
kernel = codegen_state.kernel

Expand Down Expand Up @@ -243,7 +245,7 @@ def get_function_declaration(self, codegen_state: CodeGenerationState,
arg_decls))

from loopy.target.c import FunctionDeclarationWrapper
return FunctionDeclarationWrapper(result)
return [], FunctionDeclarationWrapper(result)

def get_kernel_call(self, codegen_state: CodeGenerationState,
subkernel_name: str,
Expand Down
24 changes: 17 additions & 7 deletions loopy/target/opencl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
THE SOFTWARE.
"""

from typing import Tuple, Sequence

import numpy as np
from pymbolic import var
from pytools import memoize_method
from cgen import Declarator
from cgen import Declarator, Generable

from loopy.target.c import CFamilyTarget, CFamilyASTBuilder
from loopy.target.c.codegen.expression import ExpressionToCExpressionMapper
Expand All @@ -36,6 +38,8 @@
from loopy.kernel.array import VectorArrayDimTag, FixedStrideArrayDimTag, ArrayBase
from loopy.kernel.data import AddressSpace, ImageArg, ConstantArg
from loopy.kernel.function_interface import ScalarCallable
from loopy.codegen import CodeGenerationState
from loopy.codegen.result import CodeGenerationResult


# {{{ dtype registry wrappers
Expand Down Expand Up @@ -624,20 +628,26 @@ def preamble_generators(self):

# {{{ top-level codegen

def get_function_declaration(self, codegen_state, codegen_result,
schedule_index):
fdecl = super().get_function_declaration(
def get_function_declaration(
self, codegen_state: CodeGenerationState,
codegen_result: CodeGenerationResult, schedule_index: int
) -> Tuple[Sequence[Tuple[str, str]], Generable]:
preambles, fdecl = super().get_function_declaration(
codegen_state, codegen_result, schedule_index)

from loopy.target.c import FunctionDeclarationWrapper
assert isinstance(fdecl, FunctionDeclarationWrapper)
if not codegen_state.is_entrypoint:
# auxiliary kernels need not mention opencl speicific qualifiers
# for a functions signature
return fdecl
return preambles, fdecl

fdecl = fdecl.subdecl
return preambles, FunctionDeclarationWrapper(
self._wrap_kernel_decl(codegen_state, schedule_index, fdecl.subdecl))

def _wrap_kernel_decl(
self, codegen_state: CodeGenerationState, schedule_index: int,
fdecl: Declarator) -> Declarator:
from cgen.opencl import CLKernel, CLRequiredWorkGroupSize
fdecl = CLKernel(fdecl)

Expand All @@ -654,7 +664,7 @@ def get_function_declaration(self, codegen_state, codegen_result,

fdecl = CLRequiredWorkGroupSize(local_sizes, fdecl)

return FunctionDeclarationWrapper(fdecl)
return fdecl

def generate_top_of_body(self, codegen_state):
from loopy.kernel.data import ImageArg
Expand Down
Loading