Skip to content

Metal PJRT backend via MPSGraph + pure-Julia plugin#2489

Open
Dale-Black wants to merge 1 commit intoEnzymeAD:mainfrom
Dale-Black:metal-pjrt-backend
Open

Metal PJRT backend via MPSGraph + pure-Julia plugin#2489
Dale-Black wants to merge 1 commit intoEnzymeAD:mainfrom
Dale-Black:metal-pjrt-backend

Conversation

@Dale-Black
Copy link
Contributor

Summary

Pure-Julia Metal GPU backend for Reactant on Apple Silicon. Instead of depending on an external PJRT plugin shared library (the old jax-metal .dylib approach, which is no longer compatible with the current OpenXLA), this implements the full PJRT callback interface directly in Julia using @cfunction pointers, then walks the optimized StableHLO IR to build an equivalent MPSGraph that executes on the Metal GPU.

Target UX: using Reactant, Metal; @jit f(x) — transparent dispatch, no special API.

How it works

Julia code → Reactant tracing → XLA/MLIR optimization (fusion, CSE, layout opt)
→ Optimized StableHLO IR → PJRT compile callback → MLIR walker → MPSGraph → Metal GPU

The optimization pipeline has two layers: XLA/MLIR does high-level fusion and CSE on the IR, then MPSGraph does Metal-specific kernel fusion and scheduling on the GPU side.

What's included

  • C++ bridge (MakeClientFromApi): Registers a Julia-allocated PJRT_Api struct directly with XLA — no dlopen needed
  • 30 PJRT callbacks (PJRTPlugin.jl): Full PJRT_Api implementation covering client lifecycle, device/memory discovery, buffer management, compilation, and execution
  • MLIR walker (MLIRWalker.jl): Translates StableHLO ops to MPSGraph nodes — supports element-wise ops, dot_general, broadcast_in_dim, reshape, transpose, reduce (sum/max), conv2d/conv3d, reduce_window (pooling 2D/3D), concatenate, slice, scatter, reverse, and constant
  • @objc bindings (XLACompiler.jl): MPSGraph operations not wrapped by Metal.jl
  • Thread-safety: METAL_XLA_LOCK serializes buffer operations to prevent heap corruption from concurrent GC finalizer and main thread access to PjRtCApiClient
  • MtlArray pool: Recycles GPU buffers across @jit calls to avoid per-call allocation
  • macOS build fix: Disables lld linker (unavailable on macOS) and enables platform-aware Bazel toolchain resolution

What works today

  • Element-wise math (sin, cos, exp, tanh, relu, etc.)
  • Dense layers and Chain models
  • Conv2D and Conv3D with arbitrary layouts
  • Max/avg pooling (2D and 3D)
  • Enzyme autodiff (forward and reverse mode)
  • Full Lux CNN (Conv → Pool → Dense pipeline)

Architecture decisions

  1. Package extension (ReactantMetalExt): Loaded automatically when using Metal brings Metal.jl into scope. No changes needed to user code.
  2. __precompile__(false): Required because the extension overrides Base.convert, XLA.free_buffer, and XLA.to_host for thread-safety. Julia disallows method overwrites during precompilation.
  3. Direct PJRT_Api registration (Option A): Rather than building a C shared library, all 30 PJRT callbacks are Julia @cfunction pointers stored in a Libc.malloc'd struct. This eliminates the need for any external binary beyond the existing libReactantExtra.
  4. IR convention for tensors: MPSGraph tensors use IR (row-major) convention internally because placeholderTensor auto-reverses Julia shapes. The walker uses IR shapes directly for all operations, with layout permutations only at conv/pool boundaries.

Development process

This backend was developed over ~48 commits using an autonomous agent loop ("ralph loop") powered by Claude Code. The agent iteratively implemented and verified each component — from the initial PJRT callback prototype through conv layout bugs and thread-safety fixes. This PR is a clean 5-commit squash of that work onto origin/main, containing only the necessary production code. All development scaffolding (research files, debug tests, benchmark notebooks) has been removed.

Known limitations

  • Conv-after-concat with non-square spatial dims has a known shape mismatch (the "L7 problem" in UNet patterns) — under investigation
  • stablehlo.convert is identity-only (no actual dtype casting yet)
  • No reduce for min/prod
  • Float64 and Int64 are silently downcast to Float32/Int32 (MPSGraph limitation)

Files changed (15 files, +3,395 / -77)

File Change
deps/ReactantExtra/API.cpp +21: MakeClientFromApi()
deps/ReactantExtra/BUILD +1: export symbol
deps/build_local.jl ~8: macOS build fix
src/accelerators/Metal.jl rewrite: has_metal()/setup_metal!()
src/xla/Device.jl +1: @warn@debug
src/xla/PJRT/Client.jl +26: MakeMetalClientFromApi, _metal_pjrt_api_ptr
src/xla/XLA.jl ~22: enable Metal client init
ext/ReactantMetalExt.jl +147: extension entry point
ext/ReactantMetalExt/MLIRWalker.jl +1,576: MLIR → MPSGraph
ext/ReactantMetalExt/PJRTPlugin.jl +1,197: 30 PJRT callbacks
ext/ReactantMetalExt/XLACompiler.jl +369: @objc MPSGraph bindings
Project.toml +3: Metal in weakdeps + compat + extension
test/Project.toml +1: Metal in test deps
test/plugins/metal.jl ~16: fixes for Metal backend
test/runtests.jl ~4: enable Metal tests on macOS

Test plan

  • julia test/plugins/metal.jl on macOS with Apple Silicon — sincos, autodiff, CNN all pass
  • julia -e 'using Reactant; println(Reactant.XLA.default_backend())' — basic Reactant still works on non-Mac
  • Verify CI passes on Linux/CUDA (no functional changes to non-Metal paths)
  • Verify Metal is NOT in [deps] (only [weakdeps]) — no new mandatory dependency

🤖 Generated with Claude Code

@codecov
Copy link

codecov bot commented Feb 20, 2026

Codecov Report

❌ Patch coverage is 0.53381% with 1677 lines in your changes missing coverage. Please review.
✅ Project coverage is 34.13%. Comparing base (b39a1fc) to head (04b344a).
⚠️ Report is 788 commits behind head on main.

Files with missing lines Patch % Lines
ext/ReactantMetalExt/MLIRWalker.jl 0.00% 721 Missing ⚠️
ext/ReactantMetalExt/Executable.jl 0.00% 305 Missing ⚠️
ext/ReactantMetalExt/PJRTPlugin.jl 0.00% 228 Missing ⚠️
ext/ReactantMetalExt/XLACompiler.jl 0.00% 152 Missing ⚠️
ext/ReactantMetalExt/Buffer.jl 0.74% 134 Missing ⚠️
ext/ReactantMetalExt/Device.jl 0.00% 43 Missing ⚠️
ext/ReactantMetalExt/Client.jl 0.00% 28 Missing ⚠️
ext/ReactantMetalExt/Memory.jl 0.00% 22 Missing ⚠️
ext/ReactantMetalExt/ReactantMetalExt.jl 6.25% 15 Missing ⚠️
ext/ReactantMetalExt/Event.jl 0.00% 13 Missing ⚠️
... and 3 more
Additional details and impacted files
@@             Coverage Diff             @@
##             main    #2489       +/-   ##
===========================================
- Coverage   68.16%   34.13%   -34.03%     
===========================================
  Files         109      214      +105     
  Lines       11779    30852    +19073     
===========================================
+ Hits         8029    10531     +2502     
- Misses       3750    20321    +16571     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

# ============================================================================

"""Extract contracting_dims from dot_general op text."""
function parse_contracting_dims(op_text::AbstractString)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We dont need to parse the string here, we should be able to query the operation to extract these info

@Dale-Black
Copy link
Contributor Author

Split out the C++/Bazel changes into #2490 per @avik-pal's request. That PR adds only MakeClientFromApi to API.cpp and the symbol export in BUILD.

Once the JLL is rebuilt with that symbol, the Julia changes in this PR will work against the new JLL (no more LocalPreferences.toml / local build requirement).

I'll rebase this PR to remove the deps/ReactantExtra/ commits once #2490 is merged.

@Dale-Black
Copy link
Contributor Author

MLIR API Refactor — Addressing String Parsing Feedback

@avik-pal — thanks for the review. I interpreted your comment about not needing to parse strings as referring to the parse_* functions in MLIRWalker.jl that were using regex on string(op) and string(IR.type(...)) to extract attributes and type information. I've replaced all 12 of those functions with proper MLIR C API calls.

What changed (latest commit)

Replaced all string-based attribute/type extraction with API calls:

  • Simple attributes (broadcast_dimensions, permutation, dimensions, etc.): IR.getattr(op, name) + DenseArray indexing
  • Structured attributes (conv dimension_numbers, dot dimension_numbers): StableHLO C API functions (stablehloConvDimensionNumbersGet*, stablehloDotDimensionNumbersGet*)
  • Dense elements (padding): API.mlirDenseElementsAttrGetInt64Value (the Julia-level wrapper has a known bug for Int64/Float32 element types, so we call the C API directly)
  • Type inspection: IR.type / IR.ndims / IR.size / IR.eltype instead of regex-parsing "tensor<4x8xf32>" strings

Net result is -163 lines since the API calls are more concise than the regex parsers.

What we tested

  • test/plugins/metal.jl all 3 testsets pass (sincos, autodiff with Enzyme, CNN) — verified after each individual function replacement
  • Quick conv test with non-square input (24x16) to catch layout regressions
  • One remaining string(op) call is used for error messages on unrecognized ops (diagnostic, not attribute extraction)

