|
1 | 1 | using Base: RefValue |
2 | 2 |
|
| 3 | +if VERSION > v"1.7.0-DEV.204" |
| 4 | + using Base: ismutabletype |
| 5 | +else |
| 6 | + function ismutabletype(@nospecialize(t::Type)) |
| 7 | + t = Base.unwrap_unionall(t) |
| 8 | + return isa(t, DataType) && t.mutable |
| 9 | + end |
| 10 | +end |
| 11 | + |
3 | 12 | # Interfaces |
4 | 13 |
|
5 | 14 | accum() = nothing |
@@ -278,19 +287,19 @@ Jnew{T}(g) where T = Jnew{T,typeof(g)}(g) |
278 | 287 |
|
279 | 288 | @adjoint! function __new__(T, args...) |
280 | 289 | x = __new__(T, args...) |
281 | | - g = !T.mutable || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x) |
| 290 | + g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x) |
282 | 291 | x, Jnew{T,typeof(g),false}(g) |
283 | 292 | end |
284 | 293 |
|
285 | 294 | @adjoint! function __splatnew__(T, args) |
286 | 295 | x = __splatnew__(T, args) |
287 | | - g = !T.mutable || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x) |
| 296 | + g = !ismutabletype(T) || fieldcount(T) == 0 ? nothing : grad_mut(__context__, x) |
288 | 297 | x, Jnew{T,typeof(g),true}(g) |
289 | 298 | end |
290 | 299 |
|
291 | 300 | # TODO captured mutables + multiple calls to `back` |
292 | 301 | @generated function (back::Jnew{T,G,false})(Δ::Union{NamedTuple,Nothing,RefValue}) where {T,G} |
293 | | - !T.mutable && Δ == Nothing && return :nothing |
| 302 | + !ismutabletype(T) && Δ == Nothing && return :nothing |
294 | 303 | Δ = G == Nothing ? :Δ : |
295 | 304 | Δ <: RefValue ? :(back.g[]) : |
296 | 305 | :(accum(back.g[], Δ)) |
|
302 | 311 | end |
303 | 312 |
|
304 | 313 | @generated function (back::Jnew{T,G,true})(Δ::Union{NamedTuple,Nothing,RefValue}) where {T,G} |
305 | | - !T.mutable && Δ == Nothing && return :nothing |
| 314 | + !ismutabletype(T) && Δ == Nothing && return :nothing |
306 | 315 | Δ = G == Nothing ? :Δ : :(back.g) |
307 | 316 | quote |
308 | 317 | x̄ = $Δ |
|
0 commit comments