Skip to content

Commit c3ff412

Browse files
committed
More validation + fixup tests
1 parent 09662e5 commit c3ff412

File tree

2 files changed

+20
-20
lines changed

2 files changed

+20
-20
lines changed

src/fft.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,11 @@ mutable struct MtlFFTPlan{T <: FFTNumber, S <: FFTNumber, backward, inplace, N,
7878

7979
function MtlFFTPlan{T, S, backward, inplace, N, R}(input_size::NTuple{N, Int}, output_size::NTuple{N, Int}, region::NTuple{R, Int}) where {T <: FFTNumber, S <: FFTNumber, backward, inplace, N, R}
8080
# Validate region
81-
for r in region
82-
1 <= r <= N || throw(ArgumentError("Invalid FFT dimension $r for array with $N dimensions"))
81+
if any(diff(collect(region)) .< 1)
82+
throw(ArgumentError("region must be an increasing sequence"))
83+
end
84+
if any(region .< 1 .|| region .> N)
85+
throw(ArgumentError("region can only refer to valid dimensions"))
8386
end
8487
backward isa Bool || throw(ArgumentError("FFT backward argument must be a Bool"))
8588
inplace isa Bool || throw(ArgumentError("FFT inplace argument must be a Bool"))
@@ -277,14 +280,15 @@ function LinearAlgebra.mul!(y::MtlArray{T, N}, p::MtlFFTPlan{T, S, backward, inp
277280
end
278281

279282
function Base.:(*)(p::MtlFFTPlan{T, S, backward, true}, x::MtlArray{S}) where {T, S, backward}
280-
# assert_applicable(p, x)
281-
LinearAlgebra.mul!(x, p, x)
283+
assert_applicable(p, x)
284+
285+
unsafe_execute!(p, x, x)
282286
return x
283287
end
284288
function Base.:(*)(p::MtlFFTPlan{T, S, backward, false}, x::MtlArray{S}) where {T, S, backward}
285-
# assert_applicable(p, x)
289+
assert_applicable(p, x)
286290

287291
y = MtlArray{T}(undef, p.output_size)
288-
LinearAlgebra.mul!(y, p, x)
292+
unsafe_execute!(p, x, y)
289293
return y
290294
end

test/fft.jl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ if MPS.is_supported(device())
159159

160160
p = plan_fft(d_X, region)
161161
d_Y = p * d_X
162-
# d_X2 = reshape(d_X, (size(d_X)..., 1))
163-
# @test_throws ArgumentError p * d_X2
162+
d_X2 = reshape(d_X, (size(d_X)..., 1))
163+
@test_throws ArgumentError p * d_X2
164164

165165
Y = Array(d_Y)
166166
@test isapprox(Y, fftw_X, rtol = rtol(T), atol = atol(T))
@@ -238,19 +238,17 @@ if MPS.is_supported(device())
238238
test_complex_batched(X, region)
239239
end
240240

241-
# X = rand(T, dims)
242-
# @test_throws ArgumentError test_complex_batched(X, (3, 1))
241+
X = rand(T, dims)
242+
@test_throws ArgumentError test_complex_batched(X, (3, 1))
243243
end
244244
@testset "Batch 2D (in 4D)" begin
245245
dims = (N1, N2, N3, N4)
246246
for region in [(1, 2), (1, 4), (3, 4), (1, 3), (2, 3), (2,), (3,)]
247247
X = rand(T, dims)
248248
test_complex_batched(X, region)
249249
end
250-
# for region in [(2, 4)]
251-
# X = rand(T, dims)
252-
# @test_throws ArgumentError test_complex_batched(X, region)
253-
# end
250+
X = rand(T, dims)
251+
test_complex_batched(X, (2, 4))
254252
end
255253
end
256254
end
@@ -338,8 +336,8 @@ if MPS.is_supported(device())
338336
test_real_batched(X, region)
339337
end
340338

341-
# X = rand(T, dims)
342-
# @test_throws ArgumentError test_real_batched(X, (3, 1))
339+
X = rand(T, dims)
340+
@test_throws ArgumentError test_real_batched(X, (3, 1))
343341
end
344342

345343
@testset "Batch 2D (in 4D)" begin
@@ -348,10 +346,8 @@ if MPS.is_supported(device())
348346
X = rand(T, dims)
349347
test_real_batched(X, region)
350348
end
351-
# for region in [(2,4)]
352-
# X = rand(T, dims)
353-
# @test_throws ArgumentError test_real_batched(X, region)
354-
# end
349+
X = rand(T, dims)
350+
test_real_batched(X, (2, 4))
355351
end
356352

357353
@testset "3D" begin

0 commit comments

Comments
 (0)