-
Notifications
You must be signed in to change notification settings - Fork 57
/
Copy pathglobal-hooks.jl
101 lines (94 loc) · 4.07 KB
/
global-hooks.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
const default_global_hooks = Dict{Symbol,Function}()
function boundscheck_hook(boundscheck::Bool)
return (gbl, mod, device) -> begin
gbl_ptr = Base.unsafe_convert(Ptr{UInt8}, gbl)
Base.unsafe_store!(gbl_ptr, UInt8(boundscheck))
end
end
default_global_hooks[:__global_output_context] = (gbl, mod, device) -> begin
# initialize global output context
gbl_ptr = Base.unsafe_convert(Ptr{AMDGPU.Device.GLOBAL_OUTPUT_CONTEXT_TYPE}, gbl)
oc = Device.OutputContext(stdout; device, name=:__global_output, timeout=nothing)
Base.unsafe_store!(gbl_ptr, oc)
end
default_global_hooks[:__global_printf_context] = (gbl, mod, device) -> begin
# initialize global printf context
# Return type of Int to force synchronizing behavior
args_type = Tuple{LLVMPtr{UInt8, AS.Global}}
ret_type = Int
gbl_ptr = Base.unsafe_convert(Ptr{HostCall{ret_type, args_type}}, gbl)
hc = Device.named_perdevice_hostcall(device, :__global_printf) do
HostCall(ret_type, args_type; device, continuous=true, buf_len=2^16, timeout=nothing) do _
fmt, all_args = unsafe_load(reinterpret(LLVMPtr{AMDGPU.Device.ROCPrintfBuffer,AS.Global}, hc.buf_ptr))
for args in all_args
args = map(x -> x isa Cstring ? unsafe_string(x) : x, args)
@debug "@rocprintf with $fmt and $(args)"
try
@eval @printf($fmt, $(args...))
catch err
@error "@rocprintf error" exception=(err,catch_backtrace())
end
end
return 0
end
end
Base.unsafe_store!(gbl_ptr, hc)
end
default_global_hooks[:__global_exception_flag] = (gbl, mod, device) -> begin
# initialize global exception flag
gbl_ptr = Base.unsafe_convert(Ptr{Int64}, gbl)
Base.unsafe_store!(gbl_ptr, 0)
end
default_global_hooks[:__global_exception_ring] = (gbl, mod, device) -> begin
# initialize exception ring buffer
gbl_ptr = Base.unsafe_convert(Ptr{Ptr{ExceptionEntry}}, gbl)
ex_ptr = Base.unsafe_convert(Ptr{ExceptionEntry}, mod.exceptions)
unsafe_store!(gbl_ptr, ex_ptr)
# setup initial slots
for i in 1:Runtime.MAX_EXCEPTIONS-1
unsafe_store!(ex_ptr, ExceptionEntry(0, LLVMPtr{UInt8,1}(0)))
ex_ptr += sizeof(ExceptionEntry)
end
# setup tail slot
unsafe_store!(ex_ptr, ExceptionEntry(1, LLVMPtr{UInt8,1}(0)))
end
default_global_hooks[:__global_malloc_hostcall] = (gbl, mod, device) -> begin
# initialize malloc hostcall
args_type = Tuple{UInt64, Csize_t}
ret_type = Ptr{Cvoid}
gbl_ptr = Base.unsafe_convert(Ptr{HostCall{ret_type, args_type}}, gbl)
hc = Device.named_perdevice_hostcall(device, :__global_malloc) do
HostCall(ret_type, args_type; device, continuous=true, timeout=nothing) do kern, sz
buf = Mem.alloc(device, sz; coherent=true)
# FIXME: Lock
push!(mod.metadata, Runtime.KernelMetadata(kern, buf))
@debug "Allocated $(buf.ptr) ($sz bytes) for kernel $kern on device $device"
return buf.ptr
end
end
Base.unsafe_store!(gbl_ptr, hc)
end
default_global_hooks[:__global_free_hostcall] = (gbl, mod, device) -> begin
# initialize free hostcall
args_type = Tuple{UInt64, Ptr{Cvoid}}
ret_type = Nothing
gbl_ptr = Base.unsafe_convert(Ptr{HostCall{ret_type, args_type}}, gbl)
hc = Device.named_perdevice_hostcall(device, :__global_free) do
HostCall(ret_type, args_type; device, continuous=true, timeout=nothing) do kern, ptr
# FIXME: Lock
for idx in length(mod.metadata):-1:1
meta = mod.metadata[idx]
same_kern = meta.kern == kern
same_ptr = meta.buf.ptr == ptr
if same_kern && same_ptr
Mem.free(meta.buf)
deleteat!(mod.metadata, idx)
@debug "Freed $ptr ($(meta.buf.bytesize) bytes) for kernel $kern on device $device."
break
end
end
return nothing
end
end
Base.unsafe_store!(gbl_ptr, hc)
end