Skip to content

Commit b811519

Browse files
committed
Fix TLAS build stability + keepalive + sync
- Move _build_flat_blas_arrays! from adapt_structure to sync! so data is ready before any kernel dispatch. - Add KA.synchronize after AK.sortperm in BLAS/TLAS build to ensure sort temp buffers aren't freed while GPU is still using them. - Replace merge_sort_by_key! with sortperm in TLAS topology (avoids 64-bit Int payload corruption on Lava). - Add _REBUILD_KEEPALIVE and synchronize in multitypeset rebuild_static!. - Remove underscore-prefixed function names in RaycoreLavaExt. - Add instanced BVH test.
1 parent 17f62ae commit b811519

4 files changed

Lines changed: 54 additions & 18 deletions

File tree

ext/RaycoreLavaExt/RaycoreLavaExt.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,9 @@ using Lava: LavaBackend, LavaArray, LavaBLAS, LavaTLAS,
2323
lava_global_invocation_id_x,
2424
lava_rt_primitive_id, lava_rt_instance_custom_index,
2525
lava_rt_launch_id_x,
26-
_lava_rt_ignore_intersection, _lava_rt_terminate_ray,
27-
_lava_rt_payload_store_f32_at, _lava_rt_payload_load_f32_at,
28-
_lava_rt_trace_ray,
26+
lava_rt_ignore_intersection, lava_rt_terminate_ray,
27+
lava_rt_payload_store_f32_at, lava_rt_payload_load_f32_at,
28+
lava_rt_trace_ray,
2929
mat4_to_vk_transform
3030

3131
const LavaHWTLAS = HWTLAS{LavaBackend}
@@ -139,14 +139,14 @@ Raycore.rt_primitive_id(::LavaHWAdapted) = lava_rt_primitive_id()
139139
Raycore.rt_instance_custom_index(::LavaHWAdapted) = lava_rt_instance_custom_index()
140140
Raycore.rt_launch_id_x(::LavaHWAdapted) = lava_rt_launch_id_x()
141141
Raycore.rt_global_invocation_id_x(::LavaHWAdapted) = lava_global_invocation_id_x()
142-
Raycore.rt_ignore_intersection(::LavaHWAdapted) = _lava_rt_ignore_intersection()
143-
Raycore.rt_terminate_ray(::LavaHWAdapted) = _lava_rt_terminate_ray()
144-
Raycore.rt_payload_store!(::LavaHWAdapted, val, slot) = _lava_rt_payload_store_f32_at(val, slot)
145-
Raycore.rt_payload_load(::LavaHWAdapted, slot) = _lava_rt_payload_load_f32_at(slot)
142+
Raycore.rt_ignore_intersection(::LavaHWAdapted) = lava_rt_ignore_intersection()
143+
Raycore.rt_terminate_ray(::LavaHWAdapted) = lava_rt_terminate_ray()
144+
Raycore.rt_payload_store!(::LavaHWAdapted, val, slot) = lava_rt_payload_store_f32_at(val, slot)
145+
Raycore.rt_payload_load(::LavaHWAdapted, slot) = lava_rt_payload_load_f32_at(slot)
146146