I'll share some screenshots from the local Pluto benchmark notebook in a follow-up comment. Please let me know if this is what you had in mind or if there are other areas that need attention — still learning my way around the MLIR infrastructure here.

@Dale-Black
Copy link
Contributor Author

image image image

Since most of this is agent coded, I have been verifying with this local notebook I have been running that I THINK makes it hard to purely hallucinate these results (which look promising as far as I can tell!)

Local Pluto Notebook
### A Pluto.jl notebook ###
# v0.20.13

using Markdown
using InteractiveUtils

# ╔═╡ a1b2c3d4-0003-0001-0001-000000000001
# ╠═╡ show_logs = false
begin
	import Pkg
	env = mktempdir()
	Pkg.activate(env)
	# Write LocalPreferences BEFORE any Pkg ops (precompilation caches the JLL path)
	open(joinpath(env, "LocalPreferences.toml"), "w") do io
		println(io, "[Reactant_jll]")
		println(io, "libReactantExtra_path = \"", expanduser("~/Documents/dev/julia/Reactant.jl/deps/ReactantExtra/bazel-bin/libReactantExtra.so"), "\"")
	end
	# Use the local Reactant.jl checkout (metal-pjrt-backend branch)
	Pkg.develop(path=expanduser("~/Documents/dev/julia/Reactant.jl"))
	# Reactant_jll must be a DIRECT dep for Preferences.jl to find LocalPreferences.toml
	Pkg.add(["Lux", "Metal", "Reactant_jll", "Statistics"])
end

# ╔═╡ a1b2c3d4-0004-0001-0001-000000000001
begin
	using Random
	using Lux
	using Metal
	using Metal: MtlArray
	using Reactant
	using Statistics
end

# ╔═╡ a1b2c3d4-0001-0001-0001-000000000001
md"""
# Metal GPU Backend Benchmark

## CPU vs Metal.jl vs Reactant+Metal

This notebook compares three ways to run a Lux neural network on Apple Silicon:

1. **CPU only** — plain Julia arrays
2. **Metal.jl only** — GPU via `MtlArray`, no compiler optimization
3. **Reactant + Metal** — GPU via `@jit`, with XLA/MLIR optimization (op fusion, CSE, etc.)

The key insight: Reactant+Metal preserves **two** optimization layers — XLA/MLIR graph optimization AND Metal GPU execution — which neither CPU nor Metal.jl alone can match.
"""

# ╔═╡ a1b2c3d4-0002-0001-0001-000000000001
md"""
## Setup
"""

# ╔═╡ a1b2c3d4-0005-0001-0001-000000000001
md"""
## Model Definition

A larger `Dense` chain: 2048 → 1024 → 512 → 256 → 10. GPU benefits only appear when matrix operations are big enough that compute time dominates dispatch overhead. We test at batch sizes 256, 1024, and 4096.
"""

# ╔═╡ a1b2c3d4-0006-0001-0001-000000000001
begin
	const INPUT_DIM = 2048
	const BATCH_SIZES = [256, 1024, 4096]
	const N_WARMUP = 5
	const N_TRIALS = 20

	rng = Random.MersenneTwister(42)

	model = Chain(
		Dense(INPUT_DIM => 1024, relu),
		Dense(1024 => 512, relu),
		Dense(512 => 256, relu),
		Dense(256 => 10),
	)
	ps_cpu, st_cpu = Lux.setup(rng, model)
	nothing
end

# ╔═╡ a1b2c3d4-0007-0001-0001-000000000001
md"""
## Benchmark 1: CPU Only

Standard Julia arrays, no GPU. This is the baseline.
"""

# ╔═╡ a1b2c3d4-0008-0001-0001-000000000001
function bench_cpu(model, ps, st, input_dim, batch_size; n_warmup=N_WARMUP, n_trials=N_TRIALS)
	x = randn(Float32, input_dim, batch_size)
	f(m, x, ps, st) = first(m(x, ps, st))

	# Warmup
	for _ in 1:n_warmup
		f(model, x, ps, st)
	end

	# Timed runs
	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		stats = @timed f(model, x, ps, st)
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end

	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024)
end

# ╔═╡ a1b2c3d4-0009-0001-0001-000000000001
cpu_results = Dict(
	bs => bench_cpu(model, ps_cpu, st_cpu, INPUT_DIM, bs)
	for bs in BATCH_SIZES
)

# ╔═╡ a1b2c3d4-0010-0001-0001-000000000001
md"""
## Benchmark 2: Metal.jl Only

Move arrays to GPU via `MtlArray`. Lux supports this via `gpu_device()`. The model runs on Metal GPU but without any graph-level optimization — each op is dispatched individually.
"""

# ╔═╡ a1b2c3d4-0011-0001-0001-000000000001
begin
	gdev = gpu_device()
	cdev = cpu_device()
	ps_mtl, st_mtl = (ps_cpu, st_cpu) |> gdev
	md"Metal device: $(gdev)"
end

# ╔═╡ a1b2c3d4-0012-0001-0001-000000000001
function bench_metal(model, ps_gpu, st_gpu, input_dim, batch_size, gdev;
		n_warmup=N_WARMUP, n_trials=N_TRIALS)
	x_gpu = MtlArray(randn(Float32, input_dim, batch_size))
	f(m, x, ps, st) = first(m(x, ps, st))

	# Warmup (+ sync)
	for _ in 1:n_warmup
		y = f(model, x_gpu, ps_gpu, st_gpu)
		Metal.synchronize()
	end

	# Timed runs — sync after each to get true GPU time
	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		Metal.synchronize()
		stats = @timed begin
			y = f(model, x_gpu, ps_gpu, st_gpu)
			Metal.synchronize()
		end
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end

	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024)
end

# ╔═╡ a1b2c3d4-0013-0001-0001-000000000001
metal_results = Dict(
	bs => bench_metal(model, ps_mtl, st_mtl, INPUT_DIM, bs, gdev)
	for bs in BATCH_SIZES
)

# ╔═╡ a1b2c3d4-0014-0001-0001-000000000001
md"""
## Benchmark 3: Reactant + Metal

`@compile` traces the function, optimizes the MLIR graph (fusion, CSE, constant folding), and returns a compiled executable for Metal GPU via our PJRT plugin. Compilation happens once; the compiled function is called repeatedly for timing. (Note: `@jit` recompiles every call — always use `@compile` for benchmarks.)
"""

# ╔═╡ a1b2c3d4-0015-0001-0001-000000000001
function bench_reactant_metal(model, ps, st, input_dim, batch_size;
		n_warmup=N_WARMUP, n_trials=N_TRIALS)
	x = randn(Float32, input_dim, batch_size)

	x_ra = Reactant.to_rarray(x)
	ps_ra = Reactant.to_rarray(ps)
	st_ra = Reactant.to_rarray(st)

	f(m, x, ps, st) = first(m(x, ps, st))

	# Compile ONCE — @compile returns a cached compiled function
	compile_stats = @timed begin
		compiled_f = @compile f(model, x_ra, ps_ra, st_ra)
	end
	compile_ms = compile_stats.time * 1000

	# Warmup the compiled function (no recompilation)
	for _ in 1:n_warmup
		compiled_f(model, x_ra, ps_ra, st_ra)
	end

	# Timed runs — execute only, no host transfer (apples-to-apples with Metal.jl)
	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		stats = @timed begin
			compiled_f(model, x_ra, ps_ra, st_ra)
		end
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end

	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024,
	   compile_ms)
end

# ╔═╡ a1b2c3d4-0016-0001-0001-000000000001
reactant_results = Dict(
	bs => bench_reactant_metal(model, ps_cpu, st_cpu, INPUT_DIM, bs)
	for bs in BATCH_SIZES
)

# ╔═╡ a1b2c3d4-0017-0001-0001-000000000001
md"""
## Correctness Check

Before trusting the benchmarks, verify all three backends produce the same result.
"""

# ╔═╡ a1b2c3d4-0018-0001-0001-000000000001
begin
	x_test = randn(Float32, INPUT_DIM, 8)
	f_test(m, x, ps, st) = first(m(x, ps, st))

	# CPU
	y_cpu = f_test(model, x_test, ps_cpu, st_cpu)

	# Metal.jl
	x_mtl_test = MtlArray(x_test)
	y_mtl = Array(f_test(model, x_mtl_test, ps_mtl, st_mtl))

	# Reactant + Metal (compile once, then execute)
	x_ra_test = Reactant.to_rarray(x_test)
	ps_ra_test = Reactant.to_rarray(ps_cpu)
	st_ra_test = Reactant.to_rarray(st_cpu)
	compiled_test = @compile f_test(model, x_ra_test, ps_ra_test, st_ra_test)
	y_reactant = Array(compiled_test(model, x_ra_test, ps_ra_test, st_ra_test))

	err_metal = maximum(abs.(y_cpu .- y_mtl))
	err_reactant = maximum(abs.(y_cpu .- y_reactant))

	md"""
	| Backend | Max Error vs CPU |
	|---------|-----------------|
	| Metal.jl | $(round(err_metal; sigdigits=3)) |
	| Reactant+Metal | $(round(err_reactant; sigdigits=3)) |

	Both should be < 1e-5 (float32 precision).
	"""
end

# ╔═╡ a1b2c3d4-0019-0001-0001-000000000001
md"""
## Results Comparison
"""

