Skip to content

Commit 2d756c4

Browse files
Kernel Abstractions Extension Improvements (#78)
* Add runtime generated KA support * Add KA tests * Remove kwargs concrete_input_type and closures_size and related functions * Remove return_type from FunctionCalls
1 parent 9d5d094 commit 2d756c4

File tree

23 files changed

+459
-460
lines changed

23 files changed

+459
-460
lines changed

.github/workflows/unit_tests.yml

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,16 @@ jobs:
1818
julia-version: ['1.10', '1.11', '1.12']
1919
julia-arch: [x64]
2020
os: [ubuntu-latest]
21+
backend:
22+
- { container: "ubuntu:24.04", cpu: "1", kacpu: "0", cuda: "0", amdgpu: "0", oneapi: "0", metal: "0" }
23+
- { container: "ubuntu:24.04", cpu: "0", kacpu: "1", cuda: "0", amdgpu: "0", oneapi: "0", metal: "0" }
24+
# large runners with gpu hardware cost money
25+
#- { container: "nvidia/cuda:12.9.1-devel-ubuntu24.04", cpu: "0", kacpu: "0", cuda: "1", amdgpu: "0", oneapi: "0", metal: "0" }
26+
#- { container: "rocm/dev-ubuntu-24.04:6.4.2", cpu: "0", kacpu: "0", cuda: "0", amdgpu: "1", oneapi: "0", metal: "0" }
27+
#- { container: "intel/oneapi-hpckit:2025.2.0-0-devel-ubuntu24.04", cpu: "0", kacpu: "0", cuda: "0", amdgpu: "0", oneapi: "1", metal: "0" }
28+
29+
#container:
30+
# image: ${{matrix.backend.container}}
2131

2232
steps:
2333
- name: Checkout repository
@@ -35,7 +45,14 @@ jobs:
3545
uses: julia-actions/cache@v2
3646

3747
- name: Instantiate
38-
run: julia --project=./ -e 'using Pkg; Pkg.instantiate()'
48+
run: |
49+
julia --project=./ -e 'using Pkg; Pkg.instantiate()'
50+
echo "TEST_CPU=${{ matrix.backend.cpu }}" >> $GITHUB_ENV
51+
echo "TEST_KACPU=${{ matrix.backend.kacpu }}" >> $GITHUB_ENV
52+
echo "TEST_CUDA=${{ matrix.backend.cuda }}" >> $GITHUB_ENV
53+
echo "TEST_AMDGPU=${{ matrix.backend.amdgpu }}" >> $GITHUB_ENV
54+
echo "TEST_ONEAPI=${{ matrix.backend.oneapi }}" >> $GITHUB_ENV
55+
echo "TEST_METAL=${{ matrix.backend.metal }}" >> $GITHUB_ENV
3956
4057
- name: Run tests
4158
uses: julia-actions/julia-runtest@v1

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,11 @@ julia = "1.10"
3535
oneAPI = "1, 2"
3636

3737
[extras]
38+
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
39+
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
3840
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
3941
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
4042
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
4143

4244
[targets]
43-
test = ["SafeTestsets", "Test", "StatsBase"]
45+
test = ["SafeTestsets", "Test", "Pkg", "StatsBase", "KernelAbstractions"]

benchmark/QEDFD.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ using QEDprocesses
55
using QEDcore
66
using QEDbase
77

8-
RNG = Xoshiro(1)
8+
RNG = Xoshiro(143)
99
MODEL = PerturbativeQED()
1010
PROC = ScatteringProcess(
1111
(Electron(), Photon()),
@@ -20,7 +20,7 @@ PSP = PhaseSpacePoint(PROC, MODEL, INPSL, tuple(rand(SFourMomentum, number_incom
2020
@show g
2121

2222
@info "Building the function"
23-
@time f = compute_function(g, PROC, cpu_st(), @__MODULE__; closures_size = 100, concrete_input_type = typeof(PSP));
23+
@time f = compute_function(g, PROC, cpu_st(), @__MODULE__)
2424

2525
#=@info "Writing llvm code"
2626
@time open("llvm.out", write = true) do file

docs/make.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ pages = [
4141
"Code Generation" => "lib/internals/code_gen.md",
4242
"Devices" => "lib/internals/devices.md",
4343
"Utility" => "lib/internals/utility.md",
44+
"KernelAbstractions Extension" => "lib/internals/ka_extension.md",
4445
],
4546
"Contribution" => "contribution.md",
4647
]

docs/src/lib/internals/code_gen.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ Order = [:type, :constant, :function]
88
```
99

1010
## Function Generation
11-
Implementations for generation of a callable function. A function generated this way cannot immediately be called. One Julia World Age has to pass before this is possible, which happens when the global Julia scope advances. If the DAG and therefore the generated function becomes too large, use the tape machine instead, since compiling large functions becomes infeasible.
1211
```@autodocs
1312
Modules = [ComputableDAGs]
1413
Pages = ["code_gen/function.jl"]
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
# Kernel Abstractions Extension
2+
3+
```@autodocs
4+
Modules = [ComputableDAGs]
5+
Pages = ["ext/KernelAbstractionsExt.jl"]
6+
Order = [:function]
7+
```
8+
9+
## Kernel Wrapping
10+
```@autodocs
11+
Modules = [ComputableDAGs]
12+
Pages = ["ext/kernel_wrapper.jl"]
13+
Order = [:type, :function]
14+
```

docs/src/manual.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ When the CDAG is ready, it can be compiled for the machine you're running on.
3131

3232
Now, [`compute_function`](@ref) can be used to create a function that can be called on inputs. [`compute_function`](@ref) supports and in some cases requires keyword arguments, please refer to its documentation for more information.
3333

34-
Alternatively, GPU kernels can be generated by using [`kernel`](@ref) instead of [`compute_function`](@ref). This is implemented for several GPU backends and produces a regular function for the given backend. Since RuntimeGeneratedFunctions.jl does not support GPU kernels at this time, this function will only be callable if the world age has been increased since its generation. Furthermore, the compute functions in the graph need to comply with all the normal requirements for GPU kernels, such as not calling dynamic functions.
34+
Alternatively, [KernelAbstractions](https://juliagpu.github.io/KernelAbstractions.jl/stable/) kernels can be generated by using [`kernel`](@ref) instead of [`compute_function`](@ref). The returned value is a KernelAbstractions kernel object that can be called like any such kernel by giving it a backend and block size. The compute functions in the graph need to comply with all the normal requirements for GPU kernels, such as not calling dynamic functions. For more details, refer to the function's docs.
3535

3636
## Application repositories
3737

ext/KernelAbstractionsExt.jl

Lines changed: 23 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,34 @@
11
module KernelAbstractionsExt
22

33
using ComputableDAGs
4+
using KernelAbstractions
45
using UUIDs
56
using Random
67

7-
function ComputableDAGs.kernel(graph::DAG, instance, context_module::Module)
8-
machine = cpu_st()
9-
tape = ComputableDAGs.gen_tape(graph, instance, machine, context_module)
8+
include("kernel_wrapper.jl")
109

10+
function ComputableDAGs.init_kernel(mod::Module)
11+
mod.eval(Meta.parse("@kernel inbounds = true function _ka_broadcast!(@Const(in::AbstractVector), out::AbstractVector, val::Val)
12+
id = @index(Global)
13+
@inline out[id] = _compute_expr(in[id], val)
14+
end"))
15+
return nothing
16+
end
17+
18+
function ComputableDAGs.kernel(dag::DAG, instance, context_module::Module)
19+
tape = ComputableDAGs.gen_tape(dag, instance, cpu_st(), ComputableDAGs.GreedyScheduler())
20+
21+
code = ComputableDAGs.gen_function_body(tape)
1122
assign_inputs = Expr(:block, ComputableDAGs.expr_from_fc.(tape.input_assign_code)...)
12-
# TODO: use gen_function_body here
13-
code = Expr(:block, ComputableDAGs.expr_from_fc.(tape.schedule)...)
14-
15-
function_id = ComputableDAGs.to_var_name(UUIDs.uuid1(TaskLocalRNG()))
16-
expr = Meta.parse(
17-
"@kernel function compute_$(function_id)(input_vector, output_vector)
18-
id = @index(Global)
19-
@inline input = input_vector[id]
20-
$(assign_inputs)
21-
$code
22-
@inline output_vector[id] = $(tape.output_symbol)
23-
end"
24-
)
25-
26-
return expr
23+
24+
expr = Expr(:block, assign_inputs, code, :(return $(tape.output_symbol)))
25+
26+
# generate random UUID for type independent lookup in the expression cache
27+
val = Val(UUIDs.uuid1(TaskLocalRNG()))
28+
getfield(context_module, ComputableDAGs.EXPR_SYM)[val] = expr
29+
30+
# wrap the kernel together with the generated Val{UUID} to opaquely insert it for the caller later
31+
return KAWrapper(context_module._ka_broadcast!, val)
2732
end
2833

2934
end

ext/kernel_wrapper.jl

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
"""
2+
KAWrapper{T, ID}
3+
4+
A wrapper around a KernelAbstractions kernel. Takes the `kernel::T` and an `ID::Val`.
5+
6+
This is necessary to insert the id to the KernelAbstractions kernel without needing the user to do it manually.
7+
The Val itself is necessary to be able to define multiple different kernels working on the same input type. It is used in the expression cache as the key, and dispatched on in the `@generated` function.
8+
"""
9+
struct KAWrapper{T, ID}
10+
kernel::T
11+
id::ID
12+
end
13+
14+
"""
15+
KAWrapperKernel{T, ID, Args, KWArgs}
16+
17+
The second level of wrapping, to imitate the way that KernelAbstractions kernels are called: `kernel(<kernel config/backend>)(<runtime arguments>)`.
18+
"""
19+
struct KAWrapperKernel{T, ID, Args, KWArgs}
20+
kernel::T
21+
id::ID
22+
args::Args
23+
kwargs::KWArgs
24+
end
25+
26+
# initial level, args and kwargs are the kernel config, stored in the KAWrapperKernel
27+
@inline function (k::KAWrapper{T, ID})(args...; kwargs...) where {T, ID}
28+
return KAWrapperKernel(k.kernel, k.id, args, kwargs)
29+
end
30+
31+
# second level, wraps the actual call, inserting the kernel config args/kwargs, and calling with the runtime args + the stored id
32+
@inline function (k::KAWrapperKernel{T, ID, Args})(args...; kwargs...) where {T, ID, Args}
33+
k.kernel(k.args...; k.kwargs...)(args..., k.id; kwargs...)
34+
return nothing
35+
end

src/code_gen/function.jl

Lines changed: 36 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -1,46 +1,22 @@
11
"""
2-
compute_function(
2+
compute_function_expr(
33
dag::DAG,
44
instance,
55
machine::Machine,
6-
context_module::Module
6+
scheduler::AbstractScheduler
77
)
88
9-
Return a function of signature `compute_<id>(input::input_type(instance))`, which will return the result of the DAG computation on the given input.
10-
The final argument `context_module` should always be `@__MODULE__` to be able to use functions defined in the caller's environment.
11-
12-
## Keyword Arguments
13-
14-
`closures_size` (default=0 (off)): The size of closures to use in the main generated code. This specifies the size of code blocks across which the
15-
compiler cannot optimize. For sufficiently large functions, a larger value means longer compile times but potentially faster execution time.
16-
**Note** that the actually used closure size might be different than the one passed here, since the function automatically chooses a size that
17-
is close to a n-th root of the total number of loc, based off the given size.
18-
`concrete_input_type` (default=`input_type(instance)`): A type that will be used as the expected input type of the generated function. If
19-
omitted, the `input_type` of the problem instance is used. Note that the `input_type` of the instance will still be used as the annotated
20-
type in the generated function header.
9+
Helper function, returning the complete function expression.
2110
"""
22-
function compute_function(
11+
function compute_function_expr(
2312
dag::DAG,
2413
instance,
2514
machine::Machine,
26-
context_module::Module;
27-
closures_size::Int = 0,
28-
concrete_input_type::Type = Nothing,
15+
scheduler::AbstractScheduler
2916
)
30-
global INITIALIZED_MODULES
31-
if !(context_module in INITIALIZED_MODULES)
32-
RuntimeGeneratedFunctions.init(context_module)
33-
push!(INITIALIZED_MODULES, context_module)
34-
end
35-
36-
tape = gen_tape(dag, instance, machine, context_module)
17+
tape = gen_tape(dag, instance, machine, scheduler)
3718

38-
code = gen_function_body(
39-
tape,
40-
context_module;
41-
closures_size = closures_size,
42-
concrete_input_type = concrete_input_type,
43-
)
19+
code = gen_function_body(tape)
4420
assign_inputs = Expr(:block, expr_from_fc.(tape.input_assign_code)...)
4521

4622
function_id = to_var_name(UUIDs.uuid1(TaskLocalRNG()))
@@ -60,5 +36,34 @@ function compute_function(
6036
), # function body
6137
)
6238

39+
return expr
40+
end
41+
42+
"""
43+
compute_function(
44+
dag::DAG,
45+
instance,
46+
machine::Machine,
47+
context_module::Module,
48+
scheduler::AbstractScheduler = GreedyScheduler(),
49+
)
50+
51+
Return a function of signature `compute_<id>(input::input_type(instance))`, which will return the result of the DAG computation on the given input.
52+
The final argument `context_module` should always be `@__MODULE__` to be able to use functions defined in the caller's environment.
53+
"""
54+
function compute_function(
55+
dag::DAG,
56+
instance,
57+
machine::Machine,
58+
context_module::Module,
59+
scheduler::AbstractScheduler = GreedyScheduler()
60+
)
61+
global INITIALIZED_MODULES
62+
if !(context_module in INITIALIZED_MODULES)
63+
RuntimeGeneratedFunctions.init(context_module)
64+
push!(INITIALIZED_MODULES, context_module)
65+
end
66+
67+
expr = compute_function_expr(dag, instance, machine, scheduler)
6368
return invokelatest(RuntimeGeneratedFunction, @__MODULE__, context_module, expr)
6469
end

0 commit comments

Comments
 (0)