Skip to content

Commit f8ceabc

Browse files
committed
updates
1 parent fccda38 commit f8ceabc

File tree

2 files changed

+18
-33
lines changed

2 files changed

+18
-33
lines changed

ext/ParallelTransforms.jl

Lines changed: 12 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -729,7 +729,8 @@ function SHTnsKit.dist_synthesis(cfg::SHTnsKit.SHTConfig, Alm::AbstractMatrix; p
729729
end
730730

731731
function SHTnsKit.dist_synthesis(cfg::SHTnsKit.SHTConfig, Alm::PencilArray; prototype_θφ::PencilArray, real_output::Bool=true, use_rfft::Bool=false)
732-
return SHTnsKit.dist_synthesis(cfg, Array(Alm); prototype_θφ, real_output, use_rfft)
732+
Alm_dense = SHTnsKit.spectral_pencil_to_matrix(cfg, Alm)
733+
return SHTnsKit.dist_synthesis(cfg, Alm_dense; prototype_θφ, real_output, use_rfft)
733734
end
734735

735736
function SHTnsKit.dist_synthesis!(plan::DistPlan, fθφ_out::PencilArray, Alm::PencilArray; real_output::Bool=true)
@@ -1593,34 +1594,18 @@ function dist_analysis_distributed(cfg::SHTnsKit.SHTConfig, fθφ::PencilArray;
15931594
end
15941595
end
15951596

1596-
# Create output distributed array
1597-
result = create_distributed_spectral_array(plan, ComplexF64)
1598-
1599-
# Pack all coefficients in l-major order grouped by owner rank, then Allreduce
1600-
# and extract the local portion for this rank
1601-
total_nlm = sum(plan.recv_counts)
1602-
local_contribs_packed = Vector{ComplexF64}(undef, total_nlm)
1603-
1604-
# Pack in l-major order, grouped by owner rank
1605-
# recv_counts[r+1] = count for rank r, where rank r owns l values where l % nprocs == r
1606-
idx = 0
1607-
for owner_rank in 0:(plan.nprocs - 1)
1608-
for l in 0:lmax
1609-
if l % plan.nprocs == owner_rank
1610-
for m in 0:min(l, mmax)
1611-
idx += 1
1612-
local_contribs_packed[idx] = local_contrib[l+1, m+1]
1613-
end
1614-
end
1615-
end
1597+
# Only reduce if θ is distributed across ranks (if all ranks have all latitudes,
1598+
# each rank's local_contrib is already the complete answer)
1599+
θ_is_distributed = (nθ_local < cfg.nlat)
1600+
if θ_is_distributed
1601+
MPI.Allreduce!(MPI.IN_PLACE, local_contrib, +, comm)
16161602
end
16171603

1618-
# Allreduce the packed buffer, then extract the local portion for this rank
1619-
full_reduced = similar(local_contribs_packed)
1620-
MPI.Allreduce!(local_contribs_packed, full_reduced, +, comm)
1621-
offset = plan.recv_displs[plan.rank + 1]
1622-
count = plan.recv_counts[plan.rank + 1]
1623-
copyto!(result.local_coeffs, 1, full_reduced, offset + 1, count)
1604+
# Create output distributed array and extract local portion
1605+
result = create_distributed_spectral_array(plan, ComplexF64)
1606+
for (i, (l, m)) in enumerate(plan.local_lm_indices)
1607+
result.local_coeffs[i] = local_contrib[l+1, m+1]
1608+
end
16241609

16251610
return result
16261611
end

test/parallel/test_mpi_comprehensive.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -534,8 +534,8 @@ function test_plan_synthesis(cfg::SHTnsKit.SHTConfig, pen::Pencil)
534534
# Fill in spectral data
535535
for (ii, il) in enumerate(axes(Alm_pencil, 1))
536536
for (jj, jm) in enumerate(axes(Alm_pencil, 2))
537-
lval = globalindices(Alm_pencil, 1)[ii] - 1
538-
mval = globalindices(Alm_pencil, 2)[jj] - 1
537+
lval = ParExt.globalindices(Alm_pencil, 1)[ii] - 1
538+
mval = ParExt.globalindices(Alm_pencil, 2)[jj] - 1
539539
if lval >= mval && mval <= mmax && lval <= lmax
540540
Alm_pencil[il, jm] = alm_orig[lval+1, mval+1]
541541
else
@@ -819,8 +819,8 @@ function test_point_evaluation(cfg::SHTnsKit.SHTConfig, pen::Pencil)
819819
Alm_pencil = PencilArray{ComplexF64}(undef, spec_pen)
820820
for (ii, il) in enumerate(axes(Alm_pencil, 1))
821821
for (jj, jm) in enumerate(axes(Alm_pencil, 2))
822-
lval = globalindices(Alm_pencil, 1)[ii] - 1
823-
mval = globalindices(Alm_pencil, 2)[jj] - 1
822+
lval = ParExt.globalindices(Alm_pencil, 1)[ii] - 1
823+
mval = ParExt.globalindices(Alm_pencil, 2)[jj] - 1
824824
if lval >= mval && mval <= mmax && lval <= lmax
825825
Alm_pencil[il, jm] = alm[lval+1, mval+1]
826826
else
@@ -868,8 +868,8 @@ function test_latitude_evaluation(cfg::SHTnsKit.SHTConfig, pen::Pencil)
868868
Alm_pencil = PencilArray{ComplexF64}(undef, spec_pen)
869869
for (ii, il) in enumerate(axes(Alm_pencil, 1))
870870
for (jj, jm) in enumerate(axes(Alm_pencil, 2))
871-
lval = globalindices(Alm_pencil, 1)[ii] - 1
872-
mval = globalindices(Alm_pencil, 2)[jj] - 1
871+
lval = ParExt.globalindices(Alm_pencil, 1)[ii] - 1
872+
mval = ParExt.globalindices(Alm_pencil, 2)[jj] - 1
873873
if lval >= mval && mval <= mmax && lval <= lmax
874874
Alm_pencil[il, jm] = alm[lval+1, mval+1]
875875
else

0 commit comments

Comments
 (0)