# ╔═╡ a1b2c3d4-0020-0001-0001-000000000001
begin
	header = "| Batch Size | CPU (ms) | Metal.jl (ms) | Reactant+Metal (ms) | Speedup vs CPU | Speedup vs Metal |"
	sep    = "|-----------|---------|--------------|--------------------|--------------:|----------------:|"
	rows = String[]
	for bs in BATCH_SIZES
		c = cpu_results[bs]
		m = metal_results[bs]
		r = reactant_results[bs]
		speedup_cpu = round(c.median_ms / r.median_ms; digits=1)
		speedup_metal = round(m.median_ms / r.median_ms; digits=1)
		push!(rows, "| $bs | $(round(c.median_ms; digits=3)) | $(round(m.median_ms; digits=3)) | $(round(r.median_ms; digits=3)) | $(speedup_cpu)x | $(speedup_metal)x |")
	end

	Markdown.parse(join([header, sep, rows...], "\n"))
end

# ╔═╡ a1b2c3d4-0021-0001-0001-000000000001
begin
	alloc_header = "| Batch Size | CPU (KB) | Metal.jl (KB) | Reactant+Metal (KB) |"
	alloc_sep    = "|-----------|---------|--------------|--------------------:|"
	alloc_rows = String[]
	for bs in BATCH_SIZES
		c = cpu_results[bs]
		m = metal_results[bs]
		r = reactant_results[bs]
		push!(alloc_rows, "| $bs | $(round(c.median_alloc_kb; digits=1)) | $(round(m.median_alloc_kb; digits=1)) | $(round(r.median_alloc_kb; digits=1)) |")
	end

	Markdown.parse(join(["### Allocations (median per call)", "", alloc_header, alloc_sep, alloc_rows...], "\n"))
end

# ╔═╡ a1b2c3d4-0022-0001-0001-000000000001
begin
	compile_header = "| Batch Size | First-call compile (ms) |"
	compile_sep    = "|-----------|----------------------:|"
	compile_rows = [
		"| $bs | $(round(reactant_results[bs].compile_ms; digits=1)) |"
		for bs in BATCH_SIZES
	]

	Markdown.parse(join(["### Reactant Compile Time (one-time cost)", "", compile_header, compile_sep, compile_rows...], "\n"))
end

# ╔═╡ b2c3d4e5-0001-0001-0001-000000000001
md"""
## Benchmark 4: Fusion Stress Test

This tests where Reactant+Metal should **dominate**: a long chain of element-wise ops on large tensors. Metal.jl dispatches each broadcast as a separate GPU kernel. Reactant fuses them ALL into a single kernel — one launch instead of many.
"""

# ╔═╡ b2c3d4e5-0002-0001-0001-000000000001
begin
	"""10 element-wise ops that XLA fuses into 1-2 kernels (vs 10 kernel launches in Metal.jl)."""
	function elementwise_chain(x)
		x = x .* 2.0f0
		x = x .+ 1.0f0
		x = tanh.(x)
		x = x .* x            # square
		x = x .- 0.5f0
		x = exp.(x)
		x = x ./ (x .+ 1.0f0) # sigmoid-like
		x = abs.(x)
		x = x .* 3.0f0
		x = x .- x .* 0.1f0
		return x
	end

	const FUSION_SIZES = [2048*256, 2048*1024, 2048*4096]
	md"Defined `elementwise_chain` — 10 chained broadcasts."
end

# ╔═╡ b2c3d4e5-0003-0001-0001-000000000001
function bench_cpu_fusion(f, n; n_warmup=N_WARMUP, n_trials=N_TRIALS)
	x = randn(Float32, n)
	for _ in 1:n_warmup; f(x); end
	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		stats = @timed f(x)
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end
	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024)
end

# ╔═╡ b2c3d4e5-0004-0001-0001-000000000001
function bench_metal_fusion(f, n; n_warmup=N_WARMUP, n_trials=N_TRIALS)
	x_gpu = MtlArray(randn(Float32, n))
	for _ in 1:n_warmup
		f(x_gpu)
		Metal.synchronize()
	end
	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		Metal.synchronize()
		stats = @timed begin
			f(x_gpu)
			Metal.synchronize()
		end
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end
	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024)
end

# ╔═╡ b2c3d4e5-0005-0001-0001-000000000001
function bench_reactant_fusion(f, n; n_warmup=N_WARMUP, n_trials=N_TRIALS)
	x = randn(Float32, n)
	x_ra = Reactant.to_rarray(x)

	compile_stats = @timed begin
		compiled_f = @compile f(x_ra)
	end
	compile_ms = compile_stats.time * 1000

	for _ in 1:n_warmup
		compiled_f(x_ra)
	end

	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		stats = @timed compiled_f(x_ra)
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end
	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024,
	   compile_ms)
end

# ╔═╡ b2c3d4e5-0006-0001-0001-000000000001
begin
	fusion_cpu = Dict(n => bench_cpu_fusion(elementwise_chain, n) for n in FUSION_SIZES)
	fusion_metal = Dict(n => bench_metal_fusion(elementwise_chain, n) for n in FUSION_SIZES)
	fusion_reactant = Dict(n => bench_reactant_fusion(elementwise_chain, n) for n in FUSION_SIZES)
	md"Fusion benchmarks complete."
end

# ╔═╡ b2c3d4e5-0007-0001-0001-000000000001
begin
	fh = "| Elements | CPU (ms) | Metal.jl (ms) | Reactant+Metal (ms) | Speedup vs CPU | Speedup vs Metal |"
	fs = "|---------|---------|--------------|--------------------|--------------:|----------------:|"
	frows = String[]
	for n in FUSION_SIZES
		c = fusion_cpu[n]
		m = fusion_metal[n]
		r = fusion_reactant[n]
		sp_cpu = round(c.median_ms / r.median_ms; digits=1)
		sp_metal = round(m.median_ms / r.median_ms; digits=1)
		push!(frows, "| $(n) | $(round(c.median_ms; digits=3)) | $(round(m.median_ms; digits=3)) | $(round(r.median_ms; digits=3)) | $(sp_cpu)x | $(sp_metal)x |")
	end

	Markdown.parse(join(["### Fusion Stress Test Results", "", "10 chained element-wise ops on large 1D tensors. Metal.jl = 10 kernel launches. Reactant = 1-2 fused kernels.", "", fh, fs, frows...], "\n"))
end

# ╔═╡ b2c3d4e5-0008-0001-0001-000000000001
begin
	fa_header = "| Elements | CPU (KB) | Metal.jl (KB) | Reactant+Metal (KB) |"
	fa_sep    = "|---------|---------|--------------|--------------------:|"
	fa_rows = String[]
	for n in FUSION_SIZES
		c = fusion_cpu[n]
		m = fusion_metal[n]
		r = fusion_reactant[n]
		push!(fa_rows, "| $(n) | $(round(c.median_alloc_kb; digits=1)) | $(round(m.median_alloc_kb; digits=1)) | $(round(r.median_alloc_kb; digits=1)) |")
	end

	Markdown.parse(join(["### Fusion Allocations (median per call)", "", fa_header, fa_sep, fa_rows...], "\n"))
end

# ╔═╡ c3d4e5f6-0001-0001-0001-000000000001
md"""
## Benchmark 5: 2D UNet-like Model (Conv + Pool + Skip Connections)

A model with **convolution**, **max pooling**, **residual add**, and **concatenation skip connections** — the core building blocks of a UNet.

This exercises three new op handlers:
- `stablehlo.convolution` → `MPSGraph convolution2D`
- `stablehlo.reduce_window` → `MPSGraph maxPooling2D`
- `stablehlo.concatenate` → `MPSGraph concatTensors`

**No Metal.jl column:** Metal.jl has no native GPU kernels for conv/pool (NNlib falls back to CPU scalar indexing). This is exactly why Reactant+Metal exists — it compiles these ops directly to MPSGraph.
"""

# ╔═╡ c3d4e5f6-0002-0001-0001-000000000001
begin
	const IMG_SIZE = (64, 64)
	const IMG_CH = 1
	const UNET_BATCHES = [1, 4, 16]

	unet_model = Chain(
		# Initial projection
		Conv((3, 3), 1 => 16, relu; pad=1),                  # 64×64×16

		# Residual block (addition skip connection)
		SkipConnection(
			Chain(
				Conv((3, 3), 16 => 16, relu; pad=1),
				Conv((3, 3), 16 => 16, relu; pad=1),
			),
			+
		),                                                      # 64×64×16

		# Downsample
		MaxPool((2, 2)),                                        # 32×32×16

		# Concatenation skip block (channel-cat skip connection)
		SkipConnection(
			Chain(
				Conv((3, 3), 16 => 16, relu; pad=1),
				Conv((3, 3), 16 => 16, relu; pad=1),
			),
			(mx, x) -> cat(mx, x; dims=3)                      # concat along C dim
		),                                                      # 32×32×32

		# More convolution
		Conv((3, 3), 32 => 32, relu; pad=1),                   # 32×32×32

		# Downsample again
		MaxPool((2, 2)),                                        # 16×16×32

		# Bottleneck
		Conv((3, 3), 32 => 64, relu; pad=1),                   # 16×16×64

		# Output head
		Conv((1, 1), 64 => 1),                                  # 16×16×1
	)

	unet_ps_cpu, unet_st_cpu = Lux.setup(rng, unet_model)
	md"2D UNet model: 8 Conv layers, 2 MaxPool, 1 residual add, 1 concat skip."
end

# ╔═╡ c3d4e5f6-0003-0001-0001-000000000001
function bench_unet_cpu(model, ps, st, batch_size; n_warmup=N_WARMUP, n_trials=N_TRIALS)
	x = randn(Float32, IMG_SIZE..., IMG_CH, batch_size)
	f(m, x, ps, st) = first(m(x, ps, st))
	for _ in 1:n_warmup; f(model, x, ps, st); end
	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		stats = @timed f(model, x, ps, st)
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end
	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024)
end

