Skip to content

Commit 84b3079

Browse files
authored
Use CR OneElement
1 parent f8779f8 commit 84b3079

File tree

3 files changed

+2
-17
lines changed

3 files changed

+2
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2828

2929
[compat]
3030
AbstractFFTs = "1.3.1"
31-
ChainRules = "1.44.1"
31+
ChainRules = "1.51.0"
3232
ChainRulesCore = "1.9"
3333
ChainRulesTestUtils = "1"
3434
Colors = "0.12"

src/Zygote.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ import ZygoteRules: @adjoint, @adjoint!, AContext, adjoint, _pullback, pullback,
77
literal_getproperty, literal_getfield, unthunk_tangent
88

99
using ChainRulesCore
10-
using ChainRules: ChainRules, rrule, unthunk, canonicalize
10+
using ChainRules: ChainRules, rrule, unthunk, canonicalize, OneElement
1111
using IRTools
1212
using MacroTools, Requires
1313
using MacroTools: @forward

src/lib/array.jl

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -41,21 +41,6 @@ end
4141
@adjoint (::Type{T})(sz) where {T<:Zeros} = T(sz), Δ->(nothing,)
4242
@adjoint (::Type{T})(sz) where {T<:Ones} = T(sz), Δ->(nothing,)
4343

44-
"""
45-
OneElement(val, ind, axes) <: AbstractArray
46-
47-
Extremely simple `struct` used for the gradient of scalar `getindex`.
48-
"""
49-
struct OneElement{T,N,I,A} <: AbstractArray{T,N}
50-
val::T
51-
ind::I
52-
axes::A
53-
OneElement(val::T, ind::I, axes::A) where {T<:Number, I<:NTuple{N,Int}, A<:NTuple{N,AbstractUnitRange}} where {N} = new{T,N,I,A}(val, ind, axes)
54-
end
55-
Base.size(A::OneElement) = map(length, A.axes)
56-
Base.axes(A::OneElement) = A.axes
57-
Base.getindex(A::OneElement{T,N}, i::Vararg{Int,N}) where {T,N} = ifelse(i==A.ind, A.val, zero(T))
58-
5944

6045
_zero(xs::AbstractArray{<:Number}, T::Type{Nothing}) = fill!(similar(xs), zero(eltype(xs)))
6146
_zero(xs::AbstractArray{<:Number}, T) = fill!(similar(xs, T), false)

0 commit comments

Comments
 (0)