Skip to content

Commit 1fe2b4c

Browse files
committed
Add an experimental hostcall interface.
1 parent 1fd86b7 commit 1fe2b4c

File tree

8 files changed

+490
-3
lines changed

8 files changed

+490
-3
lines changed

src/CUDA.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ include("device/utils.jl")
5858
include("device/pointer.jl")
5959
include("device/array.jl")
6060
include("device/intrinsics.jl")
61+
include("device/hostcall.jl")
6162
include("device/runtime.jl")
6263
include("device/texture.jl")
6364
include("device/random.jl")
@@ -75,8 +76,9 @@ export CUPTI, NVTX
7576

7677
# compiler implementation
7778
include("compiler/gpucompiler.jl")
78-
include("compiler/execution.jl")
7979
include("compiler/exceptions.jl")
80+
include("compiler/hostcall.jl")
81+
include("compiler/execution.jl")
8082
include("compiler/reflection.jl")
8183

8284
# array implementation

src/compiler/execution.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,10 @@ end
453453

454454
# create the kernel state object
455455
exception_ptr = create_exceptions!(mod)
456-
state = KernelState(exception_ptr)
456+
pool = hostcall_pool(ctx)
457+
state = KernelState(exception_ptr,
458+
reinterpret(LLVMPtr{UInt32, AS.Global}, pointer(pool.pointers)),
459+
reinterpret(LLVMPtr{Hostcall, AS.Global}, pointer(pool.calls)))
457460

458461
return HostKernel{typeof(job.source.f),job.source.tt}(job.source.f, ctx, mod, fun, state)
459462
end

