Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.61.0"
version = "1.61.1"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -20,6 +20,7 @@ SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"

[compat]
Adapt = "3.4.0, 4"
AxisArrays = "0.4.7"
ChainRulesCore = "1.20"
ChainRulesTestUtils = "1.5"
Compat = "3.46, 4.2"
Expand All @@ -41,6 +42,7 @@ SuiteSparse = "1"
julia = "1.6"

[extras]
AxisArrays = "39de3d68-74b9-583c-8d2d-e117c070f3a9"
ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
Expand All @@ -50,4 +52,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"]
test = ["AxisArrays", "ChainRulesTestUtils", "FiniteDifferences", "JLArrays", "JuliaInterpreter", "Random", "StaticArrays", "Test"]
11 changes: 4 additions & 7 deletions src/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,21 +125,18 @@ Base.:(+)(oe1::OneElement, oe2::OneElement) = +(collect(oe1), oe2)
This returns roughly `dx = zero(x)`, except that this is guaranteed to be mutable via `similar`,
and its element type is wide enough to allow `setindex!(dx, dy, inds...)`, which is exactly what
`∇getindex` does next.

It's unfortunate to close over `x`, but `similar(typeof(x), axes(x))` doesn't
allow `eltype(dy)`, nor does it work for many structured matrices.
"""
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy), axes(x)), false)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy), axes(x)), false)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds::Integer...) = fill!(similar(x, typeof(dy)), false)
_setindex_zero(x::AbstractArray{<:Number}, dy, inds...) = fill!(similar(x, eltype(dy)), false)
function _setindex_zero(x::AbstractArray, dy, inds::Integer...)
# This allows for types which don't define zero (like Vector) and types whose zero special (like Tangent),
# but always makes an abstract type. TODO: make it infer concrete type for e.g. vectors of SVectors
T = Union{typeof(dy), ZeroTangent}
return fill!(similar(x, T, axes(x)), ZeroTangent())
return fill!(similar(x, T), ZeroTangent())
end
function _setindex_zero(x::AbstractArray, dy, inds...)
T = Union{eltype(dy), ZeroTangent}
return fill!(similar(x, T, axes(x)), ZeroTangent())
return fill!(similar(x, T), ZeroTangent())
end
ChainRules.@non_differentiable _setindex_zero(x::AbstractArray, dy::Any, inds::Any...)

Expand Down
7 changes: 7 additions & 0 deletions test/rulesets/Base/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,13 @@ end
@test dx23[3] == dxfix[3]
end

@testset "getindex(::AxisArray{<:Number})" begin
X = randn((2, 3))
A = AxisArray(X; row=[:a, :b], col=[:x, :y, :z])
dA, back = rrule(getindex, A, [:a], [:x, :z])
unthunk(back(ones(1, 2))[2]) == [1.0 0.0 1.0; 0.0 0.0 0.0]
end

@testset "second derivatives: ∇getindex" begin
@eval using ChainRules: ∇getindex
# Forward, scalar result
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ using Test, ChainRulesCore, ChainRulesTestUtils
@nospecialize

using Adapt
using AxisArrays
using Base.Broadcast: broadcastable
using ChainRules
using ChainRules: stack
Expand Down