Skip to content

Commit 17bc81a

Browse files
galenlynchstevengjgiordano
authored
Add get_num_threads (#171)
* Add get_num_threads This commit adds `get_num_threads`, which returns the number of threads used by the planner, and is the complement to `set_num_threads`. This simply wraps the function `fftw_planner_nthreads`, which was [newly added to fftw in version 3.3.9](https://github.com/FFTW/fftw3/blob/34082eb5d6ed7dc9436915df69f376c06fc39762/NEWS#L3). * Set FFTW_jll compat to 3.3.9 `get_num_threads` requires FFTW_jll v3.3.9+7, but it doesn't seem possible to specify a particular build in the compat section of Project.toml files. However, this should work in most cases, as the most recent build of `FFTW_jll` should be downloaded upon updating. * bump to 1.3 for the new function * Make test for get_num_threads fftw specific No equivalent function for mkl * Typo... * another typo * Add vendor check to `get_num_threads` * Add a method of `set_num_threads` that restores the original nthreads Additionally, separate previous `set_num_threads` method into a base function, `_set_num_threads`, that wraps the `ccalls`, and `set_num_threads`, which will acquire the `fftwlock`. * Provide support for `get_num_threads` with MKL's FFTW While MKL's FFTW does not provide access to the number of threads available to the planner, this can be simulated by caching the value last passed to `set_num_threads` and returning it with `get_num_threads` if `fftw_vendor == :mkl`. * Implement suggestions of @stevengj * Fix typo in set_num_threads * Add test for set_num_threads method that restores original num_threads * Rename `nthreads` variable to `num_threads` to avoid shadowing Threads.nthreads Since FFTW uses `Base.Threads`, and `nthreads` is a function defined in `Base.Threads`, then the function argument `nthreads` shadows a function already in the namespace of every function. While there is no inherent issue with this, it can make debugging this code more confusing. * Make one-line method of `set_num_threads` one line. * First attempt at adding `num_threads` to `plan_...` functions As suggested by @stevengj, I have add a `num_threads` keyword to the `plan_...` functions. My approach here is fairly naive, and adds a bunch of redundant boiler plate code to every `plan_` function. Co-authored-by: Steven G. Johnson <[email protected]> Co-authored-by: Mosè Giordano <[email protected]>
1 parent 683a6e8 commit 17bc81a

File tree

5 files changed

+128
-22
lines changed

5 files changed

+128
-22
lines changed

Project.toml

+2-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FFTW"
22
uuid = "7a1cc6ca-52ef-59f5-83cd-3a7055c09341"
3-
version = "1.4.6"
3+
version = "1.5.0"
44

55
[deps]
66
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
@@ -12,7 +12,7 @@ Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
1212

1313
[compat]
1414
AbstractFFTs = "1.0"
15-
FFTW_jll = "3.3"
15+
FFTW_jll = "3.3.9"
1616
MKL_jll = "2019.0.117, 2020, 2021, 2022"
1717
Preferences = "1.2"
1818
Reexport = "0.2, 1.0"

src/dct.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
# (This is part of the FFTW module.)
44

55
"""
6-
plan_dct!(A [, dims [, flags [, timelimit]]])
6+
plan_dct!(A [, dims [, flags [, timelimit [, num_threads]]]])
77
88
Same as [`plan_dct`](@ref), but operates in-place on `A`.
99
"""
1010
function plan_dct! end
1111

1212
"""
13-
plan_idct(A [, dims [, flags [, timelimit]]])
13+
plan_idct(A [, dims [, flags [, timelimit [, num_threads]]]])
1414
1515
Pre-plan an optimized inverse discrete cosine transform (DCT), similar to
1616
[`plan_fft`](@ref) except producing a function that computes
@@ -20,7 +20,7 @@ Pre-plan an optimized inverse discrete cosine transform (DCT), similar to
2020
function plan_idct end
2121

2222
"""
23-
plan_dct(A [, dims [, flags [, timelimit]]])
23+
plan_dct(A [, dims [, flags [, timelimit [, num_threads]]]])
2424
2525
Pre-plan an optimized discrete cosine transform (DCT), similar to
2626
[`plan_fft`](@ref) except producing a function that computes
@@ -30,7 +30,7 @@ Pre-plan an optimized discrete cosine transform (DCT), similar to
3030
function plan_dct end
3131

3232
"""
33-
plan_idct!(A [, dims [, flags [, timelimit]]])
33+
plan_idct!(A [, dims [, flags [, timelimit [, num_threads]]]])
3434
3535
Same as [`plan_idct`](@ref), but operates in-place on `A`.
3636
"""

src/fft.jl

+110-16
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ an array of real or complex floating-point numbers.
3838
function r2r! end
3939

4040
"""
41-
plan_r2r!(A, kind [, dims [, flags [, timelimit]]])
41+
plan_r2r!(A, kind [, dims [, flags [, timelimit [, num_threads]]]])
4242
4343
Similar to [`plan_fft`](@ref), but corresponds to [`r2r!`](@ref).
4444
"""
4545
function plan_r2r! end
4646

4747
"""
48-
plan_r2r(A, kind [, dims [, flags [, timelimit]]])
48+
plan_r2r(A, kind [, dims [, flags [, timelimit [, num_threads]]]])
4949
5050
Pre-plan an optimized r2r transform, similar to [`plan_fft`](@ref)
5151
except that the transforms (and the first three arguments)
@@ -171,9 +171,33 @@ end
171171

172172
# Threads
173173

174-
@exclusive function set_num_threads(nthreads::Integer)
175-
ccall((:fftw_plan_with_nthreads,libfftw3[]), Cvoid, (Int32,), nthreads)
176-
ccall((:fftwf_plan_with_nthreads,libfftw3f[]), Cvoid, (Int32,), nthreads)
174+
# Must only be called after acquiring fftwlock
175+
function _set_num_threads(num_threads::Integer)
176+
@static if fftw_provider == "mkl"
177+
_last_num_threads[] = num_threads
178+
end
179+
ccall((:fftw_plan_with_nthreads,libfftw3[]), Cvoid, (Int32,), num_threads)
180+
ccall((:fftwf_plan_with_nthreads,libfftw3f[]), Cvoid, (Int32,), num_threads)
181+
end
182+
183+
@exclusive set_num_threads(num_threads::Integer) = _set_num_threads(num_threads)
184+
185+
function get_num_threads()
186+
@static if fftw_provider == "fftw"
187+
ccall((:fftw_planner_nthreads,libfftw3[]), Cint, ())
188+
else
189+
_last_num_threads[]
190+
end
191+
end
192+
193+
@exclusive function set_num_threads(f::Function, num_threads::Integer)
194+
orig_num_threads = get_num_threads()
195+
_set_num_threads(num_threads)
196+
try
197+
f()
198+
finally
199+
_set_num_threads(orig_num_threads)
200+
end
177201
end
178202

179203
# pointer type for fftw_plan (opaque pointer)
@@ -684,22 +708,43 @@ for (f,direction) in ((:fft,FORWARD), (:bfft,BACKWARD))
684708
@eval begin
685709
function $plan_f(X::StridedArray{T,N}, region;
686710
flags::Integer=ESTIMATE,
687-
timelimit::Real=NO_TIMELIMIT) where {T<:fftwComplex,N}
711+
timelimit::Real=NO_TIMELIMIT,
712+
num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwComplex,N}
713+
if num_threads !== nothing
714+
plan = set_num_threads(num_threads) do
715+
$plan_f(X, region; flags = flags, timelimit = timelimit)
716+
end
717+
return plan
718+
end
688719
cFFTWPlan{T,$direction,false,N}(X, fakesimilar(flags, X, T),
689720
region, flags, timelimit)
690721
end
691722

692723
function $plan_f!(X::StridedArray{T,N}, region;
693-
flags::Integer=ESTIMATE,
694-
timelimit::Real=NO_TIMELIMIT) where {T<:fftwComplex,N}
724+
flags::Integer=ESTIMATE,
725+
timelimit::Real=NO_TIMELIMIT,
726+
num_threads::Union{Nothing, Integer} = nothing ) where {T<:fftwComplex,N}
727+
if num_threads !== nothing
728+
plan = set_num_threads(num_threads) do
729+
$plan_f!(X, region; flags = flags, timelimit = timelimit)
730+
end
731+
return plan
732+
end
695733
cFFTWPlan{T,$direction,true,N}(X, X, region, flags, timelimit)
696734
end
697735
$plan_f(X::StridedArray{<:fftwComplex}; kws...) =
698736
$plan_f(X, 1:ndims(X); kws...)
699737
$plan_f!(X::StridedArray{<:fftwComplex}; kws...) =
700738
$plan_f!(X, 1:ndims(X); kws...)
701739

702-
function plan_inv(p::cFFTWPlan{T,$direction,inplace,N}) where {T<:fftwComplex,N,inplace}
740+
function plan_inv(p::cFFTWPlan{T,$direction,inplace,N};
741+
num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwComplex,N,inplace}
742+
if num_threads !== nothing
743+
plan = set_num_threads(num_threads) do
744+
plan_inv(p)
745+
end
746+
return plan
747+
end
703748
X = Array{T}(undef, p.sz)
704749
Y = inplace ? X : fakesimilar(p.flags, X, T)
705750
ScaledPlan(cFFTWPlan{T,$idirection,inplace,N}(X, Y, p.region,
@@ -735,15 +780,29 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
735780
@eval begin
736781
function plan_rfft(X::StridedArray{$Tr,N}, region;
737782
flags::Integer=ESTIMATE,
738-
timelimit::Real=NO_TIMELIMIT) where N
783+
timelimit::Real=NO_TIMELIMIT,
784+
num_threads::Union{Nothing, Integer} = nothing) where N
785+
if num_threads !== nothing
786+
plan = set_num_threads(num_threads) do
787+
plan_rfft(X, region; flags = flags, timelimit = timelimit)
788+
end
789+
return plan
790+
end
739791
osize = rfft_output_size(X, region)
740792
Y = flags&ESTIMATE != 0 ? FakeArray{$Tc}(osize) : Array{$Tc}(undef, osize)
741793
rFFTWPlan{$Tr,$FORWARD,false,N}(X, Y, region, flags, timelimit)
742794
end
743795

744796
function plan_brfft(X::StridedArray{$Tc,N}, d::Integer, region;
745797
flags::Integer=ESTIMATE,
746-
timelimit::Real=NO_TIMELIMIT) where N
798+
timelimit::Real=NO_TIMELIMIT,
799+
num_threads::Union{Nothing, Integer} = nothing) where N
800+
if num_threads !== nothing
801+
plan = set_num_threads(num_threads) do
802+
plan_brfft(X, d, region; flags = flags, timelimit = timelimit)
803+
end
804+
return plan
805+
end
747806
osize = brfft_output_size(X, d, region)
748807
Y = flags&ESTIMATE != 0 ? FakeArray{$Tr}(osize) : Array{$Tr}(undef, osize)
749808

@@ -763,7 +822,14 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
763822
plan_rfft(X::StridedArray{$Tr};kws...)=plan_rfft(X,1:ndims(X);kws...)
764823
plan_brfft(X::StridedArray{$Tr};kws...)=plan_brfft(X,1:ndims(X);kws...)
765824

766-
function plan_inv(p::rFFTWPlan{$Tr,$FORWARD,false,N}) where N
825+
function plan_inv(p::rFFTWPlan{$Tr,$FORWARD,false,N},
826+
num_threads::Union{Nothing, Integer} = nothing) where N
827+
if num_threads !== nothing
828+
plan = set_num_threads(num_threads) do
829+
plan_inv(p)
830+
end
831+
return plan
832+
end
767833
X = Array{$Tr}(undef, p.sz)
768834
Y = p.flags&ESTIMATE != 0 ? FakeArray{$Tc}(p.osz) : Array{$Tc}(undef, p.osz)
769835
ScaledPlan(rFFTWPlan{$Tc,$BACKWARD,false,N}(Y, X, p.region,
@@ -773,7 +839,14 @@ for (Tr,Tc) in ((:Float32,:(Complex{Float32})),(:Float64,:(Complex{Float64})))
773839
normalization(X, p.region))
774840
end
775841

776-
function plan_inv(p::rFFTWPlan{$Tc,$BACKWARD,false,N}) where N
842+
function plan_inv(p::rFFTWPlan{$Tc,$BACKWARD,false,N};
843+
num_threads::Union{Nothing, Integer} = nothing) where N
844+
if num_threads !== nothing
845+
plan = set_num_threads(num_threads) do
846+
plan_inv(p)
847+
end
848+
return plan
849+
end
777850
X = Array{$Tc}(undef, p.sz)
778851
Y = p.flags&ESTIMATE != 0 ? FakeArray{$Tr}(p.osz) : Array{$Tr}(undef, p.osz)
779852
ScaledPlan(rFFTWPlan{$Tr,$FORWARD,false,N}(Y, X, p.region,
@@ -832,14 +905,28 @@ end
832905

833906
function plan_r2r(X::StridedArray{T,N}, kinds, region;
834907
flags::Integer=ESTIMATE,
835-
timelimit::Real=NO_TIMELIMIT) where {T<:fftwNumber,N}
908+
timelimit::Real=NO_TIMELIMIT,
909+
num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwNumber,N}
910+
if num_threads !== nothing
911+
plan = set_num_threads(num_threads) do
912+
plan_r2r(X, kinds, region; flags = flags, timelimit = timelimit)
913+
end
914+
return plan
915+
end
836916
r2rFFTWPlan{T,Any,false,N}(X, fakesimilar(flags, X, T), region, kinds,
837917
flags, timelimit)
838918
end
839919

840920
function plan_r2r!(X::StridedArray{T,N}, kinds, region;
841921
flags::Integer=ESTIMATE,
842-
timelimit::Real=NO_TIMELIMIT) where {T<:fftwNumber,N}
922+
timelimit::Real=NO_TIMELIMIT,
923+
num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwNumber,N}
924+
if num_threads !== nothing
925+
plan = set_num_threads(num_threads) do
926+
plan_r2r(X, kinds, region; flags = flags, timelimit = timelimit)
927+
end
928+
return plan
929+
end
843930
r2rFFTWPlan{T,Any,true,N}(X, X, region, kinds, flags, timelimit)
844931
end
845932

@@ -861,7 +948,14 @@ function logical_size(n::Integer, k::Integer)
861948
return 2n
862949
end
863950

864-
function plan_inv(p::r2rFFTWPlan{T,K,inplace,N}) where {T<:fftwNumber,K,inplace,N}
951+
function plan_inv(p::r2rFFTWPlan{T,K,inplace,N};
952+
num_threads::Union{Nothing, Integer} = nothing) where {T<:fftwNumber,K,inplace,N}
953+
if num_threads !== nothing
954+
set_num_threads(num_threads) do
955+
plan = plan_inv(p)
956+
end
957+
return plan
958+
end
865959
X = Array{T}(undef, p.sz)
866960
iK = fix_kinds(p.region, [inv_kind[k] for k in K])
867961
Y = inplace ? X : fakesimilar(p.flags, X, T)

src/providers.jl

+1
Original file line numberDiff line numberDiff line change
@@ -85,4 +85,5 @@ end
8585
import MKL_jll
8686
libfftw3[] = MKL_jll.libmkl_rt_path
8787
libfftw3f[] = MKL_jll.libmkl_rt_path
88+
const _last_num_threads = Ref(Cint(1))
8889
end

test/runtests.jl

+11
Original file line numberDiff line numberDiff line change
@@ -528,3 +528,14 @@ end
528528
@test occursin("dft-thr", string(p2))
529529
end
530530
end
531+
532+
@testset "Setting and getting planner nthreads" begin
533+
FFTW.set_num_threads(1)
534+
@test FFTW.get_num_threads() == 1
535+
FFTW.set_num_threads(2)
536+
@test FFTW.get_num_threads() == 2
537+
plan = FFTW.set_num_threads(1) do # Should leave get_num_threads unchanged
538+
plan_rfft(m4, 1)
539+
end
540+
@test FFTW.get_num_threads() == 2 # Unchanged
541+
end

0 commit comments

Comments
 (0)