src/compiler/hostcall.jl

Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
# host-side functionality for receiving method calls from the GPU
2+
3+
const HOSTCALL_POOL_SIZE = UInt32(1024*16) # ~64MB
4+
# ring buffer helpers assume pow2
5+
@assert ispow2(HOSTCALL_POOL_SIZE)
6+
# we should be able to request slots for a full warp, or we would deadlock
7+
@assert HOSTCALL_POOL_SIZE >= 32
8+
# head and tail pointers can exceed HOSTCALL_POOL_SIZE, so overflow behaviour should match
9+
@assert (typemax(UInt32)+1)%HOSTCALL_POOL_SIZE == 0
10+
11+
struct HostcallPool
12+
context::CuContext
13+
14+
# mapped host storage for ring buffer pointers
15+
#
16+
# we can't perform operations that are atomic wrt. both the CPU and GPU, only wrt. to
17+
# a single device, but that's okay as the tail pointer is only moved by the CPU, while
18+
# the head pointer is only moved by the GPU. stale reads from either device will only
19+
# result in under-estimated capacities.
20+
pointer_buf::Mem.HostBuffer
21+
pointers::Vector{UInt32} # [head, tail], 0-indexed for simplified modulo arithmetic
22+
23+
# mapped host storage for actual hostcall objects
24+
call_buf::Mem.HostBuffer
25+
calls::Vector{Hostcall}
26+
end
27+
28+
# small helpers for pow2 ring buffer management.
29+
# - the head is where the producer inserts, the tail is where the consumer reads
30+
# - tail == head indicates an empty buffer
31+
# - head and tail pointers can be 0 or 1 indexed, and do not need to fall within size bounds
32+
ring_count(head, tail, size) = (head - tail) & (size-1)
33+
ring_space(head, tail, size) = ring_count(tail, head+1, size)
34+
# NOTE: one item is left unused, as a full buffer means head==tail which also means empty
35+
36+
# create and return the hostcall pool for each context
37+
const hostcall_pools = Dict{CuContext, HostcallPool}()
38+
hostcall_pool(ctx::CuContext) = get!(hostcall_pools, ctx) do
39+
@context! ctx begin
40+
# NOTE: we allocate the host memory manually, instead of just registering an array,
41+
# to avoid accidentally re-registering a memory range.
42+
pointer_buf = Mem.alloc(Mem.Host, 2*sizeof(UInt32), Mem.HOSTALLOC_DEVICEMAP)
43+
pointer_ptr = convert(Ptr{UInt32}, pointer_buf)
44+
pointers = unsafe_wrap(Array, pointer_ptr, 2)
45+
fill!(pointers, 0)
46+
47+
call_buf = Mem.alloc(Mem.Host, HOSTCALL_POOL_SIZE*sizeof(Hostcall), Mem.HOSTALLOC_DEVICEMAP)
48+
call_ptr = convert(Ptr{Hostcall}, call_buf)
49+
calls = unsafe_wrap(Array, call_ptr, HOSTCALL_POOL_SIZE)
50+
51+
pool = HostcallPool(ctx, pointer_buf, pointers, call_buf, calls)
52+
marker = Threads.Atomic{Int}(0)
53+
54+
watcher = @async begin
55+
while isvalid(ctx)
56+
Base.invokelatest(check_hostcalls, pool)
57+
marker[] = 1
58+
sleep(0.1)
59+
end
60+
end
61+
VERSION >= v"1.7-" && errormonitor(watcher)
62+
63+
hostcall_markers[ctx] = marker
64+
return pool
65+
end
66+
end
67+
68+
# wait for all hostcalls to complete.
69+
# XXX: add to `synchronize()`?
70+
const hostcall_markers = Dict{CuContext, Threads.Atomic{Int}}()
71+
function hostcall_synchronize(ctx::CuContext=context())
72+
haskey(hostcall_pools, ctx) || return
73+
marker = hostcall_markers[ctx]
74+
marker[] = 0
75+
while marker[] == 0
76+
sleep(0.1)
77+
end
78+
return
79+
end
80+
81+
# check whether a pool has any outstanding hostcalls, and execute them
82+
function check_hostcalls(pool::HostcallPool)
83+
head0, tail0 = pool.pointers
84+
while ring_count(head0, tail0, HOSTCALL_POOL_SIZE) >= 1
85+
slot = tail0 & (HOSTCALL_POOL_SIZE - 0x1) + 0x1
86+
hostcall = pool.calls[slot]
87+
hostcall_ptr = pointer(pool.calls, slot)
88+
89+
if hostcall.state == HOSTCALL_SUBMITTED
90+
# Setfield.jl chokes on the 4k tuple, so we manually create pointers to fields.
91+
state_ptr = reinterpret(Ptr{HostcallState}, hostcall_ptr) + fieldoffset(Hostcall, 1)
92+
buffer_ptr = hostcall_ptr + fieldoffset(Hostcall, fieldcount(Hostcall))
93+
94+
try
95+
sig, rettyp = hostcall_targets[hostcall.target]
96+
# function barrier for specialization
97+
state = process_hostcall(sig, rettyp, buffer_ptr)
98+
unsafe_store!(state_ptr, state)
99+
catch ex
100+
Base.display_error(ex, catch_backtrace())
101+
unsafe_store!(state_ptr, HOSTCALL_READY)
102+
end
103+
end
104+
105+
tail0 += 0x1
106+
pool.pointers[2] = tail0
107+
end
108+
end
109+
110+
@inline @generated function read_hostcall_arguments(ptr, sig)
111+
args = []
112+
last_offset = 0
113+
for typ in sig.parameters
114+
sz = sizeof(typ)
115+
arg = if sz > 0
116+
align = Base.datatype_alignment(typ)
117+
offset = Base.cld(last_offset, align) * align
118+
last_offset = offset + sz
119+
if last_offset > HOSTCALL_BUFFER_SIZE
120+
return :(error("hostcall arguments exceed maximum buffer size"))
121+
end
122+
:(unsafe_load(reinterpret(Ptr{$typ}, ptr+$offset)))
123+
else
124+
:($(typ.instance))
125+
end
126+
push!(args, arg)
127+
end
128+
129+
quote
130+
($(args...))
131+
end
132+
end
133+
134+
@noinline function process_hostcall(sig::Type{T}, rettyp::Type{U}, buffer_ptr) where {T,U}
135+
f, args... = read_hostcall_arguments(buffer_ptr, sig)
136+
rv = Base.invokelatest(f, args...)::rettyp
137+
138+
if rettyp === Nothing
139+
HOSTCALL_READY
140+
else
141+
# store the return type
142+
if sizeof(rettyp) > HOSTCALL_BUFFER_SIZE
143+
error("hostcall return value exceeds maximum buffer size")
144+
end
145+
unsafe_store!(reinterpret(Ptr{rettyp}, buffer_ptr), rv)
146+
HOSTCALL_RETURNED
147+
end
148+
end

0 commit comments

Comments
 (0)