# ╔═╡ c3d4e5f6-0004-0001-0001-000000000001
unet_cpu_results = Dict(
	bs => bench_unet_cpu(unet_model, unet_ps_cpu, unet_st_cpu, bs)
	for bs in UNET_BATCHES
)

# ╔═╡ c3d4e5f6-0005-0001-0001-000000000001
function bench_unet_reactant(model, ps, st, batch_size;
		n_warmup=N_WARMUP, n_trials=N_TRIALS)
	x = randn(Float32, IMG_SIZE..., IMG_CH, batch_size)
	x_ra = Reactant.to_rarray(x)
	ps_ra = Reactant.to_rarray(ps)
	st_ra = Reactant.to_rarray(st)
	f(m, x, ps, st) = first(m(x, ps, st))

	compile_stats = @timed begin
		compiled_f = @compile f(model, x_ra, ps_ra, st_ra)
	end
	compile_ms = compile_stats.time * 1000

	for _ in 1:n_warmup
		compiled_f(model, x_ra, ps_ra, st_ra)
	end

	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		stats = @timed compiled_f(model, x_ra, ps_ra, st_ra)
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end
	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024,
	   compile_ms)
end

# ╔═╡ c3d4e5f6-0006-0001-0001-000000000001
unet_reactant_results = Dict(
	bs => bench_unet_reactant(unet_model, unet_ps_cpu, unet_st_cpu, bs)
	for bs in UNET_BATCHES
)

# ╔═╡ c3d4e5f6-0007-0001-0001-000000000001
begin
	x_unet_test = randn(Float32, IMG_SIZE..., IMG_CH, 2)
	f_unet(m, x, ps, st) = first(m(x, ps, st))

	# CPU reference
	y_unet_cpu = f_unet(unet_model, x_unet_test, unet_ps_cpu, unet_st_cpu)

	# Reactant+Metal
	x_unet_ra = Reactant.to_rarray(x_unet_test)
	ps_unet_ra = Reactant.to_rarray(unet_ps_cpu)
	st_unet_ra = Reactant.to_rarray(unet_st_cpu)
	compiled_unet = @compile f_unet(unet_model, x_unet_ra, ps_unet_ra, st_unet_ra)
	y_unet_reactant = Array(compiled_unet(unet_model, x_unet_ra, ps_unet_ra, st_unet_ra))

	unet_err_reactant = maximum(abs.(y_unet_cpu .- y_unet_reactant))

	md"""
	### 2D UNet Correctness Check

	| Backend | Max Error vs CPU |
	|---------|-----------------|
	| Reactant+Metal | $(round(unet_err_reactant; sigdigits=3)) |

	Should be < 1e-4 (conv accumulation allows slightly more error than dense).
	"""
end

# ╔═╡ c3d4e5f6-0008-0001-0001-000000000001
begin
	uh = "| Batch | CPU (ms) | Reactant+Metal (ms) | Speedup | Compile (ms) |"
	us = "|------|---------|--------------------|---------:|------------:|"
	urows = String[]
	for bs in UNET_BATCHES
		c = unet_cpu_results[bs]
		r = unet_reactant_results[bs]
		sp = round(c.median_ms / r.median_ms; digits=1)
		push!(urows, "| $bs | $(round(c.median_ms; digits=3)) | $(round(r.median_ms; digits=3)) | $(sp)x | $(round(r.compile_ms; digits=1)) |")
	end

	Markdown.parse(join(["### 2D UNet Results (CPU vs Reactant+Metal)", "",
		"8 Conv layers, 2 MaxPool, residual add + concat skip. Metal.jl cannot run this on GPU.", "",
		uh, us, urows...], "\n"))
end

# ╔═╡ c3d4e5f6-0009-0001-0001-000000000001
begin
	ua_header = "| Batch | CPU alloc (KB) | Reactant+Metal alloc (KB) |"
	ua_sep    = "|------|---------------:|-------------------------:|"
	ua_rows = String[]
	for bs in UNET_BATCHES
		c = unet_cpu_results[bs]
		r = unet_reactant_results[bs]
		push!(ua_rows, "| $bs | $(round(c.median_alloc_kb; digits=1)) | $(round(r.median_alloc_kb; digits=1)) |")
	end

	Markdown.parse(join(["### 2D UNet Allocations", "", ua_header, ua_sep, ua_rows...], "\n"))
end

# ╔═╡ d4e5f6a7-0001-0001-0001-000000000001
md"""
## Benchmark 6: 3D UNet-like Model (3D Conv + Skip Connections)

A volumetric model with **3D convolution**, **residual add**, and **concatenation skip connections** — the building blocks of a 3D UNet for medical imaging (CT/MRI segmentation).

This exercises the 3D convolution handler:
- `stablehlo.convolution` (5D tensors) → `MPSGraph convolution3D`
- `stablehlo.concatenate` → `MPSGraph concatTensors`

**No pooling:** MPSGraph has no 5D pooling. Downsampling uses stride-2 convolutions instead (common in modern architectures like V-Net).

**No Metal.jl column:** Same as 2D — Metal.jl has no native GPU kernels for 3D conv.
"""

# ╔═╡ d4e5f6a7-0002-0001-0001-000000000001
begin
	const VOL_SIZE = (32, 32, 32)
	const VOL_CH = 1
	const UNET3D_BATCHES = [1, 2]

	unet3d_model = Chain(
		# Initial projection: 32³×1 → 32³×8
		Conv((3, 3, 3), 1 => 8, relu; pad=1),

		# Residual block (addition skip)
		SkipConnection(
			Chain(
				Conv((3, 3, 3), 8 => 8, relu; pad=1),
				Conv((3, 3, 3), 8 => 8, relu; pad=1),
			),
			+
		),                                              # 32³×8

		# Downsample via stride-2 conv: 32³×8 → 16³×16
		Conv((2, 2, 2), 8 => 16, relu; stride=2),

		# Concatenation skip block
		SkipConnection(
			Chain(
				Conv((3, 3, 3), 16 => 16, relu; pad=1),
				Conv((3, 3, 3), 16 => 16, relu; pad=1),
			),
			(mx, x) -> cat(mx, x; dims=4)              # concat along C dim (dim 4 for 5D)
		),                                              # 16³×32

		# More convolution
		Conv((3, 3, 3), 32 => 32, relu; pad=1),        # 16³×32

		# Downsample again: 16³×32 → 8³×64
		Conv((2, 2, 2), 32 => 64, relu; stride=2),

		# Bottleneck
		Conv((3, 3, 3), 64 => 64, relu; pad=1),        # 8³×64

		# Output head
		Conv((1, 1, 1), 64 => 1),                       # 8³×1
	)

	unet3d_ps_cpu, unet3d_st_cpu = Lux.setup(rng, unet3d_model)
	md"3D UNet model: 8 Conv3D layers, 2 stride-2 downsample, 1 residual add, 1 concat skip."
end

# ╔═╡ d4e5f6a7-0003-0001-0001-000000000001
function bench_unet3d_cpu(model, ps, st, batch_size; n_warmup=3, n_trials=10)
	x = randn(Float32, VOL_SIZE..., VOL_CH, batch_size)
	f(m, x, ps, st) = first(m(x, ps, st))
	for _ in 1:n_warmup; f(model, x, ps, st); end
	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		stats = @timed f(model, x, ps, st)
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end
	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024)
end

# ╔═╡ d4e5f6a7-0004-0001-0001-000000000001
unet3d_cpu_results = Dict(
	bs => bench_unet3d_cpu(unet3d_model, unet3d_ps_cpu, unet3d_st_cpu, bs)
	for bs in UNET3D_BATCHES
)

# ╔═╡ d4e5f6a7-0005-0001-0001-000000000001
function bench_unet3d_reactant(model, ps, st, batch_size;
		n_warmup=3, n_trials=10)
	x = randn(Float32, VOL_SIZE..., VOL_CH, batch_size)
	x_ra = Reactant.to_rarray(x)
	ps_ra = Reactant.to_rarray(ps)
	st_ra = Reactant.to_rarray(st)
	f(m, x, ps, st) = first(m(x, ps, st))

	compile_stats = @timed begin
		compiled_f = @compile f(model, x_ra, ps_ra, st_ra)
	end
	compile_ms = compile_stats.time * 1000

	for _ in 1:n_warmup
		compiled_f(model, x_ra, ps_ra, st_ra)
	end

	times = Float64[]
	allocs = Int[]
	for _ in 1:n_trials
		stats = @timed compiled_f(model, x_ra, ps_ra, st_ra)
		push!(times, stats.time)
		push!(allocs, Int(stats.bytes))
	end
	(; median_ms = median(times) * 1000,
	   min_ms = minimum(times) * 1000,
	   median_alloc_kb = median(allocs) / 1024,
	   compile_ms)
end

# ╔═╡ d4e5f6a7-0006-0001-0001-000000000001
unet3d_reactant_results = Dict(
	bs => bench_unet3d_reactant(unet3d_model, unet3d_ps_cpu, unet3d_st_cpu, bs)
	for bs in UNET3D_BATCHES
)

