Skip to content

Commit 9759239

Browse files
Merge pull request #1003 from mcabbott/ismutabletype
use `Base.ismutabletype`
2 parents 531da8b + ec1a41b commit 9759239

File tree

1 file changed

+13
-4
lines changed

1 file changed

+13
-4
lines changed

src/lib/lib.jl

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
using Base: RefValue
22

3+
if VERSION > v"1.7.0-DEV.204"
4+
using Base: ismutabletype
5+
else
6+
function ismutabletype(@nospecialize(t::Type))
7+
t = Base.unwrap_unionall(t)
8+
return isa(t, DataType) && t.mutable
9+
end
10+
end
11+
312
# Interfaces
413

514
accum() = nothing
@@ -278,19 +287,19 @@ Jnew{T}(g) where T = Jnew{T,typeof(g)}(g)
278287

279288
@adjoint! function __new__(T, args...)
280289
x = __new__(T, args...)
281-
g = !T.mutable || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
290+
g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
282291
x, Jnew{T,typeof(g),false}(g)
283292
end
284293

285294
@adjoint! function __splatnew__(T, args)
286295
x = __splatnew__(T, args)
287-
g = !T.mutable || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
296+
g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x)
288297
x, Jnew{T,typeof(g),true}(g)
289298
end
290299

291300
# TODO captured mutables + multiple calls to `back`
292301
@generated function (back::Jnew{T,G,false})(Δ::Union{NamedTuple,Nothing,RefValue}) where {T,G}
293-
!T.mutable && Δ == Nothing && return :nothing
302+
!ismutabletype(T) && Δ == Nothing && return :nothing
294303
Δ = G == Nothing ? :
295304
Δ <: RefValue ? :(back.g[]) :
296305
:(accum(back.g[], Δ))
@@ -302,7 +311,7 @@ end
302311
end
303312

304313
@generated function (back::Jnew{T,G,true})(Δ::Union{NamedTuple,Nothing,RefValue}) where {T,G}
305-
!T.mutable && Δ == Nothing && return :nothing
314+
!ismutabletype(T) && Δ == Nothing && return :nothing
306315
Δ = G == Nothing ? : :(back.g)
307316
quote
308317
= $Δ

0 commit comments

Comments
 (0)