Metal PJRT backend via MPSGraph + pure-Julia plugin#2489
Metal PJRT backend via MPSGraph + pure-Julia plugin#2489Dale-Black wants to merge 1 commit intoEnzymeAD:mainfrom
Conversation
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## main #2489 +/- ##
===========================================
- Coverage 68.16% 34.13% -34.03%
===========================================
Files 109 214 +105
Lines 11779 30852 +19073
===========================================
+ Hits 8029 10531 +2502
- Misses 3750 20321 +16571 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
ext/ReactantMetalExt/MLIRWalker.jl
Outdated
| # ============================================================================ | ||
|
|
||
| """Extract contracting_dims from dot_general op text.""" | ||
| function parse_contracting_dims(op_text::AbstractString) |
There was a problem hiding this comment.
We dont need to parse the string here, we should be able to query the operation to extract these info
|
Split out the C++/Bazel changes into #2490 per @avik-pal's request. That PR adds only Once the JLL is rebuilt with that symbol, the Julia changes in this PR will work against the new JLL (no more I'll rebase this PR to remove the |
MLIR API Refactor — Addressing String Parsing Feedback@avik-pal — thanks for the review. I interpreted your comment about not needing to parse strings as referring to the What changed (latest commit)Replaced all string-based attribute/type extraction with API calls:
Net result is -163 lines since the API calls are more concise than the regex parsers. What we tested
I'll share some screenshots from the local Pluto benchmark notebook in a follow-up comment. Please let me know if this is what you had in mind or if there are other areas that need attention — still learning my way around the MLIR infrastructure here. |
mofeing
left a comment
There was a problem hiding this comment.
Most comments are for code cleaning and formatting, but the most critical change is that stablehlo.reduce can silently translate to wrong code and not error. Unfortunately I don't know much about MPS so I can't lend a hand there.
ext/ReactantMetalExt/MLIRWalker.jl
Outdated
| raw_body_op = API.mlirBlockGetFirstOperation(body_block) | ||
| while !(IR.mlirIsNull(raw_body_op)) | ||
| bop = IR.Operation(raw_body_op) | ||
| bop_name = IR.name(bop) | ||
| if startswith(bop_name, "stablehlo.") | ||
| body_op_name = bop_name | ||
| break | ||
| end | ||
| raw_body_op = API.mlirOperationGetNextInBlock(bop) | ||
| end |
There was a problem hiding this comment.
Block implements the Iterator interface and returns Operations so this can be simplified to
| raw_body_op = API.mlirBlockGetFirstOperation(body_block) | |
| while !(IR.mlirIsNull(raw_body_op)) | |
| bop = IR.Operation(raw_body_op) | |
| bop_name = IR.name(bop) | |
| if startswith(bop_name, "stablehlo.") | |
| body_op_name = bop_name | |
| break | |
| end | |
| raw_body_op = API.mlirOperationGetNextInBlock(bop) | |
| end | |
| for bop in body_block | |
| bop_name = IR.name(bop) | |
| if startswith(bop_name, "stablehlo.") | |
| body_op_name = bop_name | |
| break | |
| end | |
| end |
There was a problem hiding this comment.
also, the problem with this approach is that stablehlo.reduce accepts arbitrary code, and matching on the first op most probably won't be correct.
take the following example (a sign-alternating add reduction). your code will translate it to a regular add reduction and it won't be correct.
%result = "stablehlo.reduce"(%input, %init_value) ({
^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
%1 = "stablehlo.negate"(%0) : tensor<i64>
"stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>
ext/ReactantMetalExt/MLIRWalker.jl
Outdated
|
|
||
| # Pattern-match the body to determine reduction type | ||
| body_op_name = "" | ||
| if IR.nregions(op) > 0 |
There was a problem hiding this comment.
stablehlo.reduce must have a region to contains the reducer code, so this seems redundant?
ext/ReactantMetalExt/MLIRWalker.jl
Outdated
| t = Metal.MPSGraphs.transposeTensor(graph, tensor, 0, 4, "$(name)_rev1") | ||
| return Metal.MPSGraphs.transposeTensor(graph, t, 1, 3, "$(name)_rev2") | ||
| else | ||
| error("mps_reverse_dims: unsupported rank $rank") |
There was a problem hiding this comment.
a permutation can be decomposed into a series of transpositions using, for example, the Coxeter decomposition, which is implemented in Permutations.jl
ext/ReactantMetalExt/MLIRWalker.jl
Outdated
| const OP_HANDLERS = Dict{String, Function}() | ||
|
|
||
| function get_op_handlers() | ||
| if isempty(OP_HANDLERS) | ||
| OP_HANDLERS["stablehlo.add"] = handle_add | ||
| OP_HANDLERS["stablehlo.subtract"] = handle_subtract | ||
| OP_HANDLERS["stablehlo.multiply"] = handle_multiply | ||
| OP_HANDLERS["stablehlo.divide"] = handle_divide | ||
| OP_HANDLERS["stablehlo.maximum"] = handle_maximum | ||
| OP_HANDLERS["stablehlo.negate"] = handle_negate | ||
| OP_HANDLERS["stablehlo.exponential"] = handle_exponential | ||
| OP_HANDLERS["stablehlo.exp"] = handle_exponential | ||
| OP_HANDLERS["stablehlo.log"] = handle_log | ||
| OP_HANDLERS["stablehlo.tanh"] = handle_tanh | ||
| OP_HANDLERS["stablehlo.sqrt"] = handle_sqrt | ||
| OP_HANDLERS["stablehlo.rsqrt"] = handle_rsqrt | ||
| OP_HANDLERS["stablehlo.abs"] = handle_abs | ||
| OP_HANDLERS["stablehlo.sine"] = handle_sin | ||
| OP_HANDLERS["stablehlo.sin"] = handle_sin | ||
| OP_HANDLERS["stablehlo.cosine"] = handle_cos | ||
| OP_HANDLERS["stablehlo.cos"] = handle_cos | ||
| OP_HANDLERS["stablehlo.convert"] = handle_convert | ||
| OP_HANDLERS["stablehlo.constant"] = handle_constant | ||
| OP_HANDLERS["stablehlo.dot_general"] = handle_dot_general | ||
| OP_HANDLERS["stablehlo.dot"] = handle_dot_general | ||
| OP_HANDLERS["stablehlo.broadcast_in_dim"] = handle_broadcast_in_dim | ||
| OP_HANDLERS["stablehlo.reshape"] = handle_reshape | ||
| OP_HANDLERS["stablehlo.transpose"] = handle_transpose | ||
| OP_HANDLERS["stablehlo.reverse"] = handle_reverse | ||
| OP_HANDLERS["stablehlo.concatenate"] = handle_concatenate | ||
| OP_HANDLERS["stablehlo.convolution"] = handle_convolution | ||
| OP_HANDLERS["stablehlo.slice"] = handle_slice | ||
| OP_HANDLERS["stablehlo.scatter"] = handle_scatter | ||
| end | ||
| return OP_HANDLERS | ||
| end |
There was a problem hiding this comment.
passing a Function this way will incur into dynamic dispatch when translating from stablehlo to MPSGraph (so increased compile-time).
given that this function is (1) quite trivial, (2) doesn't require user extensibility, and (3) it's only used once, do you mind changing wherever you call this function for if-elseif code?
(actually I see that the code does an if-elseif but then Claude went lazy?)
There was a problem hiding this comment.
if else if vs dispatch is close in performance when it gets this big (dynamic dispatch is like 2 ns)
ext/ReactantMetalExt/MLIRWalker.jl
Outdated
| handlers = get_op_handlers() | ||
| raw_op = API.mlirBlockGetFirstOperation(func_block) | ||
| while !(IR.mlirIsNull(raw_op)) | ||
| op = IR.Operation(raw_op) | ||
| op_name = IR.name(op) | ||
| ctx.op_count += 1 | ||
|
|
||
|
|
||
| if op_name == "func.return" | ||
| for j in 1:IR.noperands(op) | ||
| ret_val = IR.operand(op, j) | ||
| if haskey(ctx.value_map, ret_val) | ||
| push!(ctx.outputs, ctx.value_map[ret_val]) | ||
| ir_shape, dtype = get_type_info(ret_val) | ||
| julia_shape = length(ir_shape) >= 2 ? reverse(ir_shape) : ir_shape | ||
| push!(ctx.output_shapes, julia_shape) | ||
| push!(ctx.output_dtypes, dtype) | ||
| else | ||
| @warn "Return value not found in value_map" | ||
| end | ||
| end | ||
| elseif op_name == "stablehlo.reduce" | ||
| handle_reduce(ctx, op) | ||
| elseif op_name == "stablehlo.reduce_window" | ||
| handle_reduce_window(ctx, op) | ||
| elseif haskey(handlers, op_name) | ||
| handlers[op_name](ctx, op) | ||
| else | ||
| error("Unsupported StableHLO op: $op_name") | ||
| end | ||
|
|
||
| raw_op = API.mlirOperationGetNextInBlock(op) | ||
| end |
There was a problem hiding this comment.
you can iterate directly on the Block
| handlers = get_op_handlers() | |
| raw_op = API.mlirBlockGetFirstOperation(func_block) | |
| while !(IR.mlirIsNull(raw_op)) | |
| op = IR.Operation(raw_op) | |
| op_name = IR.name(op) | |
| ctx.op_count += 1 | |
| if op_name == "func.return" | |
| for j in 1:IR.noperands(op) | |
| ret_val = IR.operand(op, j) | |
| if haskey(ctx.value_map, ret_val) | |
| push!(ctx.outputs, ctx.value_map[ret_val]) | |
| ir_shape, dtype = get_type_info(ret_val) | |
| julia_shape = length(ir_shape) >= 2 ? reverse(ir_shape) : ir_shape | |
| push!(ctx.output_shapes, julia_shape) | |
| push!(ctx.output_dtypes, dtype) | |
| else | |
| @warn "Return value not found in value_map" | |
| end | |
| end | |
| elseif op_name == "stablehlo.reduce" | |
| handle_reduce(ctx, op) | |
| elseif op_name == "stablehlo.reduce_window" | |
| handle_reduce_window(ctx, op) | |
| elseif haskey(handlers, op_name) | |
| handlers[op_name](ctx, op) | |
| else | |
| error("Unsupported StableHLO op: $op_name") | |
| end | |
| raw_op = API.mlirOperationGetNextInBlock(op) | |
| end | |
| handlers = get_op_handlers() | |
| for op in func_block | |
| op_name = IR.name(op) | |
| ctx.op_count += 1 | |
| if op_name == "func.return" | |
| for j in 1:IR.noperands(op) | |
| ret_val = IR.operand(op, j) | |
| if haskey(ctx.value_map, ret_val) | |
| push!(ctx.outputs, ctx.value_map[ret_val]) | |
| ir_shape, dtype = get_type_info(ret_val) | |
| julia_shape = length(ir_shape) >= 2 ? reverse(ir_shape) : ir_shape | |
| push!(ctx.output_shapes, julia_shape) | |
| push!(ctx.output_dtypes, dtype) | |
| else | |
| @warn "Return value not found in value_map" | |
| end | |
| end | |
| elseif op_name == "stablehlo.reduce" | |
| handle_reduce(ctx, op) | |
| elseif op_name == "stablehlo.reduce_window" | |
| handle_reduce_window(ctx, op) | |
| elseif haskey(handlers, op_name) | |
| handlers[op_name](ctx, op) | |
| else | |
| error("Unsupported StableHLO op: $op_name") | |
| end | |
| end |
ext/ReactantMetalExt/PJRTPlugin.jl
Outdated
| struct MetalPJRT_Api_Version | ||
| struct_size::UInt64 # offset 0, 8 bytes | ||
| extension_start::Ptr{Cvoid} # offset 8, 8 bytes | ||
| major_version::Int32 # offset 16, 4 bytes | ||
| minor_version::Int32 # offset 20, 4 bytes | ||
| end # 24 bytes total (0x18) | ||
|
|
||
| struct MetalPJRT_Api | ||
| struct_size::UInt64 # offset 0 | ||
| extension_start::Ptr{Cvoid} # offset 8 | ||
| pjrt_api_version::MetalPJRT_Api_Version # offset 16, 24 bytes | ||
| fns::NTuple{128,Ptr{Cvoid}} # offset 40, 1024 bytes | ||
| end # Total: 8 + 8 + 24 + 1024 = 1064 bytes (0x428) |
There was a problem hiding this comment.
would you mind using the bindings in https://github.com/EnzymeAD/Reactant.jl/blob/main/src/xla/PJRT/CAPI.jl ? this way we avoid duplications and can track changes in the PJRT API
src/accelerators/Metal.jl
Outdated
| """ | ||
| setup_metal!() | ||
|
|
||
| metal_pjrt_plugin_path = joinpath(path, "pjrt_plugin_metal_14.dylib") | ||
| if !isfile(metal_pjrt_plugin_path) | ||
| zip_file_path = joinpath(path, "pjrt-plugin-metal.zip") | ||
| tmp_dir = joinpath(path, "tmp") | ||
| Downloads.download( | ||
| if Sys.ARCH === :aarch64 | ||
| "https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_metal-0.1.1-py3-none-macosx_13_0_arm64.whl" | ||
| elseif Sys.ARCH === :x86_64 | ||
| "https://files.pythonhosted.org/packages/87/ec/9bb7f7f0ffd06c3fb89813126b2f698636ac7a4263ed7bdd1ff7d7c94f8f/jax_metal-0.1.1-py3-none-macosx_10_14_x86_64.whl" | ||
| else | ||
| error("Unsupported architecture: $(Sys.ARCH)") | ||
| end, | ||
| zip_file_path, | ||
| ) | ||
| run(pipeline(`$(p7zip()) x -tzip -o$(tmp_dir) -- $(zip_file_path)`, devnull)) | ||
| mv( | ||
| joinpath(tmp_dir, "jax_plugins", "metal_plugin", "pjrt_plugin_metal_14.dylib"), | ||
| metal_pjrt_plugin_path, | ||
| ) | ||
| rm(tmp_dir; recursive=true) | ||
| rm(zip_file_path; recursive=true) | ||
| end | ||
| Placeholder hook for external callers. The actual Metal PJRT client is | ||
| created inside `ReactantMetalExt.__init__()`, which Julia loads automatically | ||
| whenever `Metal` is brought into scope as a weak dependency. | ||
| """ | ||
| function setup_metal!() | ||
| # Metal client registration is handled by ReactantMetalExt.__init__() | ||
| # when Metal.jl is loaded as a weak dependency. | ||
| return nothing |
src/xla/Device.jl
Outdated
| end | ||
| else | ||
| @warn "`get_properties` not implemented for platform: $(pname)" maxlog = 1 | ||
| @debug "`get_properties` not implemented for platform: $(pname)" |
There was a problem hiding this comment.
is there a reason for this change? it looks quite arbitrary
|
Hi @mofeing — thanks for the thorough review. I (plus Claude) have tried to address each of your 10 comments in individual commits:
I'm still learning the codebase and some of this is admittedly hacky, especially the CAPI.jl integration. If I butchered anything or you'd prefer a different approach on any of these, please let me know — very happy to rework. |
There was a problem hiding this comment.
I'm still learning the codebase and some of this is admittedly hacky, especially the CAPI.jl integration. If I butchered anything or you'd prefer a different approach on any of these, please let me know — very happy to rework.
it's great for an initial version and thanks for working on this. I confess I'm a lil bit picky; most of my requests are minor things that can be changed in subsequent PRs.
for me the most critical thing holding this PR is that you just cannot match the first op in the reduce block of stablehlo.reduce and choose the reducer function based on it. the reducer code can be more complex and the way is coded right now it will silently generate wrong results. if implementing a fix for it is too hard right now, I would prefer it to be left unimplemented and just error, or if needed, match the full block for the already implemented special cases like add- or max-reductions (i.e. match up to the return).
also, it seems like Claude prefers to use Ptr{Cvoid} and hardcode the field offsets in PJRTPlugin.jl. we should instead use the types in Reactant.XLA.PJRT.CAPI instead of Cvoid for the pointers, and fieldoffset instead of hardcoded numbers.
src/xla/PJRT/CAPI.jl
Outdated
| const PJRT_API_MINOR = 90 | ||
|
|
||
| const _PJRT_API_STRUCT_FIELD = fn_type(fn_type) * fn_type | ||
| # const _PJRT_API_STRUCT_FIELD = fn_type(fn_type) * fn_type # untranslatable C macro |
There was a problem hiding this comment.
this file is auto-generated, so you should refrain from making changes there. if it's breaking sth, tell us so we can fix it in the generator script.
ext/ReactantMetalExt/PJRTPlugin.jl
Outdated
|
|
||
| # PJRT_LoadedExecutable_Destroy_Args: | ||
| # offset 16: executable* (input, 8) — our handle | ||
| function _loaded_exec_destroy(args::Ptr{Cvoid})::Ptr{Cvoid} |
There was a problem hiding this comment.
using Ptr{Cvoid} and unsafe_load/unsafe_store! with hardcoded offsets is fragile and makes fixing bugs in the future way more difficult.
the point of XLA.PJRT.CAPI is also that you can use the structs defined there in these functions.
function _loaded_exec_destroy(args::Ptr{PJRT_LoadedExecutable_Destroy_Args})::Ptr{Cvoid}instead of the hardcoded offsets, you can use fieldoffset
|
Hi @mofeing — following up on your feedback about The changes are purely mechanical — no logic changes, just signature types and field access patterns:
One note: CAPI.jl line 3932 ( All tests pass (sincos, autodiff, CNN + non-square conv). As always, let me know if anything looks off or if you'd prefer a different approach. |
|
@Dale-Black after finding a pretty bad "accidentally ccalled with the wrong number of arguments" we in the interim just landed a refactor of all the ccall/abi stuff where the ccall code is autogenerated, and you should instead call wrapper functions within API.x. Can you update your code to use those? They also should contain relevant struct definitions there as well |
a87035a to
c8f7396
Compare
|
Hi @wsmoses — just rebased onto main and updated. We only had one direct |
| # Maps C handle address (UInt64) -> MetalExecutable. | ||
| # Using Any to avoid dependency on MLIRWalker.jl (included after this file). | ||
| const LOADED_EXECUTABLES = Dict{UInt64, Any}() | ||
|
|
There was a problem hiding this comment.
I don't quite get why you need to have these dictionaries, instead of having something like
mutable struct Executable
... all the data
end
ext/ReactantMetalExt/PJRTPlugin.jl
Outdated
|
|
||
| handle = _handle_alloc() | ||
| @lock PJRT_LOCK begin | ||
| METAL_BUFFERS[UInt64(handle)] = (data=data_gpu, dims_c=dims_c, ndims=num_dims, |
There was a problem hiding this comment.
so I think this needlessly goes thorugh an extra level of indirection.
presumably metal.jl calls into apple's allocator function (which can just malloc/free or similar like a normal allocator). This is then stored in a GC'd object which is free'd upon end of use. To avoid that free, we put it in a dict.
Rather than have the dict to avoid the gc to avoid the free, can we just call [via metal.jl potential internals] the actual apple allocate/free functions?
That way we avoid the dict, indirection/race issues, and a decent chunk of overhead?
There was a problem hiding this comment.
if we need to store data, we can make our own
mutable struct MetalBuffer
ptr::Ptr{Cvoid} # actual data
eltype::Enum # element type
# ... whatever other data we need
end
we control alloc/free, so in the allocation function we allocate both the data and libc.malloc a struct of size metalbuffer or whatnot, then in the free, we free both
|
Hi @wsmoses — just pushed the Dict-free buffer/executable rearchitecture you described. Is this the type of approach you're talking about? What changed:
All tests pass (sincos, autodiff, CNN, non-square conv). |
ext/ReactantMetalExt/MLIRWalker.jl
Outdated
| # Serialized MLIR text for _exec_optimized_program (set after construction in _client_compile) | ||
| mlir_text::String | ||
| # Execution cache — lazily built on first execute! | ||
| _input_mtl::Vector{Any} # cached MtlArrays for inputs |
There was a problem hiding this comment.
this seems not to be used?
can this be trimmed to just the essentials?
ext/ReactantMetalExt/PJRTPlugin.jl
Outdated
| # GC roots for MetalExecutable objects — prevents GC collection while PJRT holds the handle. | ||
| # NOT used for lookup: the handle IS the pointer to the Julia object (pointer_from_objref). | ||
| # To retrieve: unsafe_pointer_to_objref(handle). To destroy: delete! from set. | ||
| const EXEC_GC_ROOTS = Set{Any}() |
There was a problem hiding this comment.
this shouldn't be necessary any more of we actually malloc/free it explicitly, right?
|
@wsmoses — good catches. Removed the dead cache fields from MetalExecutable (trimmed to just the essentials). On |
|
yeah so @Dale-Black GC errors are a sufficiently non-determistic pain when things go wrong (and try to fix), that since at the end of the day all of the types stored in MetalExecutable can be non-julia types, we should make it a struct just allocated with malloc/free and explicitly managed. for example for the dtypes, we can use whatever integer enum for types xla already has [and then call the convert to julia type when requested] |
|
to be clear we can (and should) compeltely use all the nice julia/metal.jl setup side for building the executable struct, but once built the executable object itself should not have julia GC objects, if possible [to avoid weird memory corruption debugging in our future] |
|
Here's my attempt at the C-allocatable MetalExecutable refactor, @wsmoses — MetalExecutableData is now Libc.malloc'd with raw ObjC ids + explicit retain/release. EXEC_GC_ROOTS eliminated. Let me know what you think. |
ext/ReactantMetalExt/PJRTPlugin.jl
Outdated
| # Total: 15 + n_in + n_out + n_c | ||
| # Plus: n_c MTLBuffer allocations via Metal.alloc (freed via Metal.free) | ||
| # ============================================================ | ||
| function freeze_executable(exec)::Ptr{MetalExecutableData} |
There was a problem hiding this comment.
this is super verbose as a file, can you separate it out into logical components [e.g. a Executable.jl]
|
Split PJRTPlugin.jl into Buffer.jl + Executable.jl + PJRTPlugin.jl per your feedback, @wsmoses. No logic changes, just file organization. |
mofeing
left a comment
There was a problem hiding this comment.
I agree with @wsmoses. Would you mind splitting PJRTPlugin.jl into more files? one file per API topic would be great. you already have Buffer and Executable, but there should be also for Client, Plugin, LoadedExecutable, Event, Device and DeviceDescription (I believe).
It requires some effort to review through these files, but also it will help with debugging and refactoring the global state design it seems Claude has decided to follow.
deps/build_local.jl
Outdated
| push!(build_cmd_list, "--sandbox_debug") | ||
|
|
||
| push!(build_cmd_list, "--linkopt=-fuse-ld=lld") | ||
| # push!(build_cmd_list, "--linkopt=-fuse-ld=lld") # lld not available on macOS |
There was a problem hiding this comment.
mmm was this commented by Claude? this flag does take effect on other OS aside of macOS, so you should wrap it inside if !Sys.isapple()
ext/ReactantMetalExt/Buffer.jl
Outdated
| # mtl_buf_id: raw ObjC id for the MTLBuffer (from Metal.alloc, SharedStorage). | ||
| # The MTLBuffer has retain count 1 from alloc — we release it in _buffer_destroy. | ||
| # data_ptr: CPU-accessible pointer from Metal.MTL.contents() (stable for SharedStorage). | ||
| struct MetalBufferData | ||
| mtl_buf_id::UInt64 | ||
| data_ptr::Ptr{Cvoid} | ||
| eltype::UInt32 | ||
| dims::Ptr{Int64} # Libc.malloc'd dims array | ||
| ndims::Int | ||
| nbytes::Int | ||
| end |
There was a problem hiding this comment.
I think what you're looking for here is a Metal.MTL.MTLTensor, although you can also use a Metal.MtlPtr and wrap whatever else you need.
ext/ReactantMetalExt/Buffer.jl
Outdated
| # PJRT_Buffer_Type enum value → Julia element type | ||
| # PJRT_Buffer_Type::UInt32: PRED=1,S8=2,S16=3,S32=4,S64=5,U8=6,F16=10,F32=11,F64=22 | ||
| function pjrt_type_to_julia(t::UInt32) | ||
| return if t == 11 | ||
| Float32 | ||
| elseif t == 22 | ||
| Float64 | ||
| elseif t == 10 | ||
| Float16 | ||
| elseif t == 4 | ||
| Int32 | ||
| elseif t == 5 | ||
| Int64 | ||
| else | ||
| Float32 | ||
| end | ||
| end | ||
|
|
||
| function julia_type_to_pjrt(T) | ||
| return if T == Float32 | ||
| UInt32(11) | ||
| elseif T == Float64 | ||
| UInt32(22) | ||
| elseif T == Float16 | ||
| UInt32(10) | ||
| elseif T == Int32 | ||
| UInt32(4) | ||
| elseif T == Int64 | ||
| UInt32(5) | ||
| else | ||
| UInt32(11) | ||
| end | ||
| end |
There was a problem hiding this comment.
you already have this functionality implemented in Reactant.XLA.primitive_type and Reactant.XLA.julia_type
| # PJRT_Buffer_Type enum value → Julia element type | |
| # PJRT_Buffer_Type::UInt32: PRED=1,S8=2,S16=3,S32=4,S64=5,U8=6,F16=10,F32=11,F64=22 | |
| function pjrt_type_to_julia(t::UInt32) | |
| return if t == 11 | |
| Float32 | |
| elseif t == 22 | |
| Float64 | |
| elseif t == 10 | |
| Float16 | |
| elseif t == 4 | |
| Int32 | |
| elseif t == 5 | |
| Int64 | |
| else | |
| Float32 | |
| end | |
| end | |
| function julia_type_to_pjrt(T) | |
| return if T == Float32 | |
| UInt32(11) | |
| elseif T == Float64 | |
| UInt32(22) | |
| elseif T == Float16 | |
| UInt32(10) | |
| elseif T == Int32 | |
| UInt32(4) | |
| elseif T == Int64 | |
| UInt32(5) | |
| else | |
| UInt32(11) | |
| end | |
| end |
ext/ReactantMetalExt/Executable.jl
Outdated
| # ============================================================ | ||
| # freeze_executable: Convert Julia MetalExecutable → C-allocated MetalExecutableData | ||
| # | ||
| # Retains all ObjC objects (graph, placeholders, output tensors). | ||
| # Copies all metadata to Libc.malloc'd C arrays. | ||
| # Allocates fresh MTLBuffers for constant data (independent of MtlArray GC). | ||
| # The returned pointer IS the PJRT executable handle. | ||
| # | ||
| # Libc.malloc count (for n_in inputs, n_out outputs, n_c constants): | ||
| # Fixed: 15 allocations (arrays + struct) | ||
| # Per-input shape: n_in allocations | ||
| # Per-output shape: n_out allocations | ||
| # Per-const shape: n_c allocations | ||
| # Total: 15 + n_in + n_out + n_c | ||
| # Plus: n_c MTLBuffer allocations via Metal.alloc (freed via Metal.free) | ||
| # ============================================================ |
There was a problem hiding this comment.
minor thing, but it would be cool if comments like this were docstrings
ext/ReactantMetalExt/PJRTPlugin.jl
Outdated
| global CLIENT_HANDLE = Libc.malloc(64) | ||
| global DEVICE_HANDLE = Libc.malloc(64) | ||
| global DEVDESC_HANDLE = Libc.malloc(64) | ||
| global MEMORY_HANDLE = Libc.malloc(64) | ||
|
|
||
| unsafe_store!(Ptr{Int64}(CLIENT_HANDLE), Int64(0xDEADBEEF)) | ||
| unsafe_store!(Ptr{Int64}(DEVICE_HANDLE), Int64(0xCAFEBABE)) | ||
| unsafe_store!(Ptr{Int64}(DEVDESC_HANDLE), Int64(0xF00DCAFE)) | ||
| unsafe_store!(Ptr{Int64}(MEMORY_HANDLE), Int64(0xFEEDFACE)) |
There was a problem hiding this comment.
what are these magic numbers? actually check it out that they are just things you can write in hexadecimal, like 'dead beef', or 'cafe babe'.
also, why are these pointers needed?
There was a problem hiding this comment.
Again, most of this is from Claude but I THINK it makes sense. Supposedly this is a single-device plugin thing — PJRT requires non-NULL pointers for client/device/memory handles (the C++ side dereferences them and segfaults on NULL). Since there's exactly one Metal GPU, one memory space, and one client, these handles don't need to carry real state — they just need to be valid, distinct, non-NULL pointers.
The hex values are conventional debug markers so they're easy to spot in lldb. Named them as SENTINEL_CLIENT, SENTINEL_DEVICE, etc. in the latest commit so they're self-documenting now.
At least this is what claude is telling me so I made that more clear in the recent commits
ext/ReactantMetalExt/PJRTPlugin.jl
Outdated
| function _client_create(args::Ptr{CAPI.PJRT_Client_Create_Args})::Ptr{Cvoid} | ||
| Reactant.unsafe_store_field!(args, CLIENT_HANDLE, Val{:client}()) | ||
| return C_NULL | ||
| end | ||
|
|
||
| function _client_destroy(args::Ptr{CAPI.PJRT_Client_Destroy_Args})::Ptr{Cvoid} | ||
| return C_NULL | ||
| end |
There was a problem hiding this comment.
I think this is cheating 😆 Here instead of creating a client with the requested configuration, it's ignoring all that and setting a ""random"" magic number for the pointer to the client.
Effectively it's working as a single client with global state, when PJRT would like to have all that data contained into an allocatable object.
I believe this is the reason why so many global pointers appear scattered around the code.
opinions @wsmoses ?
There was a problem hiding this comment.
I think this is correct: there's one Metal GPU on the system so there's no meaningful "configuration" to process from the create args. The global state is a consequence of single-device reality, not a design shortcut. That said, if multi-device support matters in the future (e.g. multiple M-series chips), we'd refactor to allocate real client state per create call. Happy to rework this now if you have a specific design in mind.
de4cc7a to
c58d751
Compare
|
Hey — squashed everything down to a single clean commit and addressed the feedback from the last round:
23 files in the diff now, all Metal extension + minimal core touches. Tests passing locally (elementwise, autodiff, CNN, non-square conv). Am I on the right track with this? Happy to adjust anything. |
src/xla/PJRT/Client.jl
Outdated
| GC.@preserve errstr begin | ||
| client = @ccall MLIR.API.mlir_c.MakeClientFromApi( | ||
| api_ptr::Ptr{Cvoid}, | ||
| device_type::Cstring, | ||
| client_name::Cstring, | ||
| errstr::Ptr{Cstring}, | ||
| )::Ptr{Cvoid} | ||
| end |
There was a problem hiding this comment.
Avoid this ccall and use the MakeClientFromApi from libmlir_h.jl
c58d751 to
694f865
Compare
src/xla/XLA.jl
Outdated
| # Apple Silicon: Metal PJRT backend via ReactantMetalExt/MPSGraph | ||
| if Accelerators.Metal.has_metal() | ||
| if was_initialized && haskey(state.clients, "metal") | ||
| free_client(state.clients["metal"]) | ||
| $(runtime).metal_client_count[] -= 1 | ||
| end | ||
| gpu = $(runtime).MetalClient(; | ||
| metal_pjrt_plugin_path=Accelerators.Metal.get_metal_pjrt_plugin_path(), | ||
| common_kwargs..., | ||
| ) | ||
| state.clients["metal"] = gpu | ||
| # Don't put this in the default_client since metal support is fairly | ||
| # limited | ||
| =# | ||
| # Metal PJRT plugin is not yet compatible with latest OpenXLA | ||
| catch e | ||
| println(stdout, e) | ||
| try | ||
| metal = $(runtime).MetalClient() | ||
| state.clients["metal"] = metal | ||
| state.default_client = metal | ||
| catch e | ||
| println(stdout, e) | ||
| end |
There was a problem hiding this comment.
Once you rebase, you can register the plugin in the MetalExt itself
Adds a Metal GPU backend for Reactant.jl as a package extension (ReactantMetalExt). Uses Apple's MPSGraph framework to compile and execute StableHLO operations on Metal GPUs. Key components: - PJRT plugin implementation via Julia @cfunction callbacks - MLIR walker that translates StableHLO ops to MPSGraph operations - MTLTensor-based buffer management with proper retain/release - Support for: elementwise ops, conv2d/3d, pooling, reduce, matmul, reshape, transpose, broadcast, concatenate, pad, dot_general, and more - Automatic differentiation works through Enzyme Requires macOS 26+ with Metal.jl >= 1.8. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
694f865 to
04b344a
Compare
|
I should have some more free time this weekend, if there is anything else I can do to help get this to the finish line |



Summary
Pure-Julia Metal GPU backend for Reactant on Apple Silicon. Instead of depending on an external PJRT plugin shared library (the old jax-metal
.dylibapproach, which is no longer compatible with the current OpenXLA), this implements the full PJRT callback interface directly in Julia using@cfunctionpointers, then walks the optimized StableHLO IR to build an equivalent MPSGraph that executes on the Metal GPU.Target UX:
using Reactant, Metal; @jit f(x)— transparent dispatch, no special API.How it works
The optimization pipeline has two layers: XLA/MLIR does high-level fusion and CSE on the IR, then MPSGraph does Metal-specific kernel fusion and scheduling on the GPU side.
What's included
MakeClientFromApi): Registers a Julia-allocatedPJRT_Apistruct directly with XLA — nodlopenneededPJRTPlugin.jl): FullPJRT_Apiimplementation covering client lifecycle, device/memory discovery, buffer management, compilation, and executionMLIRWalker.jl): Translates StableHLO ops to MPSGraph nodes — supports element-wise ops,dot_general,broadcast_in_dim,reshape,transpose,reduce(sum/max),conv2d/conv3d,reduce_window(pooling 2D/3D),concatenate,slice,scatter,reverse, andconstantXLACompiler.jl): MPSGraph operations not wrapped by Metal.jlMETAL_XLA_LOCKserializes buffer operations to prevent heap corruption from concurrent GC finalizer and main thread access toPjRtCApiClient@jitcalls to avoid per-call allocationlldlinker (unavailable on macOS) and enables platform-aware Bazel toolchain resolutionWhat works today
sin,cos,exp,tanh,relu, etc.)ChainmodelsArchitecture decisions
ReactantMetalExt): Loaded automatically whenusing Metalbrings Metal.jl into scope. No changes needed to user code.__precompile__(false): Required because the extension overridesBase.convert,XLA.free_buffer, andXLA.to_hostfor thread-safety. Julia disallows method overwrites during precompilation.@cfunctionpointers stored in aLibc.malloc'd struct. This eliminates the need for any external binary beyond the existinglibReactantExtra.placeholderTensorauto-reverses Julia shapes. The walker uses IR shapes directly for all operations, with layout permutations only at conv/pool boundaries.Development process
This backend was developed over ~48 commits using an autonomous agent loop ("ralph loop") powered by Claude Code. The agent iteratively implemented and verified each component — from the initial PJRT callback prototype through conv layout bugs and thread-safety fixes. This PR is a clean 5-commit squash of that work onto
origin/main, containing only the necessary production code. All development scaffolding (research files, debug tests, benchmark notebooks) has been removed.Known limitations
stablehlo.convertis identity-only (no actual dtype casting yet)reducefor min/prodFiles changed (15 files, +3,395 / -77)
deps/ReactantExtra/API.cppMakeClientFromApi()deps/ReactantExtra/BUILDdeps/build_local.jlsrc/accelerators/Metal.jlhas_metal()/setup_metal!()src/xla/Device.jl@warn→@debugsrc/xla/PJRT/Client.jlMakeMetalClientFromApi,_metal_pjrt_api_ptrsrc/xla/XLA.jlext/ReactantMetalExt.jlext/ReactantMetalExt/MLIRWalker.jlext/ReactantMetalExt/PJRTPlugin.jlext/ReactantMetalExt/XLACompiler.jlProject.tomltest/Project.tomltest/plugins/metal.jltest/runtests.jlTest plan
julia test/plugins/metal.jlon macOS with Apple Silicon — sincos, autodiff, CNN all passjulia -e 'using Reactant; println(Reactant.XLA.default_backend())'— basic Reactant still works on non-MacMetalis NOT in[deps](only[weakdeps]) — no new mandatory dependency🤖 Generated with Claude Code