Skip to content

Commit c3270bc

Browse files
committed
cleanup
1 parent 6492cad commit c3270bc

4 files changed

Lines changed: 93 additions & 96 deletions

File tree

examples/tasking.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ function task_noop(a)
3232
end
3333

3434
function test_driver()
35-
Legate.ensure_runtime!()
3635
N = 1000
3736
rt = Legate.get_runtime()
3837
lib = Legate.create_library("test")

src/api/tasks.jl

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,6 @@ const GLOBAL_TASK_REGISTRY = Dict{UInt32, UfiMetadata}()
2525
const SUBMITTED_COUNT = Threads.Atomic{Int}(0)
2626
const NEXT_TASK_ID = Threads.Atomic{UInt32}(50000)
2727

28-
function wrap_task(f::Function; is_gpu=false)
29-
if is_gpu
30-
return JuliaGPUTask(f, 0)
31-
else
32-
return JuliaCPUTask(f, 0)
33-
end
34-
end
35-
3628
function create_task(rt::CxxPtr{Runtime}, lib::Library, id::LocalTaskID)
3729
impl = LegateInternal.create_auto_task(rt, lib, id)
3830
@debug "Creating auto task $(impl)"
@@ -90,20 +82,21 @@ function add_constraint(task::LegateTask, c::Constraint)
9082
end
9183

9284

93-
function create_julia_task(rt::Any, lib::Any, task_obj::JuliaTask)
94-
is_cpu = (task_obj isa JuliaCPUTask)
95-
96-
# bypass create_task to avoid recursion in JIT/Method lookup
97-
impl_ptr = ccall((:legate_create_julia_task_wrapper, Legate.WRAPPER_LIB_PATH), Ptr{Cvoid}, (Ptr{Cvoid}, Ptr{Cvoid}, Cint), rt.cpp_object, lib.cpp_object, is_cpu ? 0 : 1)
98-
99-
# Wrap raw pointer directly into LegateTask
100-
task = LegateTask(CxxWrap.CxxPtr{LegateInternal.AutoTask}(impl_ptr))
101-
102-
task.fun = task_obj.fun
103-
# ONLY Julia tasks get the task_id scalar and tracking
85+
function create_julia_task(rt, lib, task_obj::JuliaTask{CPUBackend})
86+
create_julia_task_impl(rt, lib, task_obj, 0)
87+
end
88+
89+
function create_julia_task(rt, lib, task_obj::JuliaTask{GPUBackend})
90+
create_julia_task_impl(rt, lib, task_obj, 1)
91+
end
92+
93+
94+
function create_julia_task_impl(rt, lib, task_obj, backend_flag::Cint)
95+
# returns an Legate AutoTask object ptr
96+
impl_ptr = ccall((:legate_create_julia_task_wrapper, Legate.WRAPPER_LIB_PATH), Ptr{Cvoid}, (Ptr{Cvoid}, Ptr{Cvoid}, Cint), rt.cpp_object, lib.cpp_object, backend_flag)
97+
task = LegateTask(CxxWrap.CxxPtr{LegateInternal.AutoTask}(impl_ptr), task_obj.fun)
10498
task.task_id = Threads.atomic_add!(NEXT_TASK_ID, UInt32(1))
105-
106-
# Prepend internal task_id as scalar 0 on cpp side
99+
# Prepend internal task_id as scalar 0 on cpp Legate side
107100
LegateInternal.add_scalar(task.impl, Scalar(UInt32(task.task_id)).impl)
108101
return task
109102
end
@@ -136,7 +129,8 @@ function submit_task(rt::CxxPtr{Runtime}, task::LegateTask)
136129

137130
# Principled warmup: Force JIT compilation safely on submission thread
138131
# 1. Precompile the internal statically-typed dispatcher
139-
precompile(_do_call, (typeof(task.fun), Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, typeof(meta.dims), typeof(sig)))
132+
precompile(Legate._do_call, (typeof(task.fun), Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, Ptr{Ptr{Cvoid}}, typeof(meta.dims), typeof(sig)))
133+
precompile(Legate._extract_and_call, (typeof(meta), Vector{Ptr{Cvoid}}, Vector{Ptr{Cvoid}}, Vector{Ptr{Cvoid}}, typeof(sig)))
140134

141135
# 2. Precompile the user-provided function with exact types
142136
user_arg_types = Any[]

src/api/types.jl