# ╔═╡ d4e5f6a7-0007-0001-0001-000000000001
begin
	x_unet3d_test = randn(Float32, VOL_SIZE..., VOL_CH, 1)
	f_unet3d(m, x, ps, st) = first(m(x, ps, st))

	# CPU reference
	y_unet3d_cpu = f_unet3d(unet3d_model, x_unet3d_test, unet3d_ps_cpu, unet3d_st_cpu)

	# Reactant+Metal
	x_unet3d_ra = Reactant.to_rarray(x_unet3d_test)
	ps_unet3d_ra = Reactant.to_rarray(unet3d_ps_cpu)
	st_unet3d_ra = Reactant.to_rarray(unet3d_st_cpu)
	compiled_unet3d = @compile f_unet3d(unet3d_model, x_unet3d_ra, ps_unet3d_ra, st_unet3d_ra)
	y_unet3d_reactant = Array(compiled_unet3d(unet3d_model, x_unet3d_ra, ps_unet3d_ra, st_unet3d_ra))

	unet3d_err_reactant = maximum(abs.(y_unet3d_cpu .- y_unet3d_reactant))

	md"""
	### 3D UNet Correctness Check

	| Backend | Max Error vs CPU |
	|---------|-----------------|
	| Reactant+Metal | $(round(unet3d_err_reactant; sigdigits=3)) |

	Should be < 1e-4 (3D conv accumulation allows slightly more error).
	"""
end

# ╔═╡ d4e5f6a7-0008-0001-0001-000000000001
begin
	u3h = "| Batch | CPU (ms) | Reactant+Metal (ms) | Speedup | Compile (ms) |"
	u3s = "|------|---------|--------------------|---------:|------------:|"
	u3rows = String[]
	for bs in UNET3D_BATCHES
		c = unet3d_cpu_results[bs]
		r = unet3d_reactant_results[bs]
		sp = round(c.median_ms / r.median_ms; digits=1)
		push!(u3rows, "| $bs | $(round(c.median_ms; digits=3)) | $(round(r.median_ms; digits=3)) | $(sp)x | $(round(r.compile_ms; digits=1)) |")
	end

	Markdown.parse(join(["### 3D UNet Results (CPU vs Reactant+Metal)", "",
		"8 Conv3D layers, stride-2 downsample, residual add + concat skip. 32³ voxel input.", "",
		u3h, u3s, u3rows...], "\n"))
end

# ╔═╡ d4e5f6a7-0009-0001-0001-000000000001
begin
	u3a_header = "| Batch | CPU alloc (KB) | Reactant+Metal alloc (KB) |"
	u3a_sep    = "|------|---------------:|-------------------------:|"
	u3a_rows = String[]
	for bs in UNET3D_BATCHES
		c = unet3d_cpu_results[bs]
		r = unet3d_reactant_results[bs]
		push!(u3a_rows, "| $bs | $(round(c.median_alloc_kb; digits=1)) | $(round(r.median_alloc_kb; digits=1)) |")
	end

	Markdown.parse(join(["### 3D UNet Allocations", "", u3a_header, u3a_sep, u3a_rows...], "\n"))
end

# ╔═╡ a1b2c3d4-0023-0001-0001-000000000001
md"""
## What's Happening Under the Hood

Julia code


Reactant tracing → StableHLO MLIR


XLA/MLIR optimization (op fusion, CSE, constant folding)


Optimized MLIR → PJRT plugin → MLIR walker


MPSGraph builder → Metal GPU execution


**Why Reactant+Metal can beat Metal.jl alone:**
- XLA fuses multiple operations into single GPU kernels (fewer launches)
- Constant folding eliminates redundant computation at compile time
- CSE (common subexpression elimination) removes duplicate work
- The MPSGraph layer adds Apple's own Metal-specific optimizations on top

**Why GPU beats CPU (at sufficient scale):**
- GPU parallelism for large matrix operations
- GPU memory bandwidth advantage for data-heavy workloads

**Important caveats:**
- Small models (< ~100K params) are faster on CPU — GPU dispatch overhead dominates
- Reactant+Metal has per-call PJRT overhead (~ms) that only pays off with larger compute
- First `@jit` call includes compilation (seconds); subsequent calls reuse cached executable
- This is a prototype PJRT plugin — production performance would be better
"""

# ╔═╡ Cell order:
# ╟─a1b2c3d4-0001-0001-0001-000000000001
# ╟─a1b2c3d4-0002-0001-0001-000000000001
# ╠═a1b2c3d4-0003-0001-0001-000000000001
# ╠═a1b2c3d4-0004-0001-0001-000000000001
# ╟─a1b2c3d4-0005-0001-0001-000000000001
# ╠═a1b2c3d4-0006-0001-0001-000000000001
# ╟─a1b2c3d4-0007-0001-0001-000000000001
# ╠═a1b2c3d4-0008-0001-0001-000000000001
# ╠═a1b2c3d4-0009-0001-0001-000000000001
# ╟─a1b2c3d4-0010-0001-0001-000000000001
# ╠═a1b2c3d4-0011-0001-0001-000000000001
# ╠═a1b2c3d4-0012-0001-0001-000000000001
# ╠═a1b2c3d4-0013-0001-0001-000000000001
# ╟─a1b2c3d4-0014-0001-0001-000000000001
# ╠═a1b2c3d4-0015-0001-0001-000000000001
# ╠═a1b2c3d4-0016-0001-0001-000000000001
# ╟─a1b2c3d4-0017-0001-0001-000000000001
# ╠═a1b2c3d4-0018-0001-0001-000000000001
# ╟─a1b2c3d4-0019-0001-0001-000000000001
# ╠═a1b2c3d4-0020-0001-0001-000000000001
# ╠═a1b2c3d4-0021-0001-0001-000000000001
# ╠═a1b2c3d4-0022-0001-0001-000000000001
# ╟─b2c3d4e5-0001-0001-0001-000000000001
# ╠═b2c3d4e5-0002-0001-0001-000000000001
# ╠═b2c3d4e5-0003-0001-0001-000000000001
# ╠═b2c3d4e5-0004-0001-0001-000000000001
# ╠═b2c3d4e5-0005-0001-0001-000000000001
# ╠═b2c3d4e5-0006-0001-0001-000000000001
# ╟─b2c3d4e5-0007-0001-0001-000000000001
# ╠═b2c3d4e5-0008-0001-0001-000000000001
# ╟─c3d4e5f6-0001-0001-0001-000000000001
# ╠═c3d4e5f6-0002-0001-0001-000000000001
# ╠═c3d4e5f6-0003-0001-0001-000000000001
# ╠═c3d4e5f6-0004-0001-0001-000000000001
# ╠═c3d4e5f6-0005-0001-0001-000000000001
# ╠═c3d4e5f6-0006-0001-0001-000000000001
# ╠═c3d4e5f6-0007-0001-0001-000000000001
# ╟─c3d4e5f6-0008-0001-0001-000000000001
# ╠═c3d4e5f6-0009-0001-0001-000000000001
# ╟─d4e5f6a7-0001-0001-0001-000000000001
# ╠═d4e5f6a7-0002-0001-0001-000000000001
# ╠═d4e5f6a7-0003-0001-0001-000000000001
# ╠═d4e5f6a7-0004-0001-0001-000000000001
# ╠═d4e5f6a7-0005-0001-0001-000000000001
# ╠═d4e5f6a7-0006-0001-0001-000000000001
# ╠═d4e5f6a7-0007-0001-0001-000000000001
# ╟─d4e5f6a7-0008-0001-0001-000000000001
# ╠═d4e5f6a7-0009-0001-0001-000000000001
# ╟─a1b2c3d4-0023-0001-0001-000000000001

Copy link
Collaborator

@mofeing mofeing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Most comments are for code cleaning and formatting, but the most critical change is that stablehlo.reduce can silently translate to wrong code and not error. Unfortunately I don't know much about MPS so I can't lend a hand there.

Comment on lines +488 to +497
raw_body_op = API.mlirBlockGetFirstOperation(body_block)
while !(IR.mlirIsNull(raw_body_op))
bop = IR.Operation(raw_body_op)
bop_name = IR.name(bop)
if startswith(bop_name, "stablehlo.")
body_op_name = bop_name
break
end
raw_body_op = API.mlirOperationGetNextInBlock(bop)
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Block implements the Iterator interface and returns Operations so this can be simplified to

Suggested change
raw_body_op = API.mlirBlockGetFirstOperation(body_block)
while !(IR.mlirIsNull(raw_body_op))
bop = IR.Operation(raw_body_op)
bop_name = IR.name(bop)
if startswith(bop_name, "stablehlo.")
body_op_name = bop_name
break
end
raw_body_op = API.mlirOperationGetNextInBlock(bop)
end
for bop in body_block
bop_name = IR.name(bop)
if startswith(bop_name, "stablehlo.")
body_op_name = bop_name
break
end
end

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, the problem with this approach is that stablehlo.reduce accepts arbitrary code, and matching on the first op most probably won't be correct.

take the following example (a sign-alternating add reduction). your code will translate it to a regular add reduction and it won't be correct.

%result = "stablehlo.reduce"(%input, %init_value) ({
  ^bb0(%arg0: tensor<i64>, %arg1: tensor<i64>):
    %0 = "stablehlo.add"(%arg0, %arg1) : (tensor<i64>, tensor<i64>) -> tensor<i64>
    %1 = "stablehlo.negate"(%0) : tensor<i64>
    "stablehlo.return"(%0) : (tensor<i64>) -> ()
}) {
  dimensions = array<i64: 1>
} : (tensor<1x6xi64>, tensor<i64>) -> tensor<1xi64>


# Pattern-match the body to determine reduction type
body_op_name = ""
if IR.nregions(op) > 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stablehlo.reduce must have a region to contains the reducer code, so this seems redundant?

t = Metal.MPSGraphs.transposeTensor(graph, tensor, 0, 4, "$(name)_rev1")
return Metal.MPSGraphs.transposeTensor(graph, t, 1, 3, "$(name)_rev2")
else
error("mps_reverse_dims: unsupported rank $rank")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

a permutation can be decomposed into a series of transpositions using, for example, the Coxeter decomposition, which is implemented in Permutations.jl