147147
function Raycore.rt_trace_ray!(::LavaHWAdapted, flags, mask, sbt_offset, sbt_stride, miss_idx,
148148
ox, oy, oz, tmin, dx, dy, dz, tmax)
149-
_lava_rt_trace_ray(flags, mask, sbt_offset, sbt_stride, miss_idx,
149+
lava_rt_trace_ray(flags, mask, sbt_offset, sbt_stride, miss_idx,
150150
ox, oy, oz, tmin, dx, dy, dz, tmax)
151151
end
152152

src/instanced-bvh.jl

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -830,6 +830,10 @@ function _rebuild_bvh!(tlas::TLAS)
830830
tlas.nodes = nodes
831831
tlas.root_aabb = root_aabb
832832

833+
# Build flat BLAS arrays during sync (not during adapt_structure).
834+
# This ensures the data is ready before any kernel dispatch.
835+
_build_flat_blas_arrays!(tlas)
836+
833837
tlas.dirty = false
834838
return
835839
end
@@ -933,9 +937,7 @@ The TLAS must stay alive while the StaticTLAS is in use.
933937
"""
934938
function Adapt.adapt_structure(to, tlas::TLAS)
935939
sync!(tlas)
936-
937-
# Build flat BLAS arrays for traversal (avoids pointer-in-buffer on Metal)
938-
_build_flat_blas_arrays!(tlas)
940+
# Flat BLAS arrays are built during sync! -- no rebuild here.
939941

940942
if tlas._flat_blas_nodes === nothing
941943
# Empty scene — need correct types for StaticTLAS type parameters
@@ -1259,6 +1261,7 @@ function build_blas(
12591261
# Sort primitives by Morton codes
12601262
# AcceleratedKernels only supports GPU backends, use Julia's sortperm for CPU
12611263
perm = AK.sortperm(morton_codes)
1264+
KA.synchronize(backend) # Ensure sort temp buffers aren't freed while GPU is still using them
12621265
morton_codes = morton_codes[perm]
12631266
primitives = primitives[perm]
12641267

@@ -1387,17 +1390,17 @@ function _build_tlas_topology(blas_array, instances, backend)
13871390
calc_kernel!(morton_codes, instances, blas_array, scene_min, scene_extent, ndrange=n)
13881391
KA.synchronize(backend)
13891392

1390-
# Sort indices by Morton codes
1391-
# AcceleratedKernels only supports GPU backends, use Julia's sortperm for CPU
1393+
# Sort instances by Morton codes.
1394+
# On Lava, merge_sort_by_key! with a 64-bit Int payload can corrupt the
1395+
# permutation vector, which later sends TLAS leaf creation out of bounds.
1396+
# Use sortperm like the BLAS path so the permutation type matches the backend.
13921397
if backend isa KA.CPU
13931398
sorted_indices = sortperm(morton_codes)
13941399
morton_codes .= morton_codes[sorted_indices]
13951400
else
1396-
sorted_indices = KA.allocate(backend, Int, n)
1397-
iota_k! = iota_kernel!(backend)
1398-
iota_k!(sorted_indices, ndrange=n)
1399-
KA.synchronize(backend)
1400-
AK.merge_sort_by_key!(morton_codes, sorted_indices)
1401+
sorted_indices = AK.sortperm(morton_codes)
1402+
KA.synchronize(backend) # Ensure sort temp buffers aren't freed while GPU is still using them
1403+
morton_codes = morton_codes[sorted_indices]
14011404
end
14021405

14031406
# Allocate nodes and initialize with empty values

src/multitypeset.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,7 +265,17 @@ to_tuple(mts::MultiTypeSet) = to_tuple(get_static(mts))
265265
# Internal: Rebuild the static tuple - converts CPU vectors to GPU
266266
# ============================================================================
267267

268+
const _REBUILD_KEEPALIVE = Any[]
269+
268270
function rebuild_static!(dhv::MultiTypeSet)
271+
# Keep old static arrays alive until the NEXT rebuild completes.
272+
# Without this, GC can free old buffers during adapt() of new data,
273+
# causing use-after-free (DEVICE_LOST on large scenes like killeroo).
274+
old_static = dhv.static
275+
if old_static !== nothing
276+
push!(_REBUILD_KEEPALIVE, old_static)
277+
end
278+
269279
# Convert CPU data vectors to GPU
270280
data_tuple = if isempty(dhv.data_order)
271281
()
@@ -279,6 +289,10 @@ function rebuild_static!(dhv::MultiTypeSet)
279289
Tuple(Adapt.adapt(dhv.backend, dhv.texture_isbits[T]) for T in dhv.texture_order)
280290
end
281291
dhv.static = StaticMultiTypeSet(data_tuple, tex_tuple)
292+
293+
# Material/light rebuilds enqueue GPU uploads. Make them visible before any
294+
# subsequent BLAS/TLAS kernels or later rebuilds can reuse/finalize backing storage.
295+
KA.synchronize(dhv.backend)
282296
end
283297

284298
# ============================================================================

test/test_instanced_bvh.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -783,6 +783,25 @@ else
783783
@test cl_tlas.nodes isa CLArray
784784
end
785785

786+
@testset "TLAS sync with many instances" begin
787+
mesh = make_triangle_mesh()
788+
transforms = [Mat4f(1, 0, 0, 0,
789+
0, 1, 0, 0,
790+
0, 0, 1, 0,
791+
Float32(mod(i - 1, 9)) * 1.5f0,
792+
Float32((i - 1) ÷ 9) * 1.25f0,
793+
0,
794+
1) for i in 1:81]
795+
796+
tlas = Raycore.TLAS(cl_backend)
797+
push!(tlas, mesh, transforms)
798+
sync!(tlas)
799+
800+
@test length(tlas.instances) == 81
801+
@test length(tlas.nodes) == 161
802+
@test Raycore.world_bound(tlas) isa Bounds3
803+
end
804+
786805
@testset "closest_hit_kernel! - basic intersection" begin
787806
mesh = make_triangle_mesh()
788807
tlas, _ = TLAS([mesh]; backend=cl_backend)

0 commit comments

Comments
 (0)