diff --git a/loopy/__init__.py b/loopy/__init__.py index de50eb2d3..f74ce826d 100644 --- a/loopy/__init__.py +++ b/loopy/__init__.py @@ -120,6 +120,7 @@ from loopy.target.execution import ExecutorBase from loopy.target.ispc import ISPCTarget from loopy.target.opencl import OpenCLTarget +from loopy.target.pycuda import PyCudaTarget, PyCudaWithPackedArgsTarget from loopy.target.pyopencl import PyOpenCLTarget from loopy.tools import Optional, clear_in_mem_caches, memoize_on_disk, t_unit_to_python from loopy.transform.add_barrier import add_barrier @@ -151,6 +152,7 @@ tag_array_axes, tag_data_axes, ) +from loopy.transform.domain import decouple_domain from loopy.transform.fusion import fuse_kernels from loopy.transform.iname import ( add_inames_for_unused_hw_axes, @@ -206,6 +208,10 @@ unprivatize_temporaries_with_inames, ) from loopy.transform.realize_reduction import realize_reduction +from loopy.transform.reduction import ( + hoist_invariant_multiplicative_terms_in_sum_reduction, + extract_multiplicative_terms_in_sum_reduction_as_subst) +from loopy.transform.reindex import reindex_temporary_using_seghir_loechner_scheme from loopy.transform.save import save_and_reload_temporaries from loopy.transform.subst import ( assignment_to_subst, @@ -277,6 +283,8 @@ "Options", "OrderedAtomic", "PreambleInfo", + "PyCudaTarget", + "PyCudaWithPackedArgsTarget", "PyOpenCLTarget", "Reduction", "ScalarCallable", @@ -317,8 +325,10 @@ "clear_in_mem_caches", "collect_common_factors_on_increment", "concatenate_arrays", + "decouple_domain", "duplicate_inames", "expand_subst", + "extract_multiplicative_terms_in_sum_reduction_as_subst", "extract_subst", "find_instructions", "find_most_recent_global_barrier", @@ -349,6 +359,7 @@ "get_subkernels", "get_synchronization_map", "has_schedulable_iname_nesting", + "hoist_invariant_multiplicative_terms_in_sum_reduction", "infer_arg_descr", "infer_unknown_types", "inline_callable_kernel", @@ -378,6 +389,7 @@ "register_preamble_generators", "register_reduction_parser", "register_symbol_manglers", + "reindex_temporary_using_seghir_loechner_scheme", "remove_inames_from_insn", "remove_instructions", "remove_predicates_from_insn", @@ -416,6 +428,14 @@ "untag_inames", ] +try: + import loopy.relations as relations +except ImportError: + # catching ImportErrors to avoid making minikanren a hard-dep + pass +else: + __all__ += ["relations"] + # }}} diff --git a/loopy/codegen/loop.py b/loopy/codegen/loop.py index 6a6ada7b5..2e0e0afa7 100644 --- a/loopy/codegen/loop.py +++ b/loopy/codegen/loop.py @@ -29,10 +29,12 @@ import islpy as isl from islpy import dim_type from pymbolic.mapper.stringifier import PREC_NONE +from typing import FrozenSet from loopy.codegen.control import build_loop_nest from loopy.codegen.result import CodeGenerationResult, merge_codegen_results from loopy.diagnostic import LoopyError, warn +from loopy.kernel import LoopKernel from loopy.symbolic import flatten @@ -369,6 +371,16 @@ def set_up_hw_parallel_loops( # {{{ sequential loop +def _get_intersecting_inames(kernel: LoopKernel, iname: str) -> FrozenSet[str]: + from functools import reduce + return reduce(frozenset.union, + ((kernel.id_to_insn[insn].within_inames + | kernel.id_to_insn[insn].reduction_inames() + | kernel.id_to_insn[insn].sub_array_ref_inames()) + for insn in kernel.iname_to_insns()[iname]), + frozenset()) + + def generate_sequential_loop_dim_code( codegen_state: CodeGenerationState, sched_index: int, @@ -386,8 +398,18 @@ def generate_sequential_loop_dim_code( from loopy.codegen.bounds import get_usable_inames_for_conditional # Note: this does not include loop_iname itself! + + # usable_inames = get_usable_inames_for_conditional( + # kernel, sched_index, codegen_state.codegen_cachemanager) + + # # get rid of disjoint loop nests, see + # # + # usable_inames = usable_inames & _get_intersecting_inames(kernel, + # loop_iname) + # ======= usable_inames = get_usable_inames_for_conditional(kernel, sched_index, codegen_state.codegen_cache_manager) + # >>>>>>> main domain = kernel.get_inames_domain(loop_iname) diff --git a/loopy/kernel/tools.py b/loopy/kernel/tools.py index 484071f92..b184a3ac7 100644 --- a/loopy/kernel/tools.py +++ b/loopy/kernel/tools.py @@ -2218,6 +2218,65 @@ def get_outer_params(domains): # }}} +# {{{ get access map from an instruction + + +class _IndexCollector(CombineMapper): + def __init__(self, var): + self.var = var + super().__init__() + + def combine(self, values): + import operator + return reduce(operator.or_, values, frozenset()) + + def map_subscript(self, expr): + if expr.aggregate.name == self.var: + return (super().map_subscript(expr) | frozenset([expr.index_tuple])) + else: + return super().map_subscript(expr) + + def map_algebraic_leaf(self, expr): + return frozenset() + + map_constant = map_algebraic_leaf + + +def _project_out_inames_from_maps(amaps, inames_to_project_out): + new_amaps = [] + for amap in amaps: + for iname in inames_to_project_out: + dt, pos = amap.get_var_dict()[iname] + amap = amap.project_out(dt, pos, 1) + + new_amaps.append(amap) + + return new_amaps + + +def _union_amaps(amaps): + import islpy as isl + return reduce(isl.Map.union, amaps[1:], amaps[0]) + + +def get_insn_access_map(kernel, insn_id, var): + from loopy.transform.subst import expand_subst + from loopy.symbolic import get_access_map + + insn = kernel.id_to_insn[insn_id] + + kernel = expand_subst(kernel) + indices = list(_IndexCollector(var)((insn.expression, + insn.assignees, + tuple(insn.predicates)))) + + amaps = [get_access_map(kernel.get_inames_domain(insn.within_inames), + idx, kernel.assumptions) for idx in indices] + + return _union_amaps(amaps) + +# }}} + def get_hw_axis_base_for_codegen(kernel: LoopKernel, iname: str) -> isl.Aff: """ diff --git a/loopy/relations.py b/loopy/relations.py new file mode 100644 index 000000000..5d47bfa1d --- /dev/null +++ b/loopy/relations.py @@ -0,0 +1,122 @@ +from kanren import Relation, facts + + +def get_inameo(kernel): + inameo = Relation() + for iname in kernel.all_inames(): + facts(inameo, (iname,)) + return inameo + + +def get_argo(kernel): + argo = Relation() + for arg in kernel.args: + facts(argo, (arg.name,)) + + return argo + + +def get_tempo(kernel): + tempo = Relation() + for tv in kernel.temporary_variables: + facts(tempo, (tv,)) + + return tempo + + +def get_insno(kernel): + insno = Relation() + for insn in kernel.instructions: + facts(insno, (insn.id,)) + + return insno + + +def get_taggedo(kernel): + taggedo = Relation() + + for arg_name, arg in kernel.arg_dict.items(): + for tag in arg.tags: + facts(taggedo, (arg_name, tag)) + + for iname_name, iname in kernel.inames.items(): + for tag in iname.tags: + facts(taggedo, (iname_name, tag)) + + for insn in kernel.instructions: + for tag in insn.tags: + facts(taggedo, (insn.id, tag)) + + return taggedo + + +def get_taggedo_of_type(kernel, tag_type): + taggedo = Relation() + + for arg_name, arg in kernel.arg_dict.items(): + for tag in arg.tags_of_type(tag_type): + facts(taggedo, (arg_name, tag)) + + for iname_name, iname in kernel.inames.items(): + for tag in iname.tags_of_type(tag_type): + facts(taggedo, (iname_name, tag)) + + for insn in kernel.instructions: + for tag in insn.tags_of_type(tag_type): + facts(taggedo, (insn.id, tag)) + + return taggedo + + +def get_producero(kernel): + producero = Relation() + + for insn in kernel.instructions: + for var in insn.assignee_var_names(): + facts(producero, (insn.id, var)) + + return producero + + +def get_consumero(kernel): + consumero = Relation() + + for insn in kernel.instructions: + for var in insn.read_dependency_names(): + facts(consumero, (insn.id, var)) + + return consumero + + +def get_withino(kernel): + withino = Relation() + + for insn in kernel.instructions: + facts(withino, (insn.id, insn.within_inames)) + + return withino + + +def get_reduce_insno(kernel): + reduce_insno = Relation() + + for insn in kernel.instructions: + if insn.reduction_inames(): + facts(reduce_insno, (insn.id,)) + + return reduce_insno + + +def get_reduce_inameo(kernel): + from functools import reduce + reduce_inameo = Relation() + + for iname in reduce(frozenset.union, + (insn.reduction_inames() + for insn in kernel.instructions), + frozenset()): + facts(reduce_inameo, (iname,)) + + return reduce_inameo + +# vim: fdm=marker diff --git a/loopy/schedule/__init__.py b/loopy/schedule/__init__.py index bdb804b16..b6e627788 100644 --- a/loopy/schedule/__init__.py +++ b/loopy/schedule/__init__.py @@ -1036,6 +1036,162 @@ def key(x: ScheduleItem) -> tuple[str, ...]: # {{{ legacy scheduling algorithm +def _get_outermost_diverging_inames(tree, within1, within2): + """ + For loop nestings *within1* and *within2*, returns the first inames at which + the loops nests diverge in the loop nesting tree *tree*. + + :arg tree: A :class:`loopy.tools.Tree` of inames, denoting a loop nesting. + :arg within1: A :class:`frozenset` of inames. + :arg within2: A :class:`frozenset` of inames. + """ + common_ancestors = (within1 & within2) | {""} + + innermost_parent = max(common_ancestors, + key=lambda k: tree.depth(k)) + iname1, = frozenset(tree.children(innermost_parent)) & within1 + iname2, = frozenset(tree.children(innermost_parent)) & within2 + + return iname1, iname2 + + +class V2SchedulerNotImplementedException(RuntimeError): + pass + + +def generate_loop_schedules_v2(kernel): + # from loopy.schedule.tools import get_loop_nest_tree + from loopy.schedule.tools import get_loop_tree + from functools import reduce + from pytools.graph import compute_topological_order + from loopy.kernel.data import ConcurrentTag, IlpBaseTag, VectorizeTag + + concurrent_inames = {iname for iname in kernel.all_inames() + if kernel.iname_tags_of_type(iname, ConcurrentTag)} + ilp_inames = {iname for iname in kernel.all_inames() + if kernel.iname_tags_of_type(iname, IlpBaseTag)} + vec_inames = {iname for iname in kernel.all_inames() + if kernel.iname_tags_of_type(iname, VectorizeTag)} + parallel_inames = (concurrent_inames - ilp_inames - vec_inames) + + # {{{ can v2 scheduler handle?? + + if any(len(insn.conflicts_with_groups) != 0 for insn in kernel.instructions): + raise V2SchedulerNotImplementedException("v2 scheduler cannot schedule" + " kernels with instruction having conflicts with groups.") + + if any(insn.priority != 0 for insn in kernel.instructions): + raise V2SchedulerNotImplementedException("v2 scheduler cannot schedule" + " kernels with instruction priorities set.") + + if kernel.linearization is not None: + # cannnot handle preschedule yet + raise V2SchedulerNotImplementedException("v2 scheduler cannot schedule" + " prescheduled kernels.") + + if ilp_inames or vec_inames: + raise V2SchedulerNotImplementedException("v2 scheduler cannot schedule" + " loops tagged with 'ilp'/'vec' as they are not guaranteed to" + " be single entry loops.") + + # }}} + + # loop_nest_tree = get_loop_nest_tree(kernel) + loop_nest_tree = get_loop_tree(kernel) + + # loop_inames: inames that are realized as loops. Concurrent inames aren't + # realized as a loop in the generated code for a loopy.TargetBase. + loop_inames = (reduce(frozenset.union, (insn.within_inames + for insn in kernel.instructions), + frozenset()) + - parallel_inames) + + # The idea here is to build a DAG, where nodes are schedule items and if + # there exists an edge from schedule item A to schedule item B in the DAG => + # B *must* come after A in the linearized result. + + dag = {} + + # LeaveLoop(i) *must* follow EnterLoop(i) + dag.update({EnterLoop(iname=iname): frozenset({LeaveLoop(iname=iname)}) + for iname in loop_inames}) + dag.update({LeaveLoop(iname=iname): frozenset() + for iname in loop_inames}) + dag.update({RunInstruction(insn_id=insn.id): frozenset() + for insn in kernel.instructions}) + + # {{{ add constraints imposed by the loop nesting + + for outer_loop in loop_nest_tree.nodes(): + if outer_loop == "": + continue + + for child in loop_nest_tree.children(outer_loop): + inner_loop = child + dag[EnterLoop(iname=outer_loop)] |= {EnterLoop(iname=inner_loop)} + dag[LeaveLoop(iname=inner_loop)] |= {LeaveLoop(iname=outer_loop)} + + # }}} + + # {{{ add deps. b/w schedule items coming from insn. depepdencies + + for insn in kernel.instructions: + insn_loop_inames = insn.within_inames & loop_inames + for dep_id in insn.depends_on: + dep = kernel.id_to_insn[dep_id] + dep_loop_inames = dep.within_inames & loop_inames + # Enforce instruction dep: + dag[RunInstruction(insn_id=dep_id)] |= {RunInstruction(insn_id=insn.id)} + + # {{{ register deps on loop entry/leave because of insn. deps + + if dep_loop_inames < insn_loop_inames: + for iname in insn_loop_inames - dep_loop_inames: + dag[RunInstruction(insn_id=dep.id)] |= {EnterLoop(iname=iname)} + elif insn_loop_inames < dep_loop_inames: + for iname in dep_loop_inames - insn_loop_inames: + dag[LeaveLoop(iname=iname)] |= {RunInstruction(insn_id=insn.id)} + elif dep_loop_inames != insn_loop_inames: + insn_iname, dep_iname = _get_outermost_diverging_inames( + loop_nest_tree, insn_loop_inames, dep_loop_inames) + dag[LeaveLoop(iname=dep_iname)] |= {EnterLoop(iname=insn_iname)} + else: + pass + + # }}} + + for iname in insn_loop_inames: + # For an insn within a loop nest 'i' + # for i + # insn + # end i + # 'insn' *must* come b/w 'for i' and 'end i' + dag[EnterLoop(iname=iname)] |= {RunInstruction(insn_id=insn.id)} + dag[RunInstruction(insn_id=insn.id)] |= {LeaveLoop(iname=iname)} + + # }}} + + def iname_key(iname): + all_ancestors = sorted(loop_nest_tree.ancestors(iname), + key=lambda x: loop_nest_tree.depth(x)) + return ",".join(all_ancestors+[iname]) + + def key(x): + if isinstance(x, RunInstruction): + iname = max((kernel.id_to_insn[x.insn_id].within_inames & loop_inames), + key=lambda k: loop_nest_tree.depth(k), + default="") + result = (iname_key(iname), x.insn_id) + elif isinstance(x, (EnterLoop, LeaveLoop)): + result = (iname_key(x.iname),) + else: + raise NotImplementedError + + return result + + return compute_topological_order(dag, key=key) + + def _generate_loop_schedules_internal( sched_state: SchedulerState, debug: ScheduleDebugger | None = None, @@ -1092,7 +1248,7 @@ def _generate_loop_schedules_internal( if isinstance(next_preschedule_item, CallKernel): assert sched_state.within_subkernel is False - yield from _generate_loop_schedules_internal( + yield from generate_loop_schedules_internal( sched_state.copy( schedule=(*sched_state.schedule, next_preschedule_item), preschedule=sched_state.preschedule[1:], @@ -1105,7 +1261,7 @@ def _generate_loop_schedules_internal( assert sched_state.within_subkernel is True # Make sure all subkernel inames have finished. if sched_state.active_inames == sched_state.enclosing_subkernel_inames: - yield from _generate_loop_schedules_internal( + yield from generate_loop_schedules_internal( sched_state.copy( schedule=(*sched_state.schedule, next_preschedule_item), preschedule=sched_state.preschedule[1:], @@ -1124,7 +1280,7 @@ def _generate_loop_schedules_internal( if ( isinstance(next_preschedule_item, Barrier) and next_preschedule_item.originating_insn_id is None): - yield from _generate_loop_schedules_internal( + yield from generate_loop_schedules_internal( sched_state.copy( schedule=(*sched_state.schedule, next_preschedule_item), preschedule=sched_state.preschedule[1:]), @@ -1301,7 +1457,7 @@ def insn_sort_key(insn_id: InsnId): # Don't be eager about entering/leaving loops--if progress has been # made, revert to top of scheduler and see if more progress can be # made. - for sub_sched in _generate_loop_schedules_internal( + for sub_sched in generate_loop_schedules_internal( new_sched_state, debug=debug): yield sub_sched @@ -1397,7 +1553,7 @@ def insn_sort_key(insn_id: InsnId): if can_leave and not debug_mode: - for sub_sched in _generate_loop_schedules_internal( + for sub_sched in generate_loop_schedules_internal( sched_state.copy( schedule=( (*sched_state.schedule, @@ -1607,7 +1763,7 @@ def insn_sort_key(insn_id: InsnId): key_iname), reverse=True): - for sub_sched in _generate_loop_schedules_internal( + for sub_sched in generate_loop_schedules_internal( sched_state.copy( schedule=( (*sched_state.schedule, EnterLoop(iname=iname))), @@ -2226,14 +2382,16 @@ def _generate_loop_schedules_inner( if debug_args is None: debug_args = {} + debug = ScheduleDebugger(**debug_args) + from loopy.kernel import KernelState if kernel.state not in (KernelState.PREPROCESSED, KernelState.LINEARIZED): raise LoopyError("cannot schedule a kernel that has not been " - "preprocessed") + "preprocessed") from loopy.schedule.tools import V2SchedulerNotImplementedError try: - gen_sched = _generate_loop_schedules_v2(kernel) + gen_sched = generate_loop_schedules_v2(kernel) yield _postprocess_schedule(kernel, callables_table, gen_sched) return @@ -2246,8 +2404,6 @@ def _generate_loop_schedules_inner( schedule_count = 0 - debug = ScheduleDebugger(**debug_args) - preschedule = (kernel.linearization if kernel.state == KernelState.LINEARIZED else ()) @@ -2343,7 +2499,7 @@ def print_longest_dead_end(): debug.debug_length = len(debug.longest_rejected_schedule) while True: try: - for _ in _generate_loop_schedules_internal( + for _ in generate_loop_schedules_internal( sched_state, debug=debug, **schedule_gen_kwargs): pass @@ -2354,7 +2510,7 @@ def print_longest_dead_end(): break try: - for gen_sched in _generate_loop_schedules_internal( + for gen_sched in generate_loop_schedules_internal( sched_state, debug=debug, **schedule_gen_kwargs): debug.stop() diff --git a/loopy/schedule/tools.py b/loopy/schedule/tools.py index 02c5c892f..b97ba59da 100644 --- a/loopy/schedule/tools.py +++ b/loopy/schedule/tools.py @@ -70,6 +70,7 @@ from constantdict import constantdict import islpy as isl + from pytools import memoize_method, memoize_on_first_arg from loopy.diagnostic import LoopyError diff --git a/loopy/symbolic.py b/loopy/symbolic.py index 89e8922dc..3148ccba8 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -31,6 +31,7 @@ from dataclasses import dataclass, replace from functools import cached_property, reduce from sys import intern + from typing import ( TYPE_CHECKING, Any, @@ -38,9 +39,15 @@ Concatenate, Generic, Literal as LiteralT, + Mapping, + Sequence, + Tuple, + TypeAlias, TypeVar, + Union, cast, ) + from warnings import warn import numpy as np @@ -226,6 +233,26 @@ def map_type_annotation( return type(expr)(expr.type, new_child) + def map_type_cast(self, expr, *args, **kwargs): + return self.rec(expr.child, *args, **kwargs) + + def map_sub_array_ref(self, expr, *args, **kwargs): + return self.combine(( + self.rec(expr.subscript, *args, **kwargs), + self.combine(tuple( + self.rec(idx, *args, **kwargs) + for idx in expr.swept_inames)))) + + # def map_sub_array_ref(self, expr, *args, **kwargs): + # new_inames = self.rec(expr.swept_inames, *args, **kwargs) + # new_subscript = self.rec(expr.subscript, *args, **kwargs) + # + # if (all(new_iname is old_iname + # for new_iname, old_iname in zip(new_inames, expr.swept_inames)) + # and new_subscript is expr.subscript): + # return expr + # + # return SubArrayRef(new_inames, new_subscript) def map_sub_array_ref(self, expr: SubArrayRef, *args: P.args, **kwargs: P.kwargs) -> Expression: new_inames = self.rec(expr.swept_inames, *args, **kwargs) @@ -268,6 +295,16 @@ def is_expr_integer_valued(self, expr: Expression) -> bool: return True +#ArithmeticOrExpressionT = TypeVar( +# "ArithmeticOrExpressionT", +# ArithmeticExpressionT, +# ExpressionT) +ArithmeticOrExpressionT = TypeVar( + "ArithmeticOrExpressionT", + ArithmeticExpression, + Expression) + + def flatten(expr: ArithmeticOrExpressionT) -> ArithmeticOrExpressionT: return cast("ArithmeticOrExpressionT", FlattenMapper()(expr)) @@ -2978,4 +3015,29 @@ def is_tuple_of_expressions_equal( # }}} + +def _is_isl_set_universe(isl_set: Union[isl.BasicSet, isl.Set]): + if isinstance(isl_set, isl.BasicSet): + return isl_set.is_universe() + else: + assert isinstance(isl_set, isl.Set) + return isl_set.complement().is_empty() + + +def pw_qpolynomial_to_expr(pw_qpoly: isl.PwQPolynomial + ) -> ExpressionT: + from pymbolic.primitives import If + + result = 0 + + for bset, qpoly in reversed(pw_qpoly.get_pieces()): + if _is_isl_set_universe(bset): + result = qpolynomial_to_expr(qpoly) + else: + result = If(set_to_cond_expr(bset), + qpolynomial_to_expr(qpoly), + result) + + return result + # vim: foldmethod=marker diff --git a/loopy/target/c/__init__.py b/loopy/target/c/__init__.py index 8bf97fb50..a747c0733 100644 --- a/loopy/target/c/__init__.py +++ b/loopy/target/c/__init__.py @@ -332,8 +332,9 @@ def _preamble_generator(preamble_info, func_qualifier="inline"): n = -n; }""") + # inline {res_ctype} {func.c_name}({base_ctype} x, {exp_ctype} n) {{ yield (f"07_{func.c_name}", f""" - inline {res_ctype} {func.c_name}({base_ctype} x, {exp_ctype} n) {{ + {func_qualifier} {res_ctype} {func.c_name}({base_ctype} x, {exp_ctype} n) {{ if (n == 0) return 1; {re.sub(r"^", 14*" ", signed_exponent_preamble, flags=re.M)} diff --git a/loopy/target/cuda.py b/loopy/target/cuda.py index cda008b47..9111b2121 100644 --- a/loopy/target/cuda.py +++ b/loopy/target/cuda.py @@ -334,6 +334,12 @@ def known_callables(self): callables.update(get_cuda_callables()) return callables + def symbol_manglers(self): + from loopy.target.opencl import opencl_symbol_mangler + return ( + super().symbol_manglers() + [ + opencl_symbol_mangler + ]) # }}} # {{{ top-level codegen diff --git a/loopy/target/opencl.py b/loopy/target/opencl.py index a574adda1..57f0db9be 100644 --- a/loopy/target/opencl.py +++ b/loopy/target/opencl.py @@ -488,6 +488,7 @@ def get_opencl_callables(): # {{{ symbol mangler def opencl_symbol_mangler(kernel, name): + # Also being used in loopy.target.cuda.CudaCASTBuilder.symbol_manglers # FIXME: should be more picky about exact names if name.startswith("FLT_"): return NumpyType(np.dtype(np.float32)), name diff --git a/loopy/target/pycuda.py b/loopy/target/pycuda.py new file mode 100644 index 000000000..e21a87d44 --- /dev/null +++ b/loopy/target/pycuda.py @@ -0,0 +1,657 @@ +"""CUDA target integrated with PyCUDA.""" + +__copyright__ = """ +Copyright (C) 2015 Andreas Kloeckner +Copyright (C) 2022 Kaushik Kulkarni +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import numpy as np +import pymbolic.primitives as p +import genpy + +from loopy.target.cuda import (CudaTarget, CUDACASTBuilder, + ExpressionToCudaCExpressionMapper) +from loopy.target.python import PythonASTBuilderBase +from typing import Sequence, Tuple +from loopy.codegen import CodeGenerationState +from loopy.codegen.result import CodeGenerationResult +from loopy.target.c import CMathCallable +from loopy.diagnostic import LoopyError +from loopy.types import NumpyType +from loopy.codegen import CodeGenerationState +from loopy.codegen.result import CodeGenerationResult +from cgen import Generable + +import logging +logger = logging.getLogger(__name__) + + +# {{{ preamble generator + +def pycuda_preamble_generator(preamble_info): + has_complex = False + + for dtype in preamble_info.seen_dtypes: + if dtype.involves_complex(): + has_complex = True + + if has_complex: + yield ("03_include_complex_header", """ + #include + """) + +# }}} + + +# {{{ PyCudaCallable + +class PyCudaCallable(CMathCallable): + def with_types(self, arg_id_to_dtype, callables_table): + if any(dtype.is_complex() for dtype in arg_id_to_dtype.values()): + if self.name in ["abs", "real", "imag"]: + if not (set(arg_id_to_dtype) <= {0, -1}): + raise LoopyError(f"'{self.name}' takes only one argument") + if arg_id_to_dtype.get(0) is None: + # not specialized enough + return (self.copy(arg_id_to_dtype=arg_id_to_dtype), + callables_table) + else: + real_dtype = np.empty(0, + arg_id_to_dtype[0].numpy_dtype).real.dtype + arg_id_to_dtype = arg_id_to_dtype.copy() + arg_id_to_dtype[-1] = NumpyType(real_dtype) + return (self.copy(arg_id_to_dtype=arg_id_to_dtype, + name_in_target=self.name), + callables_table) + elif self.name in ["sqrt", "conj", + "sin", "cos", "tan", + "sinh", "cosh", "tanh", "exp", + "log", "log10"]: + if not (set(arg_id_to_dtype) <= {0, -1}): + raise LoopyError(f"'{self.name}' takes only one argument") + if arg_id_to_dtype.get(0) is None: + # not specialized enough + return (self.copy(arg_id_to_dtype=arg_id_to_dtype), + callables_table) + else: + arg_id_to_dtype = arg_id_to_dtype.copy() + arg_id_to_dtype[-1] = arg_id_to_dtype[0] + return (self.copy(arg_id_to_dtype=arg_id_to_dtype, + name_in_target=self.name), + callables_table) + else: + raise LoopyError(f"'{self.name}' does not take complex" + " arguments.") + else: + if self.name in ["real", "imag", "conj"]: + if arg_id_to_dtype.get(0): + raise NotImplementedError("'{self.name}' for real arguments" + ", not yet supported.") + return super().with_types(arg_id_to_dtype, callables_table) + + +def get_pycuda_callables(): + cmath_ids = ["abs", "acos", "asin", "atan", "cos", "cosh", "sin", + "sinh", "pow", "atan2", "tanh", "exp", "log", "log10", + "sqrt", "ceil", "floor", "max", "min", "fmax", "fmin", + "fabs", "tan", "erf", "erfc", "isnan", "real", "imag", + "conj"] + return {id_: PyCudaCallable(id_) for id_ in cmath_ids} + +# }}} + + +# {{{ expression mapper + +def _get_complex_tmplt_arg(dtype) -> str: + if dtype == np.complex128: + return "double" + elif dtype == np.complex64: + return "float" + else: + raise RuntimeError(f"unsupported complex type {dtype}.") + + +class ExpressionToPyCudaCExpressionMapper(ExpressionToCudaCExpressionMapper): + """ + .. note:: + + - PyCUDA (very conveniently) provides access to complex arithmetic + headers which is not the default in CUDA. To access such additional + features we introduce this mapper. + """ + def wrap_in_typecast_lazy(self, actual_type_func, needed_dtype, s): + if needed_dtype.is_complex(): + return self.wrap_in_typecast(actual_type_func(), needed_dtype, s) + else: + return super().wrap_in_typecast_lazy(actual_type_func, + needed_dtype, s) + + def wrap_in_typecast(self, actual_type, needed_dtype, s): + if not actual_type.is_complex() and needed_dtype.is_complex(): + tmplt_arg = _get_complex_tmplt_arg(needed_dtype.numpy_dtype) + return p.Variable(f"pycuda::complex<{tmplt_arg}>")(s) + else: + return super().wrap_in_typecast_lazy(actual_type, + needed_dtype, s) + + def map_constant(self, expr, type_context): + if isinstance(expr, (complex, np.complexfloating)): + try: + dtype = expr.dtype + except AttributeError: + # (COMPLEX_GUESS_LOGIC) This made it through type 'guessing' in + # type inference, and it was concluded there (search for + # COMPLEX_GUESS_LOGIC in loopy.type_inference), that no + # accuracy was lost by using single precision. + dtype = np.complex64 + else: + tmplt_arg = _get_complex_tmplt_arg(dtype) + + return p.Variable(f"pycuda::complex<{tmplt_arg}>")(self.rec(expr.real, + type_context), + self.rec(expr.imag, + type_context)) + + return super().map_constant(expr, type_context) + +# }}} + + +# {{{ target + +class PyCudaTarget(CudaTarget): + """A code generation target that takes special advantage of :mod:`pycuda` + features such as run-time knowledge of the target device (to generate + warnings) and support for complex numbers. + """ + + # FIXME make prefixes conform to naming rules + # (see Reference: Loopy’s Model of a Kernel) + + host_program_name_prefix = "_lpy_host_" + host_program_name_suffix = "" + + def __init__(self, pycuda_module_name="_lpy_cuda"): + # import pycuda.tools import to populate the TYPE_REGISTRY + import pycuda.tools # noqa: F401 + super().__init__() + self.pycuda_module_name = pycuda_module_name + + # NB: Not including 'device', as that is handled specially here. + hash_fields = CudaTarget.hash_fields + ( + "pycuda_module_name",) + comparison_fields = CudaTarget.comparison_fields + ( + "pycuda_module_name",) + + def get_host_ast_builder(self): + return PyCudaPythonASTBuilder(self) + + def get_device_ast_builder(self): + return PyCudaCASTBuilder(self) + + def get_kernel_executor_cache_key(self, **kwargs): + return (kwargs["entrypoint"],) + + def get_dtype_registry(self): + from pycuda.compyte.dtypes import TYPE_REGISTRY + return TYPE_REGISTRY + + def preprocess_translation_unit_for_passed_args(self, t_unit, epoint, + passed_args_dict): + + # {{{ ValueArgs -> GlobalArgs if passed as array shapes + + from loopy.kernel.data import ValueArg, GlobalArg + import pycuda.gpuarray as cu_np + + knl = t_unit[epoint] + new_args = [] + + for arg in knl.args: + if isinstance(arg, ValueArg): + if (arg.name in passed_args_dict + and isinstance(passed_args_dict[arg.name], cu_np.GPUArray) + and passed_args_dict[arg.name].shape == ()): + arg = GlobalArg(name=arg.name, dtype=arg.dtype, shape=(), + is_output=False, is_input=True) + + new_args.append(arg) + + knl = knl.copy(args=new_args) + + t_unit = t_unit.with_kernel(knl) + + # }}} + + return t_unit + + def get_kernel_executor(self, t_unit, **kwargs): + from loopy.target.pycuda_execution import PyCudaKernelExecutor + + epoint = kwargs.pop("entrypoint") + t_unit = self.preprocess_translation_unit_for_passed_args(t_unit, + epoint, + kwargs) + + return PyCudaKernelExecutor(t_unit, entrypoint=epoint) + + +class PyCudaWithPackedArgsTarget(PyCudaTarget): + + def get_kernel_executor(self, t_unit, **kwargs): + from loopy.target.pycuda_execution import PyCudaWithPackedArgsKernelExecutor + + epoint = kwargs.pop("entrypoint") + t_unit = self.preprocess_translation_unit_for_passed_args(t_unit, + epoint, + kwargs) + + return PyCudaWithPackedArgsKernelExecutor(t_unit, entrypoint=epoint) + + def get_host_ast_builder(self): + return PyCudaWithPackedArgsPythonASTBuilder(self) + + def get_device_ast_builder(self): + return PyCudaWithPackedArgsCASTBuilder(self) + +# }}} + + +# {{{ host ast builder + +class PyCudaPythonASTBuilder(PythonASTBuilderBase): + """A Python host AST builder for integration with PyCuda. + """ + + # {{{ code generation guts + + def get_function_definition( + self, codegen_state, codegen_result, + schedule_index: int, function_decl, function_body: genpy.Generable + ) -> genpy.Function: + assert schedule_index == 0 + + from loopy.schedule.tools import get_kernel_arg_info + kai = get_kernel_arg_info(codegen_state.kernel) + + args = ( + ["_lpy_cuda_functions"] + + [arg_name for arg_name in kai.passed_arg_names] + + ["wait_for=()", "allocator=None", "stream=None"]) + + from genpy import (For, Function, Suite, Return, Line, Statement as S) + return Function( + codegen_result.current_program(codegen_state).name, + args, + Suite([ + Line(), + ] + [ + Line(), + function_body, + Line(), + ] + ([ + For("_tv", "_global_temporaries", + # Free global temporaries. + # Zero-size temporaries allocate as None, tolerate that. + S("if _tv is not None: _tv.free()")) + ] if self._get_global_temporaries(codegen_state) else [] + ) + [ + Line(), + Return("_lpy_evt"), + ])) + + def get_function_declaration( + self, codegen_state: CodeGenerationState, + codegen_result: CodeGenerationResult, schedule_index: int + ) -> Tuple[Sequence[Tuple[str, str]], genpy.Generable]: + # no such thing in Python + return [], None + + def _get_global_temporaries(self, codegen_state): + from loopy.kernel.data import AddressSpace + + return sorted( + (tv for tv in codegen_state.kernel.temporary_variables.values() + if tv.address_space == AddressSpace.GLOBAL), + key=lambda tv: tv.name) + + def get_temporary_decls(self, codegen_state, schedule_index): + from genpy import Assign, Comment, Line + + from pymbolic.mapper.stringifier import PREC_NONE + ecm = self.get_expression_to_code_mapper(codegen_state) + + global_temporaries = self._get_global_temporaries(codegen_state) + if not global_temporaries: + return [] + + allocated_var_names = [] + code_lines = [] + code_lines.append(Line()) + code_lines.append(Comment("{{{ allocate global temporaries")) + code_lines.append(Line()) + + for tv in global_temporaries: + if not tv.base_storage: + nbytes_str = ecm(tv.nbytes, PREC_NONE, "i") + allocated_var_names.append(tv.name) + code_lines.append(Assign(tv.name, + f"allocator({nbytes_str})")) + + code_lines.append(Assign("_global_temporaries", "[{tvs}]".format( + tvs=", ".join(tv for tv in allocated_var_names)))) + + code_lines.append(Line()) + code_lines.append(Comment("}}}")) + code_lines.append(Line()) + + return code_lines + + def get_kernel_call(self, + codegen_state, subkernel_name, grid, block): + from genpy import Suite, Assign, Line, Comment, Statement + from pymbolic.mapper.stringifier import PREC_NONE + + from loopy.schedule.tools import get_subkernel_arg_info + skai = get_subkernel_arg_info( + codegen_state.kernel, subkernel_name) + + # make grid/block a 3-tuple + grid = grid + (1,) * (3-len(grid)) + block = block + (1,) * (3-len(block)) + ecm = self.get_expression_to_code_mapper(codegen_state) + + grid_str = ecm(grid, prec=PREC_NONE, type_context="i") + block_str = ecm(block, prec=PREC_NONE, type_context="i") + + return Suite([ + Comment("{{{ launch %s" % subkernel_name), + Line(), + Statement("for _lpy_cu_evt in wait_for: _lpy_cu_evt.synchronize()"), + Line(), + Assign("_lpy_knl", f"_lpy_cuda_functions['{subkernel_name}']"), + Line(), + Statement("_lpy_knl.prepared_async_call(" + f"{grid_str}, {block_str}, " + "stream, " + f"{', '.join(arg_name for arg_name in skai.passed_names)}" + ")",), + Assign("_lpy_evt", "_lpy_cuda.Event().record(stream)"), + Assign("wait_for", "[_lpy_evt]"), + Line(), + Comment("}}}"), + Line(), + ]) + + # }}} + + +class PyCudaWithPackedArgsPythonASTBuilder(PyCudaPythonASTBuilder): + + def get_kernel_call(self, + codegen_state, subkernel_name, grid, block): + from genpy import Suite, Assign, Line, Comment, Statement + from pymbolic.mapper.stringifier import PREC_NONE + from loopy.kernel.data import ValueArg, ArrayArg + + from loopy.schedule.tools import get_subkernel_arg_info + kernel = codegen_state.kernel + skai = get_subkernel_arg_info(kernel, subkernel_name) + + # make grid/block a 3-tuple + grid = grid + (1,) * (3-len(grid)) + block = block + (1,) * (3-len(block)) + ecm = self.get_expression_to_code_mapper(codegen_state) + + grid_str = ecm(grid, prec=PREC_NONE, type_context="i") + block_str = ecm(block, prec=PREC_NONE, type_context="i") + + struct_format = [] + for arg_name in skai.passed_names: + if arg_name in codegen_state.kernel.all_inames(): + struct_format.append(kernel.index_dtype.numpy_dtype.char) + if kernel.index_dtype.numpy_dtype.itemsize < 8: + struct_format.append("x") + elif arg_name in codegen_state.kernel.temporary_variables: + struct_format.append("P") + else: + knl_arg = codegen_state.kernel.arg_dict[arg_name] + if isinstance(knl_arg, ValueArg): + struct_format.append(knl_arg.dtype.numpy_dtype.char) + if knl_arg.dtype.numpy_dtype.itemsize < 8: + struct_format.append("x") + else: + struct_format.append("P") + + def _arg_cast(arg_name: str) -> str: + if arg_name in skai.passed_inames: + return ("_lpy_np" + f".{codegen_state.kernel.index_dtype.numpy_dtype.name}" + f"({arg_name})") + elif arg_name in skai.passed_temporaries: + return f"_lpy_np.uintp(int({arg_name}))" + else: + knl_arg = codegen_state.kernel.arg_dict[arg_name] + if isinstance(knl_arg, ValueArg): + assert knl_arg.dtype is not None + return f"_lpy_np.{knl_arg.dtype.numpy_dtype.name}({arg_name})" + else: + assert isinstance(knl_arg, ArrayArg) + return f"_lpy_np.uintp(int({arg_name}))" + + return Suite([ + Comment("{{{ launch %s" % subkernel_name), + Line(), + Statement("for _lpy_cu_evt in wait_for: _lpy_cu_evt.synchronize()"), + Line(), + Assign("_lpy_knl", f"_lpy_cuda_functions['{subkernel_name}']"), + Line(), + Assign("_lpy_args_on_dev", f"allocator({len(skai.passed_names)*8})"), + Assign("_lpy_args_on_host", + f"_lpy_struct.pack('{''.join(struct_format)}'," + f"{','.join(_arg_cast(arg) for arg in skai.passed_names)})"), + Statement("_lpy_cuda.memcpy_htod(_lpy_args_on_dev, _lpy_args_on_host)"), + Line(), + Statement("_lpy_knl.prepared_async_call(" + f"{grid_str}, {block_str}, " + "stream, _lpy_args_on_dev)"), + Assign("_lpy_evt", "_lpy_cuda.Event().record(stream)"), + Assign("wait_for", "[_lpy_evt]"), + Line(), + Comment("}}}"), + Line(), + ]) + +# }}} + + +# {{{ device ast builder + +class PyCudaCASTBuilder(CUDACASTBuilder): + """A C device AST builder for integration with PyCUDA. + """ + + # {{{ library + + def preamble_generators(self): + return ([pycuda_preamble_generator] + + super().preamble_generators()) + + @property + def known_callables(self): + callables = super().known_callables + callables.update(get_pycuda_callables()) + return callables + + # }}} + + def get_expression_to_c_expression_mapper(self, codegen_state): + return ExpressionToPyCudaCExpressionMapper(codegen_state) + + +class PyCudaWithPackedArgsCASTBuilder(PyCudaCASTBuilder): + def arg_struct_name(self, kernel_name: str): + return f"_lpy_{kernel_name}_packed_args" + + def get_function_declaration(self, codegen_state, codegen_result, + schedule_index): + from loopy.target.c import FunctionDeclarationWrapper + from cgen import FunctionDeclaration, Value, Pointer, Extern + from cgen.cuda import CudaGlobal, CudaDevice, CudaLaunchBounds + + kernel = codegen_state.kernel + + assert kernel.linearization is not None + name = codegen_result.current_program(codegen_state).name + arg_type = self.arg_struct_name(name) + + if self.target.fortran_abi: + name += "_" + + fdecl = FunctionDeclaration( + Value("void", name), + [Pointer(Value(arg_type, "_lpy_args"))]) + + if codegen_state.is_entrypoint: + fdecl = CudaGlobal(fdecl) + if self.target.extern_c: + fdecl = Extern("C", fdecl) + + from loopy.schedule import get_insn_ids_for_block_at + _, lsize = kernel.get_grid_sizes_for_insn_ids_as_exprs( + get_insn_ids_for_block_at(kernel.linearization, schedule_index), + codegen_state.callables_table) + + from loopy.symbolic import get_dependencies + if not get_dependencies(lsize): + # Sizes can't have parameter dependencies if they are + # to be used in static thread block size. + from pytools import product + nthreads = product(lsize) + + fdecl = CudaLaunchBounds(nthreads, fdecl) + + return [], FunctionDeclarationWrapper(fdecl) + else: + return [], CudaDevice(fdecl) + + def get_function_definition( + self, codegen_state: CodeGenerationState, + codegen_result: CodeGenerationResult, + schedule_index: int, function_decl: Generable, function_body: Generable + ) -> Generable: + from typing import cast + from loopy.target.c import generate_array_literal + from loopy.schedule import CallKernel + from loopy.schedule.tools import get_subkernel_arg_info + from loopy.kernel.data import ValueArg, AddressSpace + from cgen import (FunctionBody, + Module as Collection, + Initializer, + Line, Value, Pointer, Struct as GenerableStruct) + kernel = codegen_state.kernel + assert kernel.linearization is not None + assert isinstance(kernel.linearization[schedule_index], CallKernel) + kernel_name = (cast(CallKernel, + kernel.linearization[schedule_index]) + .kernel_name) + + skai = get_subkernel_arg_info(kernel, kernel_name) + + result = [] + + # We only need to write declarations for global variables with + # the first device program. `is_first_dev_prog` determines + # whether this is the first device program in the schedule. + is_first_dev_prog = codegen_state.is_generating_device_code + for i in range(schedule_index): + if isinstance(kernel.linearization[i], CallKernel): + is_first_dev_prog = False + break + if is_first_dev_prog: + for tv in sorted( + kernel.temporary_variables.values(), + key=lambda key_tv: key_tv.name): + + if tv.address_space == AddressSpace.GLOBAL and ( + tv.initializer is not None): + assert tv.read_only + + decl = self.wrap_global_constant( + self.get_temporary_var_declarator(codegen_state, tv)) + + if tv.initializer is not None: + decl = Initializer(decl, generate_array_literal( + codegen_state, tv, tv.initializer)) + + result.append(decl) + + # {{{ declare+unpack the struct type + + struct_fields = [] + + for arg_name in skai.passed_names: + if arg_name in skai.passed_inames: + struct_fields.append( + Value(self.target.dtype_to_typename(kernel.index_dtype), + f"{arg_name}, __padding_{arg_name}")) + elif arg_name in skai.passed_temporaries: + tv = kernel.temporary_variables[arg_name] + struct_fields.append(Pointer( + Value(self.target.dtype_to_typename(tv.dtype), arg_name))) + else: + knl_arg = kernel.arg_dict[arg_name] + if isinstance(knl_arg, ValueArg): + struct_fields.append( + Value(self.target.dtype_to_typename(knl_arg.dtype), + f"{arg_name}, __padding_{arg_name}")) + else: + struct_fields.append( + Pointer(Value(self.target.dtype_to_typename(knl_arg.dtype), + arg_name))) + + function_body.insert(0, Line()) + for arg_name in skai.passed_names[::-1]: + function_body.insert(0, Initializer( + self.arg_to_cgen_declarator( + kernel, arg_name, + arg_name in kernel.get_written_variables()), + f"_lpy_args->{arg_name}" + )) + + # }}} + + fbody = FunctionBody(function_decl, function_body) + + return Collection([*result, + Line(), + GenerableStruct(self.arg_struct_name(kernel_name), + struct_fields), + Line(), + fbody]) + + +# }}} + +# vim: foldmethod=marker diff --git a/loopy/target/pycuda_execution.py b/loopy/target/pycuda_execution.py new file mode 100644 index 000000000..ec0c1834a --- /dev/null +++ b/loopy/target/pycuda_execution.py @@ -0,0 +1,394 @@ +__copyright__ = """ +Copyright (C) 2012 Andreas Kloeckner +Copyright (C) 2022 Kaushik Kulkarni +""" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +from typing import (Sequence, Tuple, Union, Callable, Any, Optional, + TYPE_CHECKING) +from dataclasses import dataclass + +import numpy as np +from immutables import Map + +from pytools import memoize_method +from pytools.codegen import Indentation, CodeGenerator + +from loopy.types import LoopyType +from loopy.typing import ExpressionT +from loopy.kernel import LoopKernel +from loopy.kernel.data import ArrayArg +from loopy.translation_unit import TranslationUnit +from loopy.schedule.tools import KernelArgInfo +from loopy.target.execution import ( + KernelExecutorBase, ExecutionWrapperGeneratorBase) +import logging +logger = logging.getLogger(__name__) + + +if TYPE_CHECKING: + import pycuda.driver as cuda + + +# {{{ invoker generation + +# /!\ This code runs in a namespace controlled by the user. +# Prefix all auxiliary variables with "_lpy". + + +class PyCudaExecutionWrapperGenerator(ExecutionWrapperGeneratorBase): + """ + Specialized form of the :class:`ExecutionWrapperGeneratorBase` for + pycuda execution + """ + + def __init__(self): + system_args = [ + "_lpy_cuda_functions", "stream=None", "allocator=None", "wait_for=()", + # ignored if options.no_numpy + "out_host=None" + ] + super().__init__(system_args) + + def python_dtype_str_inner(self, dtype): + from pycuda.tools import dtype_to_ctype + # Test for types built into numpy. dtype.isbuiltin does not work: + # https://github.com/numpy/numpy/issues/4317 + # Guided by https://numpy.org/doc/stable/reference/arrays.scalars.html + if issubclass(dtype.type, (np.bool_, np.number)): + name = dtype.name + if dtype.type == np.bool_: + name = "bool8" + return f"_lpy_np.dtype(_lpy_np.{name})" + else: + return ('_lpy_cuda_tools.get_or_register_dtype("%s")' + % dtype_to_ctype(dtype)) + + # {{{ handle non-numpy args + + def handle_non_numpy_arg(self, gen, arg): + gen("if isinstance(%s, _lpy_np.ndarray):" % arg.name) + with Indentation(gen): + gen("# retain originally passed array") + gen(f"_lpy_{arg.name}_np_input = {arg.name}") + gen("# synchronous, nothing to worry about") + gen("%s = _lpy_cuda_array.to_gpu_async(" + "%s, allocator=allocator, stream=stream)" + % (arg.name, arg.name)) + gen("_lpy_encountered_numpy = True") + gen("elif %s is not None:" % arg.name) + with Indentation(gen): + gen("_lpy_encountered_dev = True") + gen("_lpy_%s_np_input = None" % arg.name) + gen("else:") + with Indentation(gen): + gen("_lpy_%s_np_input = None" % arg.name) + + gen("") + + # }}} + + # {{{ handle allocation of unspecified arguments + + def handle_alloc( + self, gen: CodeGenerator, arg: ArrayArg, + strify: Callable[[Union[ExpressionT, Tuple[ExpressionT]]], str], + skip_arg_checks: bool) -> None: + """ + Handle allocation of non-specified arguments for pycuda execution + """ + from pymbolic import var + + from loopy.kernel.array import get_strides + strides = get_strides(arg) + num_axes = len(strides) + + assert arg.dtype is not None + itemsize = arg.dtype.numpy_dtype.itemsize + for i in range(num_axes): + gen("_lpy_ustrides_%d = %s" % (i, strify(strides[i]))) + + if not skip_arg_checks: + for i in range(num_axes): + gen("assert _lpy_ustrides_%d >= 0, " + "\"'%s' has negative stride in axis %d\"" + % (i, arg.name, i)) + + assert isinstance(arg.shape, tuple) + sym_ustrides = tuple( + var("_lpy_ustrides_%d" % i) + for i in range(num_axes)) + sym_shape = tuple(arg.shape[i] for i in range(num_axes)) + + size_expr = (sum(astrd*(alen-1) + for alen, astrd in zip(sym_shape, sym_ustrides)) + + 1) + + gen("_lpy_size = %s" % strify(size_expr)) + sym_strides = tuple(itemsize*s_i for s_i in sym_ustrides) + + dtype_name = self.python_dtype_str(gen, arg.dtype.numpy_dtype) + gen(f"{arg.name} = _lpy_cuda_array.GPUArray({strify(sym_shape)}, " + f"{dtype_name}, strides={strify(sym_strides)}, " + f"gpudata=allocator({strify(itemsize * var('_lpy_size'))}), " + "allocator=allocator)") + + for i in range(num_axes): + gen("del _lpy_ustrides_%d" % i) + gen("del _lpy_size") + gen("") + + # }}} + + def target_specific_preamble(self, gen): + """ + Add default pycuda imports to preamble + """ + gen.add_to_preamble("import numpy as _lpy_np") + gen.add_to_preamble("import pycuda.driver as _lpy_cuda") + gen.add_to_preamble("import pycuda.gpuarray as _lpy_cuda_array") + gen.add_to_preamble("import pycuda.tools as _lpy_cuda_tools") + gen.add_to_preamble("import struct as _lpy_struct") + from loopy.target.c.c_execution import DEF_EVEN_DIV_FUNCTION + gen.add_to_preamble(DEF_EVEN_DIV_FUNCTION) + + def initialize_system_args(self, gen): + """ + Initializes possibly empty system arguments + """ + gen("if allocator is None:") + with Indentation(gen): + gen("allocator = _lpy_cuda.mem_alloc") + gen("") + + # {{{ generate invocation + + def generate_invocation(self, gen: CodeGenerator, kernel: LoopKernel, + kai: KernelArgInfo, host_program_name: str, args: Sequence[str]) -> None: + arg_list = (["_lpy_cuda_functions"] + + list(args) + + ["stream=stream", "wait_for=wait_for", "allocator=allocator"]) + gen(f"_lpy_evt = {host_program_name}({', '.join(arg_list)})") + + # }}} + + # {{{ generate_output_handler + + def generate_output_handler(self, gen: CodeGenerator, + kernel: LoopKernel, kai: KernelArgInfo) -> None: + options = kernel.options + + if not options.no_numpy: + gen("if out_host is None and (_lpy_encountered_numpy " + "and not _lpy_encountered_dev):") + with Indentation(gen): + gen("out_host = True") + + for arg_name in kai.passed_arg_names: + arg = kernel.arg_dict[arg_name] + if arg.is_output: + np_name = "_lpy_%s_np_input" % arg.name + gen("if out_host or %s is not None:" % np_name) + with Indentation(gen): + gen("%s = %s.get(stream=stream, ary=%s)" + % (arg.name, arg.name, np_name)) + + gen("") + + if options.return_dict: + gen("return _lpy_evt, {%s}" + % ", ".join(f'"{arg_name}": {arg_name}' + for arg_name in kai.passed_arg_names + if kernel.arg_dict[arg_name].is_output)) + else: + out_names = [arg_name for arg_name in kai.passed_arg_names + if kernel.arg_dict[arg_name].is_output] + if out_names: + gen("return _lpy_evt, (%s,)" + % ", ".join(out_names)) + else: + gen("return _lpy_evt, ()") + + # }}} + + def generate_host_code(self, gen, codegen_result): + gen.add_to_preamble(codegen_result.host_code()) + + def get_arg_pass(self, arg): + return "%s.gpudata" % arg.name + +# }}} + + +@dataclass(frozen=True) +class _KernelInfo: + t_unit: TranslationUnit + cuda_functions: Map[str, "cuda.Function"] + invoker: Callable[..., Any] + + +# {{{ kernel executor + +class PyCudaKernelExecutor(KernelExecutorBase): + """ + An object connecting a kernel to a :mod:`pycuda` + for execution. + + .. automethod:: __init__ + .. automethod:: __call__ + """ + + def get_invoker_uncached(self, t_unit, entrypoint, codegen_result): + generator = PyCudaExecutionWrapperGenerator() + return generator(t_unit, entrypoint, codegen_result) + + def get_wrapper_generator(self): + return PyCudaExecutionWrapperGenerator() + + def _get_arg_dtypes(self, knl, subkernel_name): + from loopy.schedule.tools import get_subkernel_arg_info + from loopy.kernel.data import ValueArg + + skai = get_subkernel_arg_info(knl, subkernel_name) + arg_dtypes = [] + for arg in skai.passed_names: + if arg in skai.passed_inames: + arg_dtypes.append(knl.index_dtype.numpy_dtype) + elif arg in skai.passed_temporaries: + arg_dtypes.append("P") + else: + assert arg in knl.arg_dict + if isinstance(knl.arg_dict[arg], ValueArg): + arg_dtypes.append(knl.arg_dict[arg].dtype.numpy_dtype) + else: + # Array Arg + arg_dtypes.append("P") + + return arg_dtypes + + @memoize_method + def translation_unit_info(self, + arg_to_dtype: Optional[Map[str, LoopyType]] = None + ) -> _KernelInfo: + t_unit = self.get_typed_and_scheduled_translation_unit(self.entrypoint, + arg_to_dtype) + + # FIXME: now just need to add the types to the arguments + from loopy.codegen import generate_code_v2 + from loopy.target.execution import get_highlighted_code + codegen_result = generate_code_v2(t_unit) + + dev_code = codegen_result.device_code() + epoint_knl = t_unit[self.entrypoint] + + if t_unit[self.entrypoint].options.write_code: + #FIXME: redirect to "translation unit" level option as well. + output = dev_code + if self.t_unit[self.entrypoint].options.allow_terminal_colors: + output = get_highlighted_code(output) + + if epoint_knl.options.write_code is True: + print(output) + else: + with open(epoint_knl.options.write_code, "w") as outf: + outf.write(output) + + if epoint_knl.options.edit_code: + #FIXME: redirect to "translation unit" level option as well. + from pytools import invoke_editor + dev_code = invoke_editor(dev_code, "code.cu") + + from pycuda.compiler import SourceModule + from loopy.kernel.tools import get_subkernels + + #FIXME: redirect to "translation unit" level option as well. + src_module = SourceModule(dev_code, + options=epoint_knl.options.build_options) + + cuda_functions = Map({name: (src_module + .get_function(name) + .prepare(self._get_arg_dtypes(epoint_knl, name)) + ) + for name in get_subkernels(epoint_knl)}) + return _KernelInfo( + t_unit=t_unit, + cuda_functions=cuda_functions, + invoker=self.get_invoker(t_unit, self.entrypoint, codegen_result)) + + def __call__(self, *, + stream=None, allocator=None, wait_for=(), out_host=None, + **kwargs): + """ + :arg allocator: a callable that accepts a byte count and returns + an instance of :class:`pycuda.driver.DeviceAllocation`. Typically + one of :func:`pycuda.driver.mem_alloc` or + :meth:`pycuda.tools.DeviceMemoryPool.allocate`. + :arg wait_for: A sequence of :class:`pycuda.driver.Event` instances + for which to wait before launching the CUDA kernels. + :arg out_host: :class:`bool` + Decides whether output arguments (i.e. arguments + written by the kernel) are to be returned as + :mod:`numpy` arrays. *True* for yes, *False* for no. + + For the default value of *None*, if all (input) array + arguments are :mod:`numpy` arrays, defaults to + returning :mod:`numpy` arrays as well. + + :returns: ``(evt, output)`` where *evt* is a + :class:`pycuda.driver.Event` that is recorded right after the + kernel has been launched and output is a tuple of output arguments + (arguments that are written as part of the kernel). The order is + given by the order of kernel arguments. If this order is + unspecified (such as when kernel arguments are inferred + automatically), enable :attr:`loopy.Options.return_dict` to make + *output* a :class:`dict` instead, with keys of argument names and + values of the returned arrays. + """ + + if "entrypoint" in kwargs: + assert kwargs.pop("entrypoint") == self.entrypoint + from warnings import warn + warn("Obtained a redundant argument 'entrypoint'. This will" + " be an error in 2023.", DeprecationWarning, stacklevel=2) + + if __debug__: + self.check_for_required_array_arguments(kwargs.keys()) + + if self.packing_controller is not None: + kwargs = self.packing_controller(kwargs) + + translation_unit_info = self.translation_unit_info(self.arg_to_dtype(kwargs)) + + return translation_unit_info.invoker( + translation_unit_info.cuda_functions, stream, allocator, wait_for, + out_host, **kwargs) + + +class PyCudaWithPackedArgsKernelExecutor(PyCudaKernelExecutor): + + def _get_arg_dtypes(self, knl, subkernel_name): + return ["P"] + +# }}} + +# vim: foldmethod=marker diff --git a/loopy/tools.py b/loopy/tools.py index 463fc2cbb..e1255e65e 100644 --- a/loopy/tools.py +++ b/loopy/tools.py @@ -26,7 +26,18 @@ import logging from functools import cached_property from sys import intern -from typing import TYPE_CHECKING, ClassVar, Generic, Literal, TypeVar, cast, overload +from typing import ( + TYPE_CHECKING, + ClassVar, + FrozenSet, + Generic, + Literal, + TypeVar, + Iterator, + cast, + overload +) +from dataclasses import dataclass import numpy as np from constantdict import constantdict diff --git a/loopy/transform/domain.py b/loopy/transform/domain.py new file mode 100644 index 000000000..03bef1547 --- /dev/null +++ b/loopy/transform/domain.py @@ -0,0 +1,90 @@ +__copyright__ = "Copyright (C) 2023 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +__doc__ = """ +.. currentmodule:: loopy + +.. autofunction:: decouple_domain +""" + +import islpy as isl + +from loopy.translation_unit import for_each_kernel +from loopy.kernel import LoopKernel +from loopy.diagnostic import LoopyError +from collections.abc import Collection + + +@for_each_kernel +def decouple_domain(kernel: LoopKernel, + inames: Collection[str], + parent_inames: Collection[str]) -> LoopKernel: + r""" + Returns a copy of *kernel* with altered domains. The home domain of + *inames* i.e. :math:`\mathcal{D}^{\text{home}}({\text{inames}})` is + replaced with two domains :math:`\mathcal{D}_1` and :math:`\mathcal{D}_2`. + :math:`\mathcal{D}_1` is the domain with dimensions corresponding to *inames* + projected out and :math:`\mathcal{D}_2` is the domain with all the dimensions + other than the ones corresponding to *inames* projected out. + + .. note:: + + An error is raised if all the *inames* do not correspond to the same home + domain of *kernel*. + """ + + if not inames: + raise LoopyError("No inames were provided to decouple into" + " a different domain.") + + hdi = kernel.get_home_domain_index(next(iter(inames))) + for iname in inames: + if kernel.get_home_domain_index(iname) != hdi: + raise LoopyError("inames are not a part of the same home domain.") + + for parent_iname in parent_inames: + if parent_iname not in set(kernel.domains[hdi].get_var_dict()): + raise LoopyError(f"Parent iname '{parent_iname}' not a part of the" + f" corresponding home domain '{kernel.domains[hdi]}'.") + + all_dims = frozenset(kernel.domains[hdi].get_var_dict()) + dom1 = kernel.domains[hdi] + dom2 = kernel.domains[hdi] + + for iname in sorted(all_dims): + if iname in inames: + dt, pos = dom1.get_var_dict()[iname] + dom1 = dom1.project_out(dt, pos, 1) + elif iname in parent_inames: + dt, pos = dom2.get_var_dict()[iname] + if dt != isl.dim_type.param: + n_params = dom2.dim(isl.dim_type.param) + dom2 = dom2.move_dims(isl.dim_type.param, n_params, dt, pos, 1) + else: + dt, pos = dom2.get_var_dict()[iname] + dom2 = dom2.project_out(dt, pos, 1) + + new_domains = kernel.domains[:] + new_domains[hdi] = dom1 + new_domains.append(dom2) + kernel = kernel.copy(domains=new_domains) + return kernel diff --git a/loopy/transform/iname.py b/loopy/transform/iname.py index 933b7ec4e..03fb936bb 100644 --- a/loopy/transform/iname.py +++ b/loopy/transform/iname.py @@ -27,7 +27,8 @@ from collections.abc import Collection, Iterable, Mapping, Sequence from typing import TYPE_CHECKING, Any, TypeAlias -from typing_extensions import override +from immutabledict import immutabledict +from typing_extensions import TypeAlias, override import islpy as isl from islpy import dim_type @@ -1645,6 +1646,8 @@ def __init__(self, rule_mapping_context, inames, within): def get_cache_key(self, expr, expn_state): return (super().get_cache_key(expr, expn_state), + # immutabledict(self.iname_to_red_count), + # immutabledict(self.iname_to_nonsimultaneous_red_count),) hash(frozenset(self.iname_to_red_count.items())), hash(frozenset(self.iname_to_nonsimultaneous_red_count.items())), ) diff --git a/loopy/transform/loop_fusion.py b/loopy/transform/loop_fusion.py index 2a9927fe3..060207755 100644 --- a/loopy/transform/loop_fusion.py +++ b/loopy/transform/loop_fusion.py @@ -46,14 +46,15 @@ if TYPE_CHECKING: - from collections.abc import Callable, Collection, Iterable, Mapping, Sequence, Set - + from collections.abc import ( + Callable, Collection, Iterable, + Mapping, Optional, Sequence, Set + ) from loopy.kernel.instruction import InstructionBase from loopy.match import RuleStack from loopy.schedule.tools import LoopNestTree from loopy.typing import InameStr, InameStrSet - __doc__ = """ .. autofunction:: rename_inames_in_batch .. autofunction:: get_kennedy_unweighted_fusion_candidates @@ -203,13 +204,11 @@ def done(self): """ return LoopDependenceGraph.new(self._dag, self._is_infusible) - # }}} # {{{ _build_ldg - @dataclass(frozen=True, eq=True, repr=True) class PreLDGNode: """ @@ -615,6 +614,7 @@ def _fuse_sequential_loops_within_outer_loops( outer_inames: frozenset[InameStr], name_gen: Callable[[str], str], prefix: str, + force_infusible: Callable[[str, str], bool] = None ): from collections import deque @@ -632,7 +632,8 @@ def _fuse_sequential_loops_within_outer_loops( queue = deque([loops_with_no_preds[0]]) for node in loops_with_no_preds[1:]: - queue.append(node) + if not force_infusible(node, loops_with_no_preds[0]): + queue.append(node) loops_to_be_fused: set[InameStr] = set() non_fusible_loops: set[InameStr] = set() @@ -658,7 +659,8 @@ def _fuse_sequential_loops_within_outer_loops( loops_to_be_fused.add(next_loop_in_queue) for succ in ldg.successors[next_loop_in_queue]: - if ldg.is_infusible.get((next_loop_in_queue, succ), False): + if (ldg.is_infusible.get((next_loop_in_queue, succ), False) + or force_infusible(next_loop_in_queue, succ)): non_fusible_loops.add(succ) else: queue.append(succ) @@ -796,8 +798,9 @@ def _get_partial_loop_nest_tree_for_fusion(kernel: LoopKernel): def get_kennedy_unweighted_fusion_candidates( kernel: LoopKernel, candidates: Collection[InameStr], - *, prefix: str = "ifused" - ) -> Mapping[InameStr, Collection[InameStr]]: + *, + force_infusible: Optional[Callable[[str, str], bool]] = None, + prefix: str = "ifused") -> Mapping[InameStr, Collection[InameStr]]: """ Returns the fusion candidates mapping that could be fed to :func:`rename_inames_in_batch` similar to Kennedy's unweighted @@ -837,6 +840,8 @@ def get_kennedy_unweighted_fusion_candidates( candidates = frozenset(candidates) vng = kernel.get_var_name_generator() + if force_infusible is None: + force_infusible = lambda x, y: False # noqa: E731 # {{{ implementation scope / sanity checks @@ -944,6 +949,7 @@ def get_kennedy_unweighted_fusion_candidates( outer_inames, vng, prefix, + force_infusible ) ) diff --git a/loopy/transform/precompute.py b/loopy/transform/precompute.py index e4e991cf1..46196b1e3 100644 --- a/loopy/transform/precompute.py +++ b/loopy/transform/precompute.py @@ -371,8 +371,9 @@ def map_kernel(self, kernel): dep_insn = kernel.id_to_insn[dep] if (frozenset(dep_insn.assignee_var_names()) & self.compute_read_variables): - self.compute_insn_depends_on.update( - insn.depends_on - excluded_insn_ids) + # self.compute_insn_depends_on.update( + # insn.depends_on - excluded_insn_ids) + self.compute_insn_depends_on.add(dep) new_insns.append(insn) diff --git a/loopy/transform/reduction.py b/loopy/transform/reduction.py new file mode 100644 index 000000000..8824dd1c1 --- /dev/null +++ b/loopy/transform/reduction.py @@ -0,0 +1,292 @@ +""" +.. currentmodule:: loopy + +.. autofunction:: hoist_invariant_multiplicative_terms_in_sum_reduction + +.. autofunction:: extract_multiplicative_terms_in_sum_reduction_as_subst +""" + +__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import pymbolic.primitives as p + +from typing import (FrozenSet, TypeVar, Callable, List, Tuple, Iterable, Union, Any, + Optional, Sequence) +from loopy.symbolic import IdentityMapper, Reduction, CombineMapper +from loopy.kernel import LoopKernel +from loopy.kernel.data import SubstitutionRule +from loopy.diagnostic import LoopyError + + +# {{{ partition (copied from more-itertools) + +Tpart = TypeVar("Tpart") + + +def partition(pred: Callable[[Tpart], bool], + iterable: Iterable[Tpart]) -> Tuple[List[Tpart], + List[Tpart]]: + """ + Use a predicate to partition entries into false entries and true + entries + """ + # Inspired from https://docs.python.org/3/library/itertools.html + # partition(is_odd, range(10)) --> 0 2 4 6 8 and 1 3 5 7 9 + from itertools import tee, filterfalse + t1, t2 = tee(iterable) + return list(filterfalse(pred, t1)), list(filter(pred, t2)) + +# }}} + + +# {{{ hoist_reduction_invariant_terms + +class EinsumTermsHoister(IdentityMapper): + """ + Mapper to hoist products out of a sum-reduction. + + .. attribute:: reduction_inames + + Inames of the reduction expressions to perform the hoisting. + """ + def __init__(self, reduction_inames: FrozenSet[str]): + super().__init__() + self.reduction_inames = reduction_inames + + # type-ignore-reason: super-class.map_reduction returns 'Any' + def map_reduction(self, expr: Reduction # type: ignore[override] + ) -> p.Expression: + if frozenset(expr.inames) != self.reduction_inames: + return super().map_reduction(expr) + + from loopy.library.reduction import SumReductionOperation + from loopy.symbolic import get_dependencies + if isinstance(expr.operation, SumReductionOperation): + if isinstance(expr.expr, p.Product): + from pymbolic.primitives import flattened_product + multiplicative_terms = (flattened_product(self.rec(expr.expr) + .children) + .children) + else: + multiplicative_terms = (expr.expr,) + + invariants, variants = partition(lambda x: (get_dependencies(x) + & self.reduction_inames), + multiplicative_terms) + if not variants: + # -> everything is invariant + return self.rec(expr.expr) * Reduction( + expr.operation, + inames=expr.inames, + expr=1, # FIXME: invalid dtype (not sure how?) + allow_simultaneous=expr.allow_simultaneous) + if not invariants: + # -> nothing to hoist + return Reduction( + expr.operation, + inames=expr.inames, + expr=self.rec(expr.expr), + allow_simultaneous=expr.allow_simultaneous) + + return p.Product(tuple(invariants)) * Reduction( + expr.operation, + inames=expr.inames, + expr=p.Product(tuple(variants)), + allow_simultaneous=expr.allow_simultaneous) + else: + return super().map_reduction(expr) + + +def hoist_invariant_multiplicative_terms_in_sum_reduction( + kernel: LoopKernel, + reduction_inames: Union[str, FrozenSet[str]], + within: Any = None +) -> LoopKernel: + """ + Hoists loop-invariant multiplicative terms in a sum-reduction expression. + + :arg reduction_inames: The inames over which reduction is performed that defines + the reduction expression that is to be transformed. + :arg within: A match expression understood by :func:`loopy.match.parse_match` + that specifies the instructions over which the transformation is to be + performed. + """ + from loopy.transform.instruction import map_instructions + if isinstance(reduction_inames, str): + reduction_inames = frozenset([reduction_inames]) + + if not (reduction_inames <= kernel.all_inames()): + raise ValueError(f"Some inames in '{reduction_inames}' not a part of" + " the kernel.") + + term_hoister = EinsumTermsHoister(reduction_inames) + + return map_instructions(kernel, + insn_match=within, + f=lambda x: x.with_transformed_expressions(term_hoister) + ) + +# }}} + + +# {{{ extract_multiplicative_terms_in_sum_reduction_as_subst + +class ContainsSumReduction(CombineMapper): + """ + Returns *True* only if the mapper maps over an expression containing a + SumReduction operation. + """ + def combine(self, values: Iterable[bool]) -> bool: + return any(values) + + # type-ignore-reason: super-class.map_reduction returns 'Any' + def map_reduction(self, expr: Reduction) -> bool: # type: ignore[override] + from loopy.library.reduction import SumReductionOperation + return (isinstance(expr.operation, SumReductionOperation) + or self.rec(expr.expr)) + + def map_variable(self, expr: p.Variable) -> bool: + return False + + def map_algebraic_leaf(self, expr: Any) -> bool: + return False + + +class MultiplicativeTermReplacer(IdentityMapper): + """ + Primary mapper of + :func:`extract_multiplicative_terms_in_sum_reduction_as_subst`. + """ + def __init__(self, + *, + terms_filter: Callable[[p.Expression], bool], + subst_name: str, + subst_arguments: Tuple[str, ...]) -> None: + self.subst_name = subst_name + self.subst_arguments = subst_arguments + self.terms_filter = terms_filter + super().__init__() + + # mutable state to record the expression collected by the terms_filter + self.collected_subst_rule: Optional[SubstitutionRule] = None + + # type-ignore-reason: super-class.map_reduction returns 'Any' + def map_reduction(self, expr: Reduction) -> Reduction: # type: ignore[override] + from loopy.library.reduction import SumReductionOperation + from loopy.symbolic import SubstitutionMapper + if isinstance(expr.operation, SumReductionOperation): + if self.collected_subst_rule is not None: + # => there was already a sum-reduction operation -> raise + raise ValueError("Multiple sum reduction expressions found -> not" + " allowed.") + + if isinstance(expr.expr, p.Product): + from pymbolic.primitives import flattened_product + terms = flattened_product(expr.expr.children).children + else: + terms = (expr.expr,) + + unfiltered_terms, filtered_terms = partition(self.terms_filter, terms) + submap = SubstitutionMapper({ + argument_expr: p.Variable(f"arg{i}") + for i, argument_expr in enumerate(self.subst_arguments)}.get) + self.collected_subst_rule = SubstitutionRule( + name=self.subst_name, + arguments=tuple(f"arg{i}" for i in range(len(self.subst_arguments))), + expression=submap(p.Product(tuple(filtered_terms)) + if filtered_terms + else 1) + ) + return Reduction( + expr.operation, + expr.inames, + p.Product((p.Variable(self.subst_name)(*self.subst_arguments), + *unfiltered_terms)), + expr.allow_simultaneous) + else: + return super().map_reduction(expr) + + +def extract_multiplicative_terms_in_sum_reduction_as_subst( + kernel: LoopKernel, + within: Any, + subst_name: str, + arguments: Sequence[p.Expression], + terms_filter: Callable[[p.Expression], bool], +) -> LoopKernel: + """ + Returns a copy of *kernel* with a new substitution named *subst_name* and + *arguments* as arguments for the aggregated multiplicative terms in a + sum-reduction expression. + + :arg within: A match expression understood by :func:`loopy.match.parse_match` + to specify the instructions over which the transformation is to be + performed. + :arg terms_filter: A callable to filter which terms of the sum-reduction + comprise the body of substitution rule. + :arg arguments: The sub-expressions of the product of the filtered terms that + form the arguments of the extract substitution rule in the same order. + + .. note:: + + A ``LoopyError`` is raised if none or more than 1 sum-reduction expression + appear in *within*. + """ + from loopy.match import parse_match + within = parse_match(within) + + matched_insns = [ + insn + for insn in kernel.instructions + if within(kernel, insn) and ContainsSumReduction()((insn.expression, + tuple(insn.predicates))) + ] + + if len(matched_insns) == 0: + raise LoopyError(f"No instructions found matching '{within}'" + " with sum-reductions found.") + if len(matched_insns) > 1: + raise LoopyError(f"More than one instruction found matching '{within}'" + " with sum-reductions found -> not allowed.") + + insn, = matched_insns + replacer = MultiplicativeTermReplacer(subst_name=subst_name, + subst_arguments=tuple(arguments), + terms_filter=terms_filter) + new_insn = insn.with_transformed_expressions(replacer) + new_rule = replacer.collected_subst_rule + new_substitutions = dict(kernel.substitutions).copy() + if subst_name in new_substitutions: + raise LoopyError(f"Kernel '{kernel.name}' already contains a substitution" + " rule named '{subst_name}'.") + assert new_rule is not None + new_substitutions[subst_name] = new_rule + + return kernel.copy(instructions=[new_insn if insn.id == new_insn.id else insn + for insn in kernel.instructions], + substitutions=new_substitutions) + +# }}} + + +# vim: foldmethod=marker diff --git a/loopy/transform/reindex.py b/loopy/transform/reindex.py new file mode 100644 index 000000000..3d4e7c562 --- /dev/null +++ b/loopy/transform/reindex.py @@ -0,0 +1,329 @@ +""" +.. currentmodule:: loopy + +.. autofunction:: reindex_temporary_using_seghir_loechner_scheme +""" + +__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + + +import islpy as isl +from typing import Union, Iterable, Tuple +# from loopy.typing import ExpressionT +from loopy.typing import Expression +from loopy.kernel import LoopKernel +from loopy.diagnostic import LoopyError +from loopy.symbolic import CombineMapper +from loopy.kernel.instruction import (MultiAssignmentBase, + CInstruction, BarrierInstruction) +from loopy.symbolic import RuleAwareIdentityMapper + + +ISLMapT = Union[isl.BasicMap, isl.Map] +ISLSetT = Union[isl.BasicSet, isl.Set] + + +def _add_prime_to_dim_names(isl_map: ISLMapT, + dts: Iterable[isl.dim_type]) -> ISLMapT: + """ + Returns a copy of *isl_map* with dims of types *dts* having their names + suffixed with an apostrophe (``'``). + + .. testsetup:: + + >>> import islpy as isl + >>> from loopy.transform.reindex import _add_prime_to_dim_names + + .. doctest:: + + >>> amap = isl.Map("{[i]->[j=2i]}") + >>> _add_prime_to_dim_names(amap, [isl.dim_type.in_, isl.dim_type.out]) + Map("{ [i'] -> [j' = 2i'] }") + """ + for dt in dts: + for idim in range(isl_map.dim(dt)): + old_name = isl_map.get_dim_name(dt, idim) + new_name = f"{old_name}'" + isl_map = isl_map.set_dim_name(dt, idim, new_name) + + return isl_map + + +def _get_seghir_loechner_reindexing_from_range(access_range: ISLSetT + ) -> Tuple[isl.PwQPolynomial, + isl.PwQPolynomial]: + """ + Returns ``(reindex_map, new_shape)``, where, + + * ``reindex_map`` is a quasi-polynomial of the form ``[i1, .., in] -> {f(i1, + .., in)}`` representing that an array indexed via the subscripts + ``[i1, ..,in]`` should be re-indexed into a 1-dimensional array as + ``f(i1, .., in)``. + * ``new_shape`` is a quasi-polynomial corresponding to the shape of the + re-indexed 1-dimensional array. + """ + + # {{{ create amap: an ISL map which is an identity map from access_map's range + + amap = isl.BasicMap.identity( + access_range + .space + .add_dims(isl.dim_type.in_, access_range.dim(isl.dim_type.out))) + + # set amap's dim names + for idim in range(amap.dim(isl.dim_type.in_)): + amap = amap.set_dim_name(isl.dim_type.in_, idim, + f"_lpy_in_{idim}") + amap = amap.set_dim_name(isl.dim_type.out, idim, + f"_lpy_out_{idim}") + + amap = amap.intersect_domain(access_range) + + # }}} + + n_in = amap.dim(isl.dim_type.out) + n_out = amap.dim(isl.dim_type.out) + + amap_lexmin = amap.lexmin() + primed_amap_lexmin = _add_prime_to_dim_names(amap_lexmin, [isl.dim_type.in_, + isl.dim_type.out]) + + lex_lt_map = isl.Map.lex_lt_map(primed_amap_lexmin, amap_lexmin) + + # make the lexmin map parametric in terms of it's previous access expressions. + lex_lt_set = (lex_lt_map + .move_dims(isl.dim_type.param, 0, isl.dim_type.out, 0, n_in) + .domain()) + + # {{{ initialize amap_to_count + + amap_to_count = _add_prime_to_dim_names(amap, [isl.dim_type.in_]) + amap_to_count = amap_to_count.insert_dims(isl.dim_type.param, 0, n_in) + + for idim in range(n_in): + amap_to_count = amap_to_count.set_dim_name( + isl.dim_type.param, idim, + amap.get_dim_name(isl.dim_type.in_, idim)) + + amap_to_count = amap_to_count.intersect_domain(lex_lt_set) + + # }}} + + result = amap_to_count.range().card() + + # {{{ simplify 'result' by gisting with 'access_range' + + aligned_access_range = access_range.move_dims(isl.dim_type.param, 0, + isl.dim_type.set, 0, n_out) + + for idim in range(result.dim(isl.dim_type.param)): + aligned_access_range = ( + aligned_access_range + .set_dim_name(isl.dim_type.param, idim, + result.space.get_dim_name(isl.dim_type.param, + idim))) + + result = result.gist_params(aligned_access_range.params()) + + # }}} + + return result, access_range.card() + + +class _IndexCollector(CombineMapper): + """ + A mapper that collects all instances of + :class:`pymbolic.primitives.Subscript` accessing :attr:`var_name`. + """ + def __init__(self, var_name): + super().__init__() + self.var_name = var_name + + def combine(self, values): + from functools import reduce + return reduce(frozenset.union, values, frozenset()) + + def map_subscript(self, expr): + if expr.aggregate.name == self.var_name: + return frozenset([expr]) | super().map_subscript(expr) + else: + return super().map_subscript(expr) + + def map_constant(self, expr): + return frozenset() + + map_variable = map_constant + map_function_symbol = map_constant + map_tagged_variable = map_constant + map_type_cast = map_constant + map_nan = map_constant + + +class ReindexingApplier(RuleAwareIdentityMapper): + def __init__(self, rule_mapping_context, + var_to_reindex, + reindexed_var_name, + new_index_expr, + index_names): + + super().__init__(rule_mapping_context) + + self.var_to_reindex = var_to_reindex + self.reindexed_var_name = reindexed_var_name + self.new_index_expr = new_index_expr + self.index_names = index_names + + def map_subscript(self, expr, expn_state): + if expr.aggregate.name != self.var_to_reindex: + return super().map_subscript(expr, expn_state) + + from loopy.symbolic import SubstitutionMapper + from pymbolic.mapper.substitutor import make_subst_func + from pymbolic.primitives import Subscript, Variable + + rec_indices = tuple(self.rec(idx, expn_state) for idx in expr.index_tuple) + + assert len(self.index_names) == len(rec_indices) + subst_func = make_subst_func({idx_name: rec_idx + for idx_name, rec_idx in zip(self.index_names, + rec_indices)}) + + return SubstitutionMapper(subst_func)( + Subscript(Variable(self.reindexed_var_name), + self.new_index_expr) + ) + + +def reindex_temporary_using_seghir_loechner_scheme(kernel: LoopKernel, + var_name: str, + ) -> LoopKernel: + """ + Returns a kernel with expressions of the form ``var_name[i1, .., in]`` + replaced with ``var_name_reindexed[f(i1, .., in)]`` where ``f`` is a + quasi-polynomial as outlined in [Seghir_2006]_. + """ + from loopy.transform.subst import expand_subst + from loopy.symbolic import (BatchedAccessMapMapper, pw_qpolynomial_to_expr, + SubstitutionRuleMappingContext) + + if var_name not in kernel.temporary_variables: + raise LoopyError(f"'{var_name}' not in temporary variable in kernel" + f" '{kernel.name}'.") + + # {{{ compute the access_range of *var_name* in *kernel* + + subst_kernel = expand_subst(kernel) + access_map_recorder = BatchedAccessMapMapper( + subst_kernel, + frozenset([var_name])) + + # access_exprs: Tuple[ExpressionT, ...] + access_exprs: Tuple[Expression, ...] + + for insn in subst_kernel.instructions: + if var_name in insn.dependency_names(): + if isinstance(insn, MultiAssignmentBase): + access_exprs = (insn.assignees, + insn.expression, + tuple(insn.predicates)) + elif isinstance(insn, (CInstruction, BarrierInstruction)): + access_exprs = tuple(insn.predicates) + else: + raise NotImplementedError(type(insn)) + + access_map_recorder(access_exprs, insn.within_inames) + + vng = kernel.get_var_name_generator() + new_var_name = vng(var_name+"_reindexed") + + access_range = access_map_recorder.get_access_range(var_name) + + del subst_kernel + del access_map_recorder + + # }}} + + subst, new_shape = _get_seghir_loechner_reindexing_from_range( + access_range) + + # {{{ simplify new_shape with the assumptions from kernel + + new_shape = new_shape.gist_params(kernel.assumptions) + + # }}} + + # {{{ update kernel.temporary_variables + + new_shape = new_shape.drop_unused_params() + + new_temps = dict(kernel.temporary_variables).copy() + new_temps[new_var_name] = new_temps.pop(var_name).copy( + name=new_var_name, + shape=pw_qpolynomial_to_expr(new_shape), + strides=None, + dim_tags=None, + dim_names=None, + ) + + kernel = kernel.copy(temporary_variables=new_temps) + + # }}} + + # {{{ perform the substitution i.e. reindex the accesses + + subst_expr = pw_qpolynomial_to_expr(subst) + subst_dim_names = tuple( + subst.space.get_dim_name(isl.dim_type.param, idim) + for idim in range(access_range.dim(isl.dim_type.out))) + assert not (set(subst_dim_names) & kernel.all_variable_names()) + + rule_mapping_context = SubstitutionRuleMappingContext(kernel.substitutions, + vng) + reindexing_mapper = ReindexingApplier(rule_mapping_context, + var_name, new_var_name, + subst_expr, subst_dim_names) + + def _does_access_var_name(kernel, insn, *args): + return var_name in insn.dependency_names() + + kernel = reindexing_mapper.map_kernel(kernel, + within=_does_access_var_name, + map_args=False, + map_tvs=False) + kernel = rule_mapping_context.finish_kernel(kernel) + + # }}} + + # Note: Distributing a piece of code that depends on loopy and distributes + # code that conditionally/unconditionally calls this routine does *NOT* + # become a derivative of GPLv2. Since, as per point (0) of GPLV2 a + # derivative is defined as: "a work containing the Program or a portion of + # it, either verbatim or with modifications and/or translated into another + # language." + # + # Loopy does *NOT* contain any portion of the barvinok library in it's + # source code. + + return kernel + +# vim: fdm=marker diff --git a/loopy/transform/subst.py b/loopy/transform/subst.py index 4d196de61..0971370c1 100644 --- a/loopy/transform/subst.py +++ b/loopy/transform/subst.py @@ -28,10 +28,14 @@ from typing import TYPE_CHECKING, cast from pymbolic import Expression, Variable, var +from pytools import ImmutableRecord, memoize_on_first_arg from loopy.diagnostic import LoopyError from loopy.kernel.function_interface import CallableKernel, ScalarCallable -from loopy.symbolic import RuleAwareIdentityMapper, SubstitutionRuleMappingContext +from loopy.symbolic import ( + RuleAwareIdentityMapper, + SubstitutionRuleMappingContext +) from loopy.transform.iname import remove_any_newly_unused_inames from loopy.translation_unit import TranslationUnit, for_each_kernel @@ -526,6 +530,7 @@ def _accesses_lhs( # {{{ expand_subst @for_each_kernel +@memoize_on_first_arg def expand_subst(kernel, within=None): """ Returns an instance of :class:`loopy.LoopKernel` with the substitutions diff --git a/requirements.txt b/requirements.txt index 0751539e6..dffe232bf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,5 @@ ply>=3.6 # Optional, for testing special math function scipy +# Optional, kanren-style relation helpers +git+https://github.com/pythological/kanren.git#egg=miniKanren diff --git a/test/test_loop_fusion.py b/test/test_loop_fusion.py index 29f615215..848ab0c6e 100644 --- a/test/test_loop_fusion.py +++ b/test/test_loop_fusion.py @@ -315,6 +315,7 @@ def test_loop_fusion_with_induced_dependencies_from_sibling_nests( <> tmp2[j] = j out1[i1] = tmp2[i1] out2[i2] = 2 * tmp1[i2] + """) """, ) ref_t_unit = t_unit diff --git a/test/test_pycuda_invoker.py b/test/test_pycuda_invoker.py new file mode 100644 index 000000000..8cf5bcad4 --- /dev/null +++ b/test/test_pycuda_invoker.py @@ -0,0 +1,305 @@ +__copyright__ = "Copyright (C) 2022 Kaushik Kulkarni" + +__license__ = """ +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +""" + +import sys +import numpy as np +import loopy as lp +import pytest +pytest.importorskip("pycuda") +import pycuda.gpuarray as cu_np +import itertools + +import logging +logger = logging.getLogger(__name__) + +try: + import faulthandler +except ImportError: + pass +else: + faulthandler.enable() + +from typing import Tuple, Any +from pycuda.tools import init_cuda_context_fixture +from loopy.version import LOOPY_USE_LANGUAGE_VERSION_2018_2 # noqa + + +@pytest.fixture(autouse=True) +def init_cuda_context(): + yield from init_cuda_context_fixture() + + +def get_random_array(rng, shape: Tuple[int, ...], dtype: np.dtype[Any]): + if np.issubdtype(dtype, np.complexfloating): + subdtype = np.empty(0, dtype=dtype).real.dtype + return (get_random_array(rng, shape, subdtype) + + dtype.type(1j) * get_random_array(rng, shape, subdtype)) + else: + assert np.issubdtype(dtype, np.floating) + return rng.random(shape, dtype=dtype) + + +@pytest.mark.parametrize("target", [lp.PyCudaTarget(), + lp.PyCudaWithPackedArgsTarget()]) +def test_pycuda_invoker(target): + m = 5 + n = 6 + + knl = lp.make_kernel( + "{[i, j]: 0<=i tmp[i] = sin(x[i]) + z[i] = 2 * tmp[i] + """, + target=target) + knl = lp.set_temporary_address_space(knl, "tmp", lp.AddressSpace.GLOBAL) + + evt, (out,) = knl(x=x, out_host=False) + np.testing.assert_allclose(2*np.sin(x), out.get(), rtol=1e-6) + + +@pytest.mark.parametrize("target", [lp.PyCudaTarget(), + lp.PyCudaWithPackedArgsTarget()]) +@pytest.mark.parametrize("dtype", [np.float32, np.float64]) +def test_multi_entrypoints(target, dtype): + rng = np.random.default_rng(seed=314) + x = rng.random(42, dtype=dtype) + + knl1 = lp.make_kernel( + "{[i]: 0<=i tmp[i] = 21*sin(x[i]) + 864.5*cos(y[i]) + z[i] = 2 * tmp[i] + """, + [lp.GlobalArg("x,y", + offset=lp.auto, shape=lp.auto), + ...], + target=target) + knl = lp.set_temporary_address_space(knl, "tmp", lp.AddressSpace.GLOBAL) + + evt, (out,) = knl(x=x, y=y) + np.testing.assert_allclose(42*np.sin(x) + 1729*np.cos(y), out) + + +@pytest.mark.parametrize("target", [lp.PyCudaTarget(), + lp.PyCudaWithPackedArgsTarget()]) +@pytest.mark.parametrize("dtype,rtol", [(np.complex64, 1e-6), + (np.complex128, 1e-14), + (np.float32, 1e-6), + (np.float64, 1e-14)]) +def test_sum_of_array(target, dtype, rtol): + # Reported by Mit Kotak + rng = np.random.default_rng(seed=0) + knl = lp.make_kernel( + "{[i]: 0 <= i < N}", + """ + out = sum(i, x[i]) + """, + target=target) + x = get_random_array(rng, (42,), np.dtype(dtype)) + evt, (out,) = knl(x=x) + np.testing.assert_allclose(np.sum(x), out, rtol=rtol) + + +@pytest.mark.parametrize("target", [lp.PyCudaTarget(), + lp.PyCudaWithPackedArgsTarget()]) +@pytest.mark.parametrize("dtype,rtol", [(np.complex64, 1e-6), + (np.complex128, 1e-14), + (np.float32, 1e-6), + (np.float64, 1e-14)]) +def test_int_pow(target, dtype, rtol): + rng = np.random.default_rng(seed=0) + knl = lp.make_kernel( + "{[i]: 0 <= i < N}", + """ + out[i] = x[i] ** i + """, + target=target) + x = get_random_array(rng, (10,), np.dtype(dtype)) + evt, (out,) = knl(x=x) + np.testing.assert_allclose(x ** np.arange(10, dtype=np.int32), out, + rtol=rtol) + + +@pytest.mark.parametrize("target", [lp.PyCudaTarget(), + lp.PyCudaWithPackedArgsTarget()]) +@pytest.mark.parametrize("dtype", [np.complex64, np.complex128, + np.float32, np.float64]) +@pytest.mark.parametrize("func", ["abs", "sqrt", + "sin", "cos", "tan", + "sinh", "cosh", "tanh", + "exp", "log", "log10"]) +def test_math_functions(target, dtype, func): + rng = np.random.default_rng(seed=0) + knl = lp.make_kernel( + "{[i]: 0 <= i < N}", + f""" + y[i] = {func}(x[i]) + """, + target=target) + x = get_random_array(rng, (42,), np.dtype(dtype)) + _, (out,) = knl(x=x) + np.testing.assert_allclose(getattr(np, func)(x), + out, rtol=1e-6) + + +def test_pycuda_packargs_tgt_avoids_param_space_overflow(): + from pymbolic.primitives import Sum + from loopy.symbolic import parse + + nargs = 1_000 + rng = np.random.default_rng(32) + knl = lp.make_kernel( + "{[i]: 0<=i 1: + exec(sys.argv[1]) + else: + from pytest import main + main([__file__]) + +# vim: foldmethod=marker diff --git a/test/test_transform.py b/test/test_transform.py index 592004f9c..625cf5e75 100644 --- a/test/test_transform.py +++ b/test/test_transform.py @@ -1641,6 +1641,156 @@ def test_concatenate_arrays(ctx_factory: cl.CtxFactory): lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit) +def test_reindexing_strided_access(ctx_factory): + import islpy as isl + + if not hasattr(isl.Set, "card"): + pytest.skip("No barvinok support") + + ctx = ctx_factory() + + tunit = lp.make_kernel( + "{[i, j]: 0<=j,i<10}", + """ + <> tmp[2*i, 2*j] = a[i, j] + out[i, j] = tmp[2*i, 2*j]**2 + """) + + tunit = lp.add_dtypes(tunit, {"a": "float64"}) + ref_tunit = tunit + + knl = lp.reindex_temporary_using_seghir_loechner_scheme(tunit.default_entrypoint, + "tmp") + tunit = tunit.with_kernel(knl) + + tv, = tunit.default_entrypoint.temporary_variables.values() + assert tv.shape == (100,) + + lp.auto_test_vs_ref(ref_tunit, ctx, tunit) + + +def test_reindexing_figurate(ctx_factory): + import islpy as isl + + if not hasattr(isl.Set, "card"): + pytest.skip("No barvinok support") + + ctx = ctx_factory() + + tunit = lp.make_kernel( + "{[i, j]: 0<=j<=i<10}", + """ + <> tmp[2*i, 2*j] = a[i, j] + out[i, j] = tmp[2*i, 2*j]**2 + """) + + tunit = lp.add_dtypes(tunit, {"a": "float64"}) + ref_tunit = tunit + + knl = lp.reindex_temporary_using_seghir_loechner_scheme(tunit.default_entrypoint, + "tmp") + tunit = tunit.with_kernel(knl) + + tv, = tunit.default_entrypoint.temporary_variables.values() + assert tv.shape == (55,) + + lp.auto_test_vs_ref(ref_tunit, ctx, tunit) + + +def test_reindexing_figurate_parametric_shape(ctx_factory): + import islpy as isl + from loopy.symbolic import parse + + if not hasattr(isl.Set, "card"): + pytest.skip("No barvinok support") + + ctx = ctx_factory() + + tunit = lp.make_kernel( + "{[i, j]: 0<=j<=i tmp[i, j] = a[i, j] + out[i, j] = tmp[i, j]**2 + """, + assumptions="n > 0", + ) + + tunit = lp.add_dtypes(tunit, {"a": "float64"}) + tunit = lp.set_temporary_address_space(tunit, "tmp", + lp.AddressSpace.GLOBAL) + ref_tunit = tunit + + knl = lp.reindex_temporary_using_seghir_loechner_scheme(tunit.default_entrypoint, + "tmp") + tunit = tunit.with_kernel(knl) + + tv, = tunit.default_entrypoint.temporary_variables.values() + assert tv.shape == (parse("(n + n**2) // 2"),) + + lp.auto_test_vs_ref(ref_tunit, ctx, tunit, parameters={"n": 20}) + + +def test_sum_redn_algebraic_transforms(ctx_factory): + from pymbolic import variables + from loopy.symbolic import Reduction + + t_unit = lp.make_kernel( + "{[e,i,j,x,r]: 0<=e