Lines changed: 19 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,29 +17,34 @@
1717
* Ethan Meitz <emeitz@andrew.cmu.edu>
1818
=#
1919

20-
struct JuliaCPUTask
21-
fun::Function
22-
task_id::UInt32
23-
end
2420

25-
struct JuliaGPUTask
26-
fun::Function
21+
abstract type TaskBackend end
22+
struct CPUBackend <: TaskBackend end
23+
struct GPUBackend <: TaskBackend end
24+
25+
struct JuliaTask{B<:TaskBackend, F}
26+
fun::F
2727
task_id::UInt32
2828
end
2929

30-
JuliaTask = Union{JuliaCPUTask, JuliaGPUTask}
30+
wrap_task(f, ::Type{CPUBackend}) =
31+
JuliaTask{CPUBackend, typeof(f)}(f, 0)
3132

32-
mutable struct LegateTask{I}
33+
wrap_task(f, ::Type{GPUBackend}) =
34+
JuliaTask{GPUBackend, typeof(f)}(f, 0)
35+
36+
mutable struct LegateTask{I, F}
3337
impl::I
34-
fun::Union{Nothing, Function}
38+
fun::F
3539
task_id::UInt32
3640
input_types::Vector{DataType}
3741
output_types::Vector{DataType}
3842
scalar_types::Vector{DataType}
3943
arg_dims::Vector{Union{Nothing, NTuple}}
4044
end
4145

42-
LegateTask(impl::I) where I = LegateTask{I}(impl, nothing, UInt32(0), DataType[], DataType[], DataType[], Union{Nothing, NTuple}[])
46+
LegateTask(impl::I, fun::F) where {I, F} = LegateTask{I, F}(impl, fun, UInt32(0), DataType[], DataType[], DataType[], Union{Nothing, NTuple}[])
47+
4348

4449
const AutoTask = LegateTask{AutoTaskImpl}
4550
const ManualTask = LegateTask{ManualTaskImpl}
@@ -57,13 +62,12 @@ end
5762

5863
struct UfiSignature{InT, OutT, ScT} end
5964

60-
struct UfiMetadata
61-
fun::Function
62-
sig::Any
63-
dims::Tuple
65+
struct UfiMetadata{F,S,D}
66+
fun::F
67+
sig::S
68+
dims::D
6469
end
6570

66-
6771
struct Scalar{T}
6872
impl::ScalarImpl
6973
end

src/ufi.jl

Lines changed: 58 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
* Ethan Meitz <emeitz@andrew.cmu.edu>
1818
=#
1919

20-
# Realm-defined max dimension
2120
const REALM_MAX_DIM = 6
2221
const MAX_UFI_SLOTS_VAL = 32
2322
const SLOT_REQUEST_PTRS = Vector{Ptr{Cvoid}}(undef, MAX_UFI_SLOTS_VAL)
@@ -31,6 +30,8 @@ const DISPATCH_LOCK = ReentrantLock()
3130
const IN_POLL = Threads.Atomic{Int}(0)
3231
const PENDING_JOBS = Threads.Atomic{Int}(0)
3332

