Skip to content

Commit 3d433df

Browse files
authored
Merge pull request #182 from magerton/pull-request/61d75528
Fix gradients for ScaledInterpolation with NoInterp
2 parents 62dd04b + 61d7552 commit 3d433df

File tree

2 files changed

+19
-2
lines changed

2 files changed

+19
-2
lines changed

src/scaling/scaling.jl

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,12 @@ gradient(sitp::ScaledInterpolation{T,N,ITPT,IT,GT}, xs...) where {T,N,ITPT,IT<:D
9898
quote
9999
length(g) == $(count_interp_dims(IT, N)) || throw(ArgumentError(string("The length of the provided gradient vector (", length(g), ") did not match the number of interpolating dimensions (", $(count_interp_dims(IT, N)), ")")))
100100
gradient!(g, sitp.itp, $(interp_indices...))
101-
for i in eachindex(g)
102-
g[i] = rescale_gradient(sitp.ranges[i], g[i])
101+
cntr = 0
102+
for i = 1:N
103+
if $(interp_dimens)[i]
104+
cntr += 1
105+
g[cntr] = rescale_gradient(sitp.ranges[i], g[cntr])
106+
end
103107
end
104108
g
105109
end

test/scaling/nointerp.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,5 +21,18 @@ end
2121

2222
@test length(gradient(sitp, pi/3, 2)) == 1
2323

24+
# check for case where initial/middle indices are NoInterp but later ones are <:BSpline
25+
srand(1234)
26+
z0 = rand(10,10)
27+
za = copy(z0)
28+
zb = copy(z0')
29+
30+
itpa = interpolate(za, (BSpline(Linear()), NoInterp()), OnGrid())
31+
itpb = interpolate(zb, (NoInterp(), BSpline(Linear())), OnGrid())
32+
33+
rng = linspace(1.0, 19.0, 10)
34+
sitpa = scale(itpa, rng, 1:10)
35+
sitpb = scale(itpb, 1:10, rng)
36+
@test gradient(sitpa, 3.0, 3) == gradient(sitpb, 3, 3.0)
2437

2538
end

0 commit comments

Comments
 (0)