|
| 1 | +# host-side functionality for receiving method calls from the GPU |
| 2 | + |
| 3 | +const HOSTCALL_POOL_SIZE = UInt32(1024*16) # ~64MB |
| 4 | +# ring buffer helpers assume pow2 |
| 5 | +@assert ispow2(HOSTCALL_POOL_SIZE) |
| 6 | +# we should be able to request slots for a full warp, or we would deadlock |
| 7 | +@assert HOSTCALL_POOL_SIZE >= 32 |
| 8 | +# head and tail pointers can exceed HOSTCALL_POOL_SIZE, so overflow behaviour should match |
| 9 | +@assert (typemax(UInt32)+1)%HOSTCALL_POOL_SIZE == 0 |
| 10 | + |
| 11 | +struct HostcallPool |
| 12 | + context::CuContext |
| 13 | + |
| 14 | + # mapped host storage for ring buffer pointers |
| 15 | + # |
| 16 | + # we can't perform operations that are atomic wrt. both the CPU and GPU, only wrt. to |
| 17 | + # a single device, but that's okay as the tail pointer is only moved by the CPU, while |
| 18 | + # the head pointer is only moved by the GPU. stale reads from either device will only |
| 19 | + # result in under-estimated capacities. |
| 20 | + pointer_buf::Mem.HostBuffer |
| 21 | + pointers::Vector{UInt32} # [head, tail], 0-indexed for simplified modulo arithmetic |
| 22 | + |
| 23 | + # mapped host storage for actual hostcall objects |
| 24 | + call_buf::Mem.HostBuffer |
| 25 | + calls::Vector{Hostcall} |
| 26 | +end |
| 27 | + |
| 28 | +# small helpers for pow2 ring buffer management. |
| 29 | +# - the head is where the producer inserts, the tail is where the consumer reads |
| 30 | +# - tail == head indicates an empty buffer |
| 31 | +# - head and tail pointers can be 0 or 1 indexed, and do not need to fall within size bounds |
| 32 | +ring_count(head, tail, size) = (head - tail) & (size-1) |
| 33 | +ring_space(head, tail, size) = ring_count(tail, head+1, size) |
| 34 | +# NOTE: one item is left unused, as a full buffer means head==tail which also means empty |
| 35 | + |
| 36 | +# create and return the hostcall pool for each context |
| 37 | +const hostcall_pools = Dict{CuContext, HostcallPool}() |
| 38 | +hostcall_pool(ctx::CuContext) = get!(hostcall_pools, ctx) do |
| 39 | + @context! ctx begin |
| 40 | + # NOTE: we allocate the host memory manually, instead of just registering an array, |
| 41 | + # to avoid accidentally re-registering a memory range. |
| 42 | + pointer_buf = Mem.alloc(Mem.Host, 2*sizeof(UInt32), Mem.HOSTALLOC_DEVICEMAP) |
| 43 | + pointer_ptr = convert(Ptr{UInt32}, pointer_buf) |
| 44 | + pointers = unsafe_wrap(Array, pointer_ptr, 2) |
| 45 | + fill!(pointers, 0) |
| 46 | + |
| 47 | + call_buf = Mem.alloc(Mem.Host, HOSTCALL_POOL_SIZE*sizeof(Hostcall), Mem.HOSTALLOC_DEVICEMAP) |
| 48 | + call_ptr = convert(Ptr{Hostcall}, call_buf) |
| 49 | + calls = unsafe_wrap(Array, call_ptr, HOSTCALL_POOL_SIZE) |
| 50 | + |
| 51 | + pool = HostcallPool(ctx, pointer_buf, pointers, call_buf, calls) |
| 52 | + marker = Threads.Atomic{Int}(0) |
| 53 | + |
| 54 | + watcher = @async begin |
| 55 | + while isvalid(ctx) |
| 56 | + Base.invokelatest(check_hostcalls, pool) |
| 57 | + marker[] = 1 |
| 58 | + sleep(0.1) |
| 59 | + end |
| 60 | + end |
| 61 | + VERSION >= v"1.7-" && errormonitor(watcher) |
| 62 | + |
| 63 | + hostcall_markers[ctx] = marker |
| 64 | + return pool |
| 65 | + end |
| 66 | +end |
| 67 | + |
| 68 | +# wait for all hostcalls to complete. |
| 69 | +# XXX: add to `synchronize()`? |
| 70 | +const hostcall_markers = Dict{CuContext, Threads.Atomic{Int}}() |
| 71 | +function hostcall_synchronize(ctx::CuContext=context()) |
| 72 | + haskey(hostcall_pools, ctx) || return |
| 73 | + marker = hostcall_markers[ctx] |
| 74 | + marker[] = 0 |
| 75 | + while marker[] == 0 |
| 76 | + sleep(0.1) |
| 77 | + end |
| 78 | + return |
| 79 | +end |
| 80 | + |
| 81 | +# check whether a pool has any outstanding hostcalls, and execute them |
| 82 | +function check_hostcalls(pool::HostcallPool) |
| 83 | + head0, tail0 = pool.pointers |
| 84 | + while ring_count(head0, tail0, HOSTCALL_POOL_SIZE) >= 1 |
| 85 | + slot = tail0 & (HOSTCALL_POOL_SIZE - 0x1) + 0x1 |
| 86 | + hostcall = pool.calls[slot] |
| 87 | + hostcall_ptr = pointer(pool.calls, slot) |
| 88 | + |
| 89 | + if hostcall.state == HOSTCALL_SUBMITTED |
| 90 | + # Setfield.jl chokes on the 4k tuple, so we manually create pointers to fields. |
| 91 | + state_ptr = reinterpret(Ptr{HostcallState}, hostcall_ptr) + fieldoffset(Hostcall, 1) |
| 92 | + buffer_ptr = hostcall_ptr + fieldoffset(Hostcall, fieldcount(Hostcall)) |
| 93 | + |
| 94 | + try |
| 95 | + sig, rettyp = hostcall_targets[hostcall.target] |
| 96 | + # function barrier for specialization |
| 97 | + state = process_hostcall(sig, rettyp, buffer_ptr) |
| 98 | + unsafe_store!(state_ptr, state) |
| 99 | + catch ex |
| 100 | + Base.display_error(ex, catch_backtrace()) |
| 101 | + unsafe_store!(state_ptr, HOSTCALL_READY) |
| 102 | + end |
| 103 | + end |
| 104 | + |
| 105 | + tail0 += 0x1 |
| 106 | + pool.pointers[2] = tail0 |
| 107 | + end |
| 108 | +end |
| 109 | + |
| 110 | +@inline @generated function read_hostcall_arguments(ptr, sig) |
| 111 | + args = [] |
| 112 | + last_offset = 0 |
| 113 | + for typ in sig.parameters |
| 114 | + sz = sizeof(typ) |
| 115 | + arg = if sz > 0 |
| 116 | + align = Base.datatype_alignment(typ) |
| 117 | + offset = Base.cld(last_offset, align) * align |
| 118 | + last_offset = offset + sz |
| 119 | + if last_offset > HOSTCALL_BUFFER_SIZE |
| 120 | + return :(error("hostcall arguments exceed maximum buffer size")) |
| 121 | + end |
| 122 | + :(unsafe_load(reinterpret(Ptr{$typ}, ptr+$offset))) |
| 123 | + else |
| 124 | + :($(typ.instance)) |
| 125 | + end |
| 126 | + push!(args, arg) |
| 127 | + end |
| 128 | + |
| 129 | + quote |
| 130 | + ($(args...)) |
| 131 | + end |
| 132 | +end |
| 133 | + |
| 134 | +@noinline function process_hostcall(sig::Type{T}, rettyp::Type{U}, buffer_ptr) where {T,U} |
| 135 | + f, args... = read_hostcall_arguments(buffer_ptr, sig) |
| 136 | + rv = Base.invokelatest(f, args...)::rettyp |
| 137 | + |
| 138 | + if rettyp === Nothing |
| 139 | + HOSTCALL_READY |
| 140 | + else |
| 141 | + # store the return type |
| 142 | + if sizeof(rettyp) > HOSTCALL_BUFFER_SIZE |
| 143 | + error("hostcall return value exceeds maximum buffer size") |
| 144 | + end |
| 145 | + unsafe_store!(reinterpret(Ptr{rettyp}, buffer_ptr), rv) |
| 146 | + HOSTCALL_RETURNED |
| 147 | + end |
| 148 | +end |
0 commit comments