33+
const UFI_ERROR = 217
34+
3435
struct TaskJob
3536
slot_id::Int
3637
in_args::Vector{Ptr{Cvoid}}
@@ -52,20 +53,20 @@ zero-allocation in the hot path.
5253
"""
5354
@generated function _do_call(f, in_p_ptr::Ptr{Ptr{Cvoid}}, out_p_ptr::Ptr{Ptr{Cvoid}}, scal_p_ptr::Ptr{Ptr{Cvoid}}, dims::Tuple, ::UfiSignature{InT, OutT, ScT}) where {InT, OutT, ScT}
5455
exprs = []
55-
cursor = 1
56+
dim_cursor = 1
5657

5758
# Inputs
5859
for (i, T) in enumerate(InT.parameters)
5960
E = eltype(T)
60-
push!(exprs, :(unsafe_wrap(Array, Ptr{$E}(unsafe_load(in_p_ptr, $i)), dims[$cursor])))
61-
cursor += 1
61+
push!(exprs, :(unsafe_wrap(Array, Ptr{$E}(unsafe_load(in_p_ptr, $i)), dims[$dim_cursor])))
62+
dim_cursor += 1
6263
end
6364

6465
# Outputs
6566
for (i, T) in enumerate(OutT.parameters)
6667
E = eltype(T)
67-
push!(exprs, :(unsafe_wrap(Array, Ptr{$E}(unsafe_load(out_p_ptr, $i)), dims[$cursor])))
68-
cursor += 1
68+
push!(exprs, :(unsafe_wrap(Array, Ptr{$E}(unsafe_load(out_p_ptr, $i)), dims[$dim_cursor])))
69+
dim_cursor += 1
6970
end
7071

7172
# Scalars
@@ -78,8 +79,10 @@ zero-allocation in the hot path.
7879
end
7980
end
8081

81-
function _extract_and_call(meta::UfiMetadata, in_p_ptr::Ptr{Ptr{Cvoid}}, out_p_ptr::Ptr{Ptr{Cvoid}}, scal_p_ptr::Ptr{Ptr{Cvoid}}, sig::UfiSignature)
82-
_do_call(meta.fun, in_p_ptr, out_p_ptr, scal_p_ptr, meta.dims, sig)
82+
function _extract_and_call(meta::UfiMetadata{F, S, D}, in_args::Vector{Ptr{Cvoid}}, out_args::Vector{Ptr{Cvoid}}, scal_args::Vector{Ptr{Cvoid}}, sig::S) where {F, S, D}
83+
GC.@preserve in_args out_args scal_args begin
84+
_do_call(meta.fun, pointer(in_args), pointer(out_args), pointer(scal_args), meta.dims, sig)
85+
end
8386
end
8487

8588
function ufi_has_pending_work()
@@ -107,66 +110,57 @@ function ufi_poll()
107110
if !UFI_INITIALIZED[]; return false; end
108111
if Threads.atomic_cas!(IN_POLL, 0, 1) != 0; return false; end
109112

110-
try
111-
slot_id = Int(ccall((:legate_pop_pending_slot_nonblocking, Legate.WRAPPER_LIB_PATH), Cint, ()))
112-
if slot_id != -1
113-
base_ptr = SLOT_REQUEST_PTRS[slot_id + 1]
114-
task_id = unsafe_load(Ptr{UInt32}(base_ptr + 4))
115-
116-
meta = lock(REGISTRY_LOCK) do
117-
get(GLOBAL_TASK_REGISTRY, task_id, nothing)
118-
end
119-
120-
if isnothing(meta)
121-
# This should NOT happen now with correct metadata alignment
122-
println(stderr, "[UFI Error] Task ID $task_id not found in registry!")
123-
ccall((:completion_callback_from_julia, Legate.WRAPPER_LIB_PATH), Cvoid, (Cint,), Cint(slot_id))
124-
return true
125-
end
126-
127-
# Extract pointers immediately into stable vectors
128-
sig_type = typeof(meta.sig)
129-
in_p_ptr = unsafe_load(Ptr{Ptr{Ptr{Cvoid}}}(base_ptr + 8))
130-
out_p_ptr = unsafe_load(Ptr{Ptr{Ptr{Cvoid}}}(base_ptr + 16))
131-
scal_p_ptr = unsafe_load(Ptr{Ptr{Ptr{Cvoid}}}(base_ptr + 24))
132-
133-
in_args = [unsafe_load(in_p_ptr, i) for i in 1:length(sig_type.parameters[1].parameters)]
134-
out_args = [unsafe_load(out_p_ptr, i) for i in 1:length(sig_type.parameters[2].parameters)]
135-
136-
# User scalars start at index 1 of scal_p_ptr (skipping task_id at index 0)
137-
num_user_scars = length(sig_type.parameters[3].parameters)
138-
scal_args = [unsafe_load(scal_p_ptr, i) for i in 1:num_user_scars]
139-
140-
Threads.atomic_add!(PENDING_JOBS, 1)
141-
put!(JOB_QUEUE[], TaskJob(slot_id, in_args, out_args, scal_args, meta))
142-
return true
113+
slot_id = Int(ccall((:legate_pop_pending_slot_nonblocking, Legate.WRAPPER_LIB_PATH), Cint, ()))
114+
if slot_id != -1
115+
base_ptr = SLOT_REQUEST_PTRS[slot_id + 1]
116+
task_id = unsafe_load(Ptr{UInt32}(base_ptr + 4))
117+
118+
meta = lock(REGISTRY_LOCK) do
119+
get(GLOBAL_TASK_REGISTRY, task_id, nothing)
143120
end
144-
finally
145-
IN_POLL[] = 0
121+
122+
if isnothing(meta)
123+
println(stderr, "[UFI Error] Task ID $task_id not found in registry!")
124+
exit(UFI_ERROR)
125+
end
126+
127+
# Extract pointers immediately into stable vectors
128+
sig_type = typeof(meta.sig)
129+
in_p_ptr = unsafe_load(Ptr{Ptr{Ptr{Cvoid}}}(base_ptr + 8))
130+
out_p_ptr = unsafe_load(Ptr{Ptr{Ptr{Cvoid}}}(base_ptr + 16))
131+
scal_p_ptr = unsafe_load(Ptr{Ptr{Ptr{Cvoid}}}(base_ptr + 24))
132+
133+
in_args = [unsafe_load(in_p_ptr, i) for i in 1:length(sig_type.parameters[1].parameters)]
134+
out_args = [unsafe_load(out_p_ptr, i) for i in 1:length(sig_type.parameters[2].parameters)]
135+
scal_args = [unsafe_load(scal_p_ptr, i) for i in 1:length(sig_type.parameters[3].parameters)]
136+
137+
Threads.atomic_add!(PENDING_JOBS, 1)
138+
put!(JOB_QUEUE[], TaskJob(slot_id, in_args, out_args, scal_args, meta))
139+
return true
146140
end
141+
IN_POLL[] = 0
147142
return false
148143
end
149144

150145
function _ufi_worker_loop()
151-
(ccall(:jl_generating_output, Cint, ()) != 0) && return
146+
_is_precompiling && return
152147
while !UFI_SHUTDOWN[]
148+
job = take!(JOB_QUEUE[])
153149
try
154-
job = take!(JOB_QUEUE[])
155-
try
156-
_extract_and_call(job.meta, pointer(job.in_args), pointer(job.out_args), pointer(job.scal_args), job.meta.sig)
157-
catch e
158-
println(stderr, "[UFI Worker Error] Slot $(job.slot_id): $e")
159-
Base.display_error(stderr, e, catch_backtrace())
160-
finally
161-
Threads.atomic_sub!(PENDING_JOBS, 1)
162-
ccall((:completion_callback_from_julia, Legate.WRAPPER_LIB_PATH), Cvoid, (Cint,), Cint(job.slot_id))
163-
end
164-
catch e; end
150+
_extract_and_call(job.meta, job.in_args, job.out_args, job.scal_args, job.meta.sig)
151+
catch e
152+
println(stderr, "[UFI Worker Error] Slot $(job.slot_id): $e")
153+
Base.display_error(stderr, e, catch_backtrace())
154+
exit(UFI_ERROR)
155+
finally
156+
Threads.atomic_sub!(PENDING_JOBS, 1)
157+
ccall((:completion_callback_from_julia, Legate.WRAPPER_LIB_PATH), Cvoid, (Cint,), Cint(job.slot_id))
158+
end
165159
end
166160
end
167161

168162
function _ufi_poller_loop()
169-
(ccall(:jl_generating_output, Cint, ()) != 0) && return
163+
_is_precompiling && return
170164
while !UFI_SHUTDOWN[]
171165
if !ufi_poll()
172166
yield()
@@ -182,9 +176,13 @@ function init_ufi()
182176
UFI_INITIALIZED[] && return
183177
_is_precompiling() && return
184178

185-
JOB_QUEUE[] = Channel{TaskJob}(1000)
179+
JOB_QUEUE[] = Channel{TaskJob}(128)
186180

187181
max_slots = ccall((:legate_get_max_slots, Legate.WRAPPER_LIB_PATH), Cint, ())
182+
if max_slots <= 0
183+
exit(UFI_ERROR)
184+
end
185+
188186
for i in 1:max_slots
189187
SLOT_REQUEST_PTRS[i] = ccall((:legate_get_slot_request_ptr, Legate.WRAPPER_LIB_PATH), Ptr{Cvoid}, (Cint,), Cint(i-1))
190188
end
@@ -205,8 +203,10 @@ function init_ufi()
205203
end
206204
yield()
207205

208-
println(stderr, "[UFI] System Initialized (Concurrent Count-Sync Mode)")
206+
println(stderr, "[UFI] System Initialized (Concurrent Count-Sync Mode) with $(num_workers) workers\n")
209207
end
208+
209+
@debug "fuck"
210210
end
211211

212212
function shutdown_ufi()

0 commit comments

Comments
 (0)