Skip to content

Commit f7215bb

Browse files
committed
Add warning and separate aligned function for the vloada extract_smatrix
1 parent e431d8b commit f7215bb

File tree

3 files changed

+24
-2
lines changed

3 files changed

+24
-2
lines changed

src/general/abstract_system.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,18 @@ end
8484

8585
# Optimized version for 2D, which uses SIMD.jl to combine the 4 loads of the 2x2 matrix
8686
# into a single wide load. This is significantly faster on GPUs than 4 individual loads.
87-
@inline function extract_smatrix(A, ::Val{2}, particle)
87+
# WARNING:
88+
# 1. This only works if the matrix elements are stored contiguously in memory.
89+
# The 4 elements of the 2x2 matrix for a particle must immediately follow the
90+
# 4 elements of the previous particle's matrix, with no padding in between.
91+
# 2. The pointer of `A` must be aligned to the size of `Vec{4, eltype(A)}`.
92+
# This is guaranteed if `A` is allocated by Julia and has the correct size,
93+
# but may not be true if `A` is a view or subarray.
94+
@propagate_inbounds function extract_smatrix_aligned(A, system, particle)
95+
return extract_smatrix_aligned(A, Val(ndims(system)), particle)
96+
end
97+
98+
@inline function extract_smatrix_aligned(A, ::Val{2}, particle)
8899
@boundscheck checkbounds(A, 2, 2, particle)
89100

90101
# Note that this doesn't work in 3D because it requires a stride of 2^n.
@@ -93,6 +104,11 @@ end
93104
return SMatrix{2, 2}(Tuple(x))
94105
end
95106

107+
# Fall back to the generic version when not in 2D.
108+
@propagate_inbounds function extract_smatrix_aligned(A, ::Val{NDIMS}, particle) where {NDIMS}
109+
return extract_smatrix(A, Val(NDIMS), particle)
110+
end
111+
96112
# Specifically get the current coordinates of a particle for all system types.
97113
@propagate_inbounds function current_coords(u, system, particle)
98114
return extract_svector(current_coordinates(u, system), system, particle)

src/general/corrections.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,9 @@ end
395395

396396
function correction_matrix_inversion_step!(corr_matrix, system, semi)
397397
@threaded semi for particle in eachparticle(system)
398-
L = extract_smatrix(corr_matrix, system, particle)
398+
# `corr_matrix` is not a view, so we can use the fast `extract_smatrix_aligned`
399+
# instead of the generic `extract_smatrix`.
400+
L = extract_smatrix_aligned(corr_matrix, system, particle)
399401

400402
# The matrix `L` only becomes singular when the particle and all neighbors
401403
# are collinear (in 2D) or lie all in the same plane (in 3D).

src/schemes/structure/total_lagrangian_sph/system.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,9 +356,13 @@ end
356356
end
357357

358358
@propagate_inbounds function deformation_gradient(system, particle)
359+
# `deformation_grad` is not a view, so we can use the fast `extract_smatrix_aligned`
360+
# instead of the generic `extract_smatrix`.
359361
extract_smatrix(system.deformation_grad, system, particle)
360362
end
361363
@propagate_inbounds function pk1_rho2(system, particle)
364+
# `pk1_rho2` is not a view, so we can use the fast `extract_smatrix_aligned`
365+
# instead of the generic `extract_smatrix`.
362366
extract_smatrix(system.pk1_rho2, system, particle)
363367
end
364368

0 commit comments

Comments
 (0)