Comment on lines +1171 to +1206
const OP_HANDLERS = Dict{String, Function}()

function get_op_handlers()
if isempty(OP_HANDLERS)
OP_HANDLERS["stablehlo.add"] = handle_add
OP_HANDLERS["stablehlo.subtract"] = handle_subtract
OP_HANDLERS["stablehlo.multiply"] = handle_multiply
OP_HANDLERS["stablehlo.divide"] = handle_divide
OP_HANDLERS["stablehlo.maximum"] = handle_maximum
OP_HANDLERS["stablehlo.negate"] = handle_negate
OP_HANDLERS["stablehlo.exponential"] = handle_exponential
OP_HANDLERS["stablehlo.exp"] = handle_exponential
OP_HANDLERS["stablehlo.log"] = handle_log
OP_HANDLERS["stablehlo.tanh"] = handle_tanh
OP_HANDLERS["stablehlo.sqrt"] = handle_sqrt
OP_HANDLERS["stablehlo.rsqrt"] = handle_rsqrt
OP_HANDLERS["stablehlo.abs"] = handle_abs
OP_HANDLERS["stablehlo.sine"] = handle_sin
OP_HANDLERS["stablehlo.sin"] = handle_sin
OP_HANDLERS["stablehlo.cosine"] = handle_cos
OP_HANDLERS["stablehlo.cos"] = handle_cos
OP_HANDLERS["stablehlo.convert"] = handle_convert
OP_HANDLERS["stablehlo.constant"] = handle_constant
OP_HANDLERS["stablehlo.dot_general"] = handle_dot_general
OP_HANDLERS["stablehlo.dot"] = handle_dot_general
OP_HANDLERS["stablehlo.broadcast_in_dim"] = handle_broadcast_in_dim
OP_HANDLERS["stablehlo.reshape"] = handle_reshape
OP_HANDLERS["stablehlo.transpose"] = handle_transpose
OP_HANDLERS["stablehlo.reverse"] = handle_reverse
OP_HANDLERS["stablehlo.concatenate"] = handle_concatenate
OP_HANDLERS["stablehlo.convolution"] = handle_convolution
OP_HANDLERS["stablehlo.slice"] = handle_slice
OP_HANDLERS["stablehlo.scatter"] = handle_scatter
end
return OP_HANDLERS
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

passing a Function this way will incur into dynamic dispatch when translating from stablehlo to MPSGraph (so increased compile-time).

given that this function is (1) quite trivial, (2) doesn't require user extensibility, and (3) it's only used once, do you mind changing wherever you call this function for if-elseif code?

(actually I see that the code does an if-elseif but then Claude went lazy?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if else if vs dispatch is close in performance when it gets this big (dynamic dispatch is like 2 ns)

Comment on lines +1249 to +1281
handlers = get_op_handlers()
raw_op = API.mlirBlockGetFirstOperation(func_block)
while !(IR.mlirIsNull(raw_op))
op = IR.Operation(raw_op)
op_name = IR.name(op)
ctx.op_count += 1


if op_name == "func.return"
for j in 1:IR.noperands(op)
ret_val = IR.operand(op, j)
if haskey(ctx.value_map, ret_val)
push!(ctx.outputs, ctx.value_map[ret_val])
ir_shape, dtype = get_type_info(ret_val)
julia_shape = length(ir_shape) >= 2 ? reverse(ir_shape) : ir_shape
push!(ctx.output_shapes, julia_shape)
push!(ctx.output_dtypes, dtype)
else
@warn "Return value not found in value_map"
end
end
elseif op_name == "stablehlo.reduce"
handle_reduce(ctx, op)
elseif op_name == "stablehlo.reduce_window"
handle_reduce_window(ctx, op)
elseif haskey(handlers, op_name)
handlers[op_name](ctx, op)
else
error("Unsupported StableHLO op: $op_name")
end

raw_op = API.mlirOperationGetNextInBlock(op)
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you can iterate directly on the Block

Suggested change
handlers = get_op_handlers()
raw_op = API.mlirBlockGetFirstOperation(func_block)
while !(IR.mlirIsNull(raw_op))
op = IR.Operation(raw_op)
op_name = IR.name(op)
ctx.op_count += 1
if op_name == "func.return"
for j in 1:IR.noperands(op)
ret_val = IR.operand(op, j)
if haskey(ctx.value_map, ret_val)
push!(ctx.outputs, ctx.value_map[ret_val])
ir_shape, dtype = get_type_info(ret_val)
julia_shape = length(ir_shape) >= 2 ? reverse(ir_shape) : ir_shape
push!(ctx.output_shapes, julia_shape)
push!(ctx.output_dtypes, dtype)
else
@warn "Return value not found in value_map"
end
end
elseif op_name == "stablehlo.reduce"
handle_reduce(ctx, op)
elseif op_name == "stablehlo.reduce_window"
handle_reduce_window(ctx, op)
elseif haskey(handlers, op_name)
handlers[op_name](ctx, op)
else
error("Unsupported StableHLO op: $op_name")
end
raw_op = API.mlirOperationGetNextInBlock(op)
end
handlers = get_op_handlers()
for op in func_block
op_name = IR.name(op)
ctx.op_count += 1
if op_name == "func.return"
for j in 1:IR.noperands(op)
ret_val = IR.operand(op, j)
if haskey(ctx.value_map, ret_val)
push!(ctx.outputs, ctx.value_map[ret_val])
ir_shape, dtype = get_type_info(ret_val)
julia_shape = length(ir_shape) >= 2 ? reverse(ir_shape) : ir_shape
push!(ctx.output_shapes, julia_shape)
push!(ctx.output_dtypes, dtype)
else
@warn "Return value not found in value_map"
end
end
elseif op_name == "stablehlo.reduce"
handle_reduce(ctx, op)
elseif op_name == "stablehlo.reduce_window"
handle_reduce_window(ctx, op)
elseif haskey(handlers, op_name)
handlers[op_name](ctx, op)
else
error("Unsupported StableHLO op: $op_name")
end
end

Comment on lines +20 to +32
struct MetalPJRT_Api_Version
struct_size::UInt64 # offset 0, 8 bytes
extension_start::Ptr{Cvoid} # offset 8, 8 bytes
major_version::Int32 # offset 16, 4 bytes
minor_version::Int32 # offset 20, 4 bytes
end # 24 bytes total (0x18)

struct MetalPJRT_Api
struct_size::UInt64 # offset 0
extension_start::Ptr{Cvoid} # offset 8
pjrt_api_version::MetalPJRT_Api_Version # offset 16, 24 bytes
fns::NTuple{128,Ptr{Cvoid}} # offset 40, 1024 bytes
end # Total: 8 + 8 + 24 + 1024 = 1064 bytes (0x428)
Copy link
Collaborator

@mofeing mofeing Feb 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would you mind using the bindings in https://github.com/EnzymeAD/Reactant.jl/blob/main/src/xla/PJRT/CAPI.jl ? this way we avoid duplications and can track changes in the PJRT API

Comment on lines +23 to +33
"""
setup_metal!()

metal_pjrt_plugin_path = joinpath(path, "pjrt_plugin_metal_14.dylib")
if !isfile(metal_pjrt_plugin_path)
zip_file_path = joinpath(path, "pjrt-plugin-metal.zip")
tmp_dir = joinpath(path, "tmp")
Downloads.download(
if Sys.ARCH === :aarch64
"https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_metal-0.1.1-py3-none-macosx_13_0_arm64.whl"
elseif Sys.ARCH === :x86_64
"https://files.pythonhosted.org/packages/87/ec/9bb7f7f0ffd06c3fb89813126b2f698636ac7a4263ed7bdd1ff7d7c94f8f/jax_metal-0.1.1-py3-none-macosx_10_14_x86_64.whl"
else
error("Unsupported architecture: $(Sys.ARCH)")
end,
zip_file_path,
)
run(pipeline(`$(p7zip()) x -tzip -o$(tmp_dir) -- $(zip_file_path)`, devnull))
mv(
joinpath(tmp_dir, "jax_plugins", "metal_plugin", "pjrt_plugin_metal_14.dylib"),
metal_pjrt_plugin_path,
)
rm(tmp_dir; recursive=true)
rm(zip_file_path; recursive=true)
end
Placeholder hook for external callers. The actual Metal PJRT client is
created inside `ReactantMetalExt.__init__()`, which Julia loads automatically
whenever `Metal` is brought into scope as a weak dependency.
"""
function setup_metal!()
# Metal client registration is handled by ReactantMetalExt.__init__()
# when Metal.jl is loaded as a weak dependency.
return nothing
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this can be removed

end
else
@warn "`get_properties` not implemented for platform: $(pname)" maxlog = 1
@debug "`get_properties` not implemented for platform: $(pname)"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is there a reason for this change? it looks quite arbitrary

@Dale-Black
Copy link
Contributor Author

Hi @mofeing — thanks for the thorough review. I (plus Claude) have tried to address each of your 10 comments in individual commits:

  1. Revert @warn@debug — reverted, removes Device.jl from the diff entirely
  2. Remove setup_metal!() — deleted the no-op placeholder
  3. Generalize MakeClientFromApi — now takes device_type/client_name args for reuse
  4. Block iterator in reduce body — removed redundant nregions guard, uses for bop in body_block
  5. Validate reduce body — errors on non-trivial multi-op bodies instead of silently matching the first op
  6. Simplify mps_reverse_dims — now calls the existing apply_permutation, works for arbitrary rank
  7. OP_HANDLERS Dict → if-elseif — eliminated dynamic dispatch
  8. Block iterator in main walk loopfor op in func_block
  9. Use CAPI.jl bindings — removed duplicate structs from PJRTPlugin.jl (note: this also required commenting out an untranslatable C macro on CAPI.jl line 3932 and exposing CAPI as a submodule in PJRT.jl — happy to adjust if there's a better way)

I'm still learning the codebase and some of this is admittedly hacky, especially the CAPI.jl integration. If I butchered anything or you'd prefer a different approach on any of these, please let me know — very happy to rework.

Copy link
Collaborator

@mofeing mofeing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm still learning the codebase and some of this is admittedly hacky, especially the CAPI.jl integration. If I butchered anything or you'd prefer a different approach on any of these, please let me know — very happy to rework.

it's great for an initial version and thanks for working on this. I confess I'm a lil bit picky; most of my requests are minor things that can be changed in subsequent PRs.

for me the most critical thing holding this PR is that you just cannot match the first op in the reduce block of stablehlo.reduce and choose the reducer function based on it. the reducer code can be more complex and the way is coded right now it will silently generate wrong results. if implementing a fix for it is too hard right now, I would prefer it to be left unimplemented and just error, or if needed, match the full block for the already implemented special cases like add- or max-reductions (i.e. match up to the return).

also, it seems like Claude prefers to use Ptr{Cvoid} and hardcode the field offsets in PJRTPlugin.jl. we should instead use the types in Reactant.XLA.PJRT.CAPI instead of Cvoid for the pointers, and fieldoffset instead of hardcoded numbers.

const PJRT_API_MINOR = 90

const _PJRT_API_STRUCT_FIELD = fn_type(fn_type) * fn_type
# const _PJRT_API_STRUCT_FIELD = fn_type(fn_type) * fn_type # untranslatable C macro
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this file is auto-generated, so you should refrain from making changes there. if it's breaking sth, tell us so we can fix it in the generator script.


# PJRT_LoadedExecutable_Destroy_Args:
# offset 16: executable* (input, 8) — our handle
function _loaded_exec_destroy(args::Ptr{Cvoid})::Ptr{Cvoid}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

using Ptr{Cvoid} and unsafe_load/unsafe_store! with hardcoded offsets is fragile and makes fixing bugs in the future way more difficult.

the point of XLA.PJRT.CAPI is also that you can use the structs defined there in these functions.

function _loaded_exec_destroy(args::Ptr{PJRT_LoadedExecutable_Destroy_Args})::Ptr{Cvoid}

instead of the hardcoded offsets, you can use fieldoffset

@Dale-Black
Copy link
Contributor Author

Hi @mofeing — following up on your feedback about Ptr{Cvoid} + hardcoded offsets being fragile. I (plus Claude) went through every callback in PJRTPlugin.jl and replaced them with the CAPI typed pointers + your unsafe_load_field / unsafe_store_field! helpers.

The changes are purely mechanical — no logic changes, just signature types and field access patterns:

  • All ~62 callback signatures now use Ptr{CAPI.PJRT_StructName_Args} instead of Ptr{Cvoid}
  • All field reads/writes use Reactant.unsafe_load_field / Reactant.unsafe_store_field! instead of hardcoded byte offsets
  • Zero hardcoded offsets remain (raw array pointer iteration in _loaded_exec_execute stays as-is since there's no CAPI struct for those)
  • _stub and _unimpl keep Ptr{Cvoid} since they're generic catch-all handlers

One note: CAPI.jl line 3932 (_PJRT_API_STRUCT_FIELD) is still commented out — it's a C preprocessor macro argument (fn_type) with no Julia equivalent. Uncommenting causes LoadError: UndefVarError: fn_type not defined. Might be worth a fix in the generator script at some point.

All tests pass (sincos, autodiff, CNN + non-square conv). As always, let me know if anything looks off or if you'd prefer a different approach.

@wsmoses
Copy link
Member

wsmoses commented Feb 24, 2026

@Dale-Black after finding a pretty bad "accidentally ccalled with the wrong number of arguments" we in the interim just landed a refactor of all the ccall/abi stuff where the ccall code is autogenerated, and you should instead call wrapper functions within API.x. Can you update your code to use those? They also should contain relevant struct definitions there as well

@Dale-Black
Copy link
Contributor Author

Hi @wsmoses — just rebased onto main and updated. We only had one direct @ccall in the extension (stablehloDeserializePortableArtifactNoError), which now goes through the auto-generated wrapper in API. Also adapted PJRTPlugin.jl for the PJRT_Api struct growth — fn pointer count is now derived dynamically from fieldcount(CAPI.PJRT_Api), version uses CAPI.PJRT_API_MAJOR/MINOR, and added the new _device_get_attributes callback for slot 129. All tests still pass locally (sincos, autodiff, CNN). Let me know if this is what you had in mind or if I misunderstood something.

# Maps C handle address (UInt64) -> MetalExecutable.
# Using Any to avoid dependency on MLIRWalker.jl (included after this file).
const LOADED_EXECUTABLES = Dict{UInt64, Any}()

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't quite get why you need to have these dictionaries, instead of having something like

mutable struct Executable
    ... all the data
end


handle = _handle_alloc()
@lock PJRT_LOCK begin
METAL_BUFFERS[UInt64(handle)] = (data=data_gpu, dims_c=dims_c, ndims=num_dims,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so I think this needlessly goes thorugh an extra level of indirection.

presumably metal.jl calls into apple's allocator function (which can just malloc/free or similar like a normal allocator). This is then stored in a GC'd object which is free'd upon end of use. To avoid that free, we put it in a dict.

Rather than have the dict to avoid the gc to avoid the free, can we just call [via metal.jl potential internals] the actual apple allocate/free functions?

That way we avoid the dict, indirection/race issues, and a decent chunk of overhead?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we need to store data, we can make our own

mutable struct MetalBuffer
    ptr::Ptr{Cvoid} # actual data
    eltype::Enum # element type
    # ... whatever other data we need
end

we control alloc/free, so in the allocation function we allocate both the data and libc.malloc a struct of size metalbuffer or whatnot, then in the free, we free both

@Dale-Black
Copy link
Contributor Author

Hi @wsmoses — just pushed the Dict-free buffer/executable rearchitecture you described. Is this the type of approach you're talking about?

What changed:

  • Replaced METAL_BUFFERS Dict with MetalBufferData C-struct (Libc.malloc'd, 48 bytes: raw MTLBuffer id, data_ptr, eltype, dims, ndims, nbytes)
  • Replaced LOADED_EXECUTABLES/LOADED_EXECUTABLE_MLIR Dicts with pointer_from_objref(exec) — direct pointer dereference, no Dict lookup
  • Execute path feeds raw MTLBuffer directly to MPSGraph via MPSGraphTensorData(buf, shape, dtype) — no MtlArray intermediary, data stays on GPU
  • Removed all dead infrastructure: handle pool, MtlArray pool, PJRT_LOCK
  • Cleaned up __init__ client registration with TODO for generic register_backend!() API

All tests pass (sincos, autodiff, CNN, non-square conv).

# Serialized MLIR text for _exec_optimized_program (set after construction in _client_compile)
mlir_text::String
# Execution cache — lazily built on first execute!
_input_mtl::Vector{Any} # cached MtlArrays for inputs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems not to be used?

can this be trimmed to just the essentials?

# GC roots for MetalExecutable objects — prevents GC collection while PJRT holds the handle.
# NOT used for lookup: the handle IS the pointer to the Julia object (pointer_from_objref).
# To retrieve: unsafe_pointer_to_objref(handle). To destroy: delete! from set.
const EXEC_GC_ROOTS = Set{Any}()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this shouldn't be necessary any more of we actually malloc/free it explicitly, right?

@Dale-Black
Copy link
Contributor Author

@wsmoses — good catches. Removed the dead cache fields from MetalExecutable (trimmed to just the essentials).

On EXEC_GC_ROOTS — I think eliminating it would require a larger refactor of MetalExecutable into a C-allocatable struct, right? It currently holds MPSGraph objects, Vectors of MPSGraphTensor, etc. that need Julia/ObjC GC. The buffer side (MetalBufferData) was straightforward to Libc.malloc since it's just scalar fields, but the executable struct is more complex. Or am I misunderstanding what you're suggesting?

@wsmoses
Copy link
Member

wsmoses commented Feb 24, 2026

yeah so @Dale-Black GC errors are a sufficiently non-determistic pain when things go wrong (and try to fix), that since at the end of the day all of the types stored in MetalExecutable can be non-julia types, we should make it a struct just allocated with malloc/free and explicitly managed. for example for the dtypes, we can use whatever integer enum for types xla already has [and then call the convert to julia type when requested]

@wsmoses
Copy link
Member

wsmoses commented Feb 24, 2026

to be clear we can (and should) compeltely use all the nice julia/metal.jl setup side for building the executable struct, but once built the executable object itself should not have julia GC objects, if possible [to avoid weird memory corruption debugging in our future]

@Dale-Black
Copy link
Contributor Author

Here's my attempt at the C-allocatable MetalExecutable refactor, @wsmoses — MetalExecutableData is now Libc.malloc'd with raw ObjC ids + explicit retain/release. EXEC_GC_ROOTS eliminated. Let me know what you think.

# Total: 15 + n_in + n_out + n_c
# Plus: n_c MTLBuffer allocations via Metal.alloc (freed via Metal.free)
# ============================================================
function freeze_executable(exec)::Ptr{MetalExecutableData}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is super verbose as a file, can you separate it out into logical components [e.g. a Executable.jl]

Copy link
Member

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minus some refactorings to make this easier to read/maintain/etc, I'm okay with this now at a high level.

@avik-pal @mofeing if you want to take a look

@Dale-Black
Copy link
Contributor Author

Split PJRTPlugin.jl into Buffer.jl + Executable.jl + PJRTPlugin.jl per your feedback, @wsmoses. No logic changes, just file organization.

Copy link
Collaborator

@mofeing mofeing left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @wsmoses. Would you mind splitting PJRTPlugin.jl into more files? one file per API topic would be great. you already have Buffer and Executable, but there should be also for Client, Plugin, LoadedExecutable, Event, Device and DeviceDescription (I believe).

It requires some effort to review through these files, but also it will help with debugging and refactoring the global state design it seems Claude has decided to follow.

push!(build_cmd_list, "--sandbox_debug")

push!(build_cmd_list, "--linkopt=-fuse-ld=lld")
# push!(build_cmd_list, "--linkopt=-fuse-ld=lld") # lld not available on macOS
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmm was this commented by Claude? this flag does take effect on other OS aside of macOS, so you should wrap it inside if !Sys.isapple()

Comment on lines +11 to +21
# mtl_buf_id: raw ObjC id for the MTLBuffer (from Metal.alloc, SharedStorage).
# The MTLBuffer has retain count 1 from alloc — we release it in _buffer_destroy.
# data_ptr: CPU-accessible pointer from Metal.MTL.contents() (stable for SharedStorage).
struct MetalBufferData
mtl_buf_id::UInt64
data_ptr::Ptr{Cvoid}
eltype::UInt32
dims::Ptr{Int64} # Libc.malloc'd dims array
ndims::Int
nbytes::Int
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think what you're looking for here is a Metal.MTL.MTLTensor, although you can also use a Metal.MtlPtr and wrap whatever else you need.

Comment on lines +27 to +59
# PJRT_Buffer_Type enum value → Julia element type
# PJRT_Buffer_Type::UInt32: PRED=1,S8=2,S16=3,S32=4,S64=5,U8=6,F16=10,F32=11,F64=22
function pjrt_type_to_julia(t::UInt32)
return if t == 11
Float32
elseif t == 22
Float64
elseif t == 10
Float16
elseif t == 4
Int32
elseif t == 5
Int64
else
Float32
end
end

function julia_type_to_pjrt(T)
return if T == Float32
UInt32(11)
elseif T == Float64
UInt32(22)
elseif T == Float16
UInt32(10)
elseif T == Int32
UInt32(4)
elseif T == Int64
UInt32(5)
else
UInt32(11)
end
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you already have this functionality implemented in Reactant.XLA.primitive_type and Reactant.XLA.julia_type

Suggested change
# PJRT_Buffer_Type enum value → Julia element type
# PJRT_Buffer_Type::UInt32: PRED=1,S8=2,S16=3,S32=4,S64=5,U8=6,F16=10,F32=11,F64=22
function pjrt_type_to_julia(t::UInt32)
return if t == 11
Float32
elseif t == 22
Float64
elseif t == 10
Float16
elseif t == 4
Int32
elseif t == 5
Int64
else
Float32
end
end
function julia_type_to_pjrt(T)
return if T == Float32
UInt32(11)
elseif T == Float64
UInt32(22)
elseif T == Float16
UInt32(10)
elseif T == Int32
UInt32(4)
elseif T == Int64
UInt32(5)
else
UInt32(11)
end
end

Comment on lines +46 to +61
# ============================================================
# freeze_executable: Convert Julia MetalExecutable → C-allocated MetalExecutableData
#
# Retains all ObjC objects (graph, placeholders, output tensors).
# Copies all metadata to Libc.malloc'd C arrays.
# Allocates fresh MTLBuffers for constant data (independent of MtlArray GC).
# The returned pointer IS the PJRT executable handle.
#
# Libc.malloc count (for n_in inputs, n_out outputs, n_c constants):
# Fixed: 15 allocations (arrays + struct)
# Per-input shape: n_in allocations
# Per-output shape: n_out allocations
# Per-const shape: n_c allocations
# Total: 15 + n_in + n_out + n_c
# Plus: n_c MTLBuffer allocations via Metal.alloc (freed via Metal.free)
# ============================================================
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor thing, but it would be cool if comments like this were docstrings

Comment on lines +406 to +414
global CLIENT_HANDLE = Libc.malloc(64)
global DEVICE_HANDLE = Libc.malloc(64)
global DEVDESC_HANDLE = Libc.malloc(64)
global MEMORY_HANDLE = Libc.malloc(64)

unsafe_store!(Ptr{Int64}(CLIENT_HANDLE), Int64(0xDEADBEEF))
unsafe_store!(Ptr{Int64}(DEVICE_HANDLE), Int64(0xCAFEBABE))
unsafe_store!(Ptr{Int64}(DEVDESC_HANDLE), Int64(0xF00DCAFE))
unsafe_store!(Ptr{Int64}(MEMORY_HANDLE), Int64(0xFEEDFACE))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are these magic numbers? actually check it out that they are just things you can write in hexadecimal, like 'dead beef', or 'cafe babe'.

also, why are these pointers needed?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Again, most of this is from Claude but I THINK it makes sense. Supposedly this is a single-device plugin thing — PJRT requires non-NULL pointers for client/device/memory handles (the C++ side dereferences them and segfaults on NULL). Since there's exactly one Metal GPU, one memory space, and one client, these handles don't need to carry real state — they just need to be valid, distinct, non-NULL pointers.

The hex values are conventional debug markers so they're easy to spot in lldb. Named them as SENTINEL_CLIENT, SENTINEL_DEVICE, etc. in the latest commit so they're self-documenting now.

At least this is what claude is telling me so I made that more clear in the recent commits

Comment on lines +129 to +136
function _client_create(args::Ptr{CAPI.PJRT_Client_Create_Args})::Ptr{Cvoid}
Reactant.unsafe_store_field!(args, CLIENT_HANDLE, Val{:client}())
return C_NULL
end

function _client_destroy(args::Ptr{CAPI.PJRT_Client_Destroy_Args})::Ptr{Cvoid}
return C_NULL
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is cheating 😆 Here instead of creating a client with the requested configuration, it's ignoring all that and setting a ""random"" magic number for the pointer to the client.

Effectively it's working as a single client with global state, when PJRT would like to have all that data contained into an allocatable object.

I believe this is the reason why so many global pointers appear scattered around the code.

opinions @wsmoses ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is correct: there's one Metal GPU on the system so there's no meaningful "configuration" to process from the create args. The global state is a consequence of single-device reality, not a design shortcut. That said, if multi-device support matters in the future (e.g. multiple M-series chips), we'd refactor to allocate real client state per create call. Happy to rework this now if you have a specific design in mind.

@Dale-Black
Copy link
Contributor Author

Hey — squashed everything down to a single clean commit and addressed the feedback from the last round:

  • Moved the PJRT_Device_GetAttributes structs out of CAPI.jl and into the extension
  • Replaced regex-based feature_group_count parsing with the proper C API (IR.getattr)
  • Added multi-op body validation for reduce_window (same pattern you flagged on reduce)
  • Removed dead code (mps_relu, mps_sigmoid — defined but never called)
  • Fixed a duplicate _make_extents allocation in Buffer.jl
  • Cleaned up deps/build_local.jl, Project.toml, and test files

23 files in the diff now, all Metal extension + minimal core touches. Tests passing locally (elementwise, autodiff, CNN, non-square conv).

Am I on the right track with this? Happy to adjust anything.

Comment on lines +233 to +240
GC.@preserve errstr begin
client = @ccall MLIR.API.mlir_c.MakeClientFromApi(
api_ptr::Ptr{Cvoid},
device_type::Cstring,
client_name::Cstring,
errstr::Ptr{Cstring},
)::Ptr{Cvoid}
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avoid this ccall and use the MakeClientFromApi from libmlir_h.jl

@Dale-Black Dale-Black force-pushed the metal-pjrt-backend branch from c58d751 to 694f865 Compare March 1, 2026 18:19
src/xla/XLA.jl Outdated
Comment on lines +325 to +337
# Apple Silicon: Metal PJRT backend via ReactantMetalExt/MPSGraph
if Accelerators.Metal.has_metal()
if was_initialized && haskey(state.clients, "metal")
free_client(state.clients["metal"])
$(runtime).metal_client_count[] -= 1
end
gpu = $(runtime).MetalClient(;
metal_pjrt_plugin_path=Accelerators.Metal.get_metal_pjrt_plugin_path(),
common_kwargs...,
)
state.clients["metal"] = gpu
# Don't put this in the default_client since metal support is fairly
# limited
=#
# Metal PJRT plugin is not yet compatible with latest OpenXLA
catch e
println(stdout, e)
try
metal = $(runtime).MetalClient()
state.clients["metal"] = metal
state.default_client = metal
catch e
println(stdout, e)
end
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once you rebase, you can register the plugin in the MetalExt itself

Adds a Metal GPU backend for Reactant.jl as a package extension (ReactantMetalExt).
Uses Apple's MPSGraph framework to compile and execute StableHLO operations on Metal GPUs.

Key components:
- PJRT plugin implementation via Julia @cfunction callbacks
- MLIR walker that translates StableHLO ops to MPSGraph operations
- MTLTensor-based buffer management with proper retain/release
- Support for: elementwise ops, conv2d/3d, pooling, reduce, matmul, reshape, transpose, broadcast, concatenate, pad, dot_general, and more
- Automatic differentiation works through Enzyme

Requires macOS 26+ with Metal.jl >= 1.8.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
@Dale-Black Dale-Black force-pushed the metal-pjrt-backend branch from 694f865 to 04b344a Compare March 2, 2026 17:51
@Dale-Black
Copy link
Contributor Author

I should have some more free time this weekend, if there is anything else I can do to help get this to the finish line

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants