|
| 1 | +using Enzyme |
| 2 | +using Test |
| 3 | + |
| 4 | +mutable struct MutableWrapper{T} |
| 5 | + x::T |
| 6 | +end |
| 7 | + |
| 8 | +Base.:(==)(a::MutableWrapper, b::MutableWrapper) = (a === b) || isequal(a.x, b.x) |
| 9 | + |
| 10 | +struct Incomplete{T} |
| 11 | + s::String |
| 12 | + x::Float64 |
| 13 | + w::T |
| 14 | + z # not initialized |
| 15 | + Incomplete(s, x, w) = new{typeof(w)}(s, x, w) |
| 16 | +end |
| 17 | + |
| 18 | +function Base.:(==)(a::Incomplete, b::Incomplete) |
| 19 | + (a === b) && return true |
| 20 | + (isequal(a.s, b.s) && isequal(a.x, b.x) && isequal(a.w, b.w)) || return false |
| 21 | + if isdefined(a, :z) && isdefined(b, :z) |
| 22 | + isequal(a.z, b.z) || return false |
| 23 | + elseif isdefined(a, :z) || isdefined(b, :z) |
| 24 | + return false |
| 25 | + end |
| 26 | + return true |
| 27 | +end |
| 28 | + |
| 29 | +mutable struct MutableIncomplete{T} |
| 30 | + s::String |
| 31 | + const x::Float64 |
| 32 | + y::Float64 |
| 33 | + z # not initialized |
| 34 | + w::T |
| 35 | + function MutableIncomplete(s, x, y, w) |
| 36 | + ret = new{typeof(w)}(s, x, y) |
| 37 | + ret.w = w |
| 38 | + return ret |
| 39 | + end |
| 40 | +end |
| 41 | + |
| 42 | +function Base.:(==)(a::MutableIncomplete, b::MutableIncomplete) |
| 43 | + (a === b) && return true |
| 44 | + if !isequal(a.s, b.s) || !isequal(a.x, b.x) || !isequal(a.y, b.y) || !isequal(a.w, b.w) |
| 45 | + return false |
| 46 | + end |
| 47 | + if isdefined(a, :z) && isdefined(b, :z) |
| 48 | + isequal(a.z, b.z) || return false |
| 49 | + elseif isdefined(a, :z) || isdefined(b, :z) |
| 50 | + return false |
| 51 | + end |
| 52 | + return true |
| 53 | +end |
| 54 | + |
| 55 | +@testset "make_zero" begin |
| 56 | + # floats |
| 57 | + @test make_zero(1.0) == 0.0 |
| 58 | + @test make_zero(1.0im) == 0.0im |
| 59 | + |
| 60 | + # float arrays + multiple references |
| 61 | + rearr = [1.0] |
| 62 | + imarr = [1.0im] |
| 63 | + rearr0 = make_zero(rearr) |
| 64 | + imarr0 = make_zero(imarr) |
| 65 | + @test typeof(rearr0) === typeof(rearr) |
| 66 | + @test typeof(imarr0) === typeof(imarr) |
| 67 | + @test rearr == [1.0] # no mutation |
| 68 | + @test imarr == [1.0im] # no mutation |
| 69 | + @test rearr0 == [0.0] |
| 70 | + @test imarr0 == [0.0im] |
| 71 | + rearrs0 = make_zero((rearr, rearr)) |
| 72 | + imarrs0 = make_zero((imarr, imarr)) |
| 73 | + @test typeof(rearrs0) === typeof((rearr, rearr)) |
| 74 | + @test typeof(imarrs0) === typeof((imarr, imarr)) |
| 75 | + @test rearr == [1.0] # no mutation |
| 76 | + @test imarr == [1.0im] # no mutation |
| 77 | + @test rearrs0[1] === rearrs0[2] |
| 78 | + @test imarrs0[1] === imarrs0[2] |
| 79 | + @test rearrs0[1] == [0.0] |
| 80 | + @test imarrs0[1] == [0.0im] |
| 81 | + |
| 82 | + # floats in structs |
| 83 | + rewrapped = MutableWrapper(1.0) |
| 84 | + imwrapped = MutableWrapper(1.0im) |
| 85 | + rewrapped0 = make_zero(rewrapped) |
| 86 | + imwrapped0 = make_zero(imwrapped) |
| 87 | + @test typeof(rewrapped0) === typeof(rewrapped) |
| 88 | + @test typeof(imwrapped0) === typeof(imwrapped) |
| 89 | + @test rewrapped == MutableWrapper(1.0) # no mutation |
| 90 | + @test imwrapped == MutableWrapper(1.0im) # no mutation |
| 91 | + @test rewrapped0 == MutableWrapper(0.0) |
| 92 | + @test imwrapped0 == MutableWrapper(0.0im) |
| 93 | + |
| 94 | + # generic array + multiple references |
| 95 | + wrapped = MutableWrapper(1.0) |
| 96 | + mixarr = ["a", 1.0, wrapped] |
| 97 | + mixarr0 = make_zero(mixarr) |
| 98 | + @test typeof(mixarr0) === typeof(mixarr) |
| 99 | + @test view(mixarr, 1:2) == ["a", 1.0] # no mutation |
| 100 | + @test mixarr[3] === wrapped # no mutation |
| 101 | + @test mixarr0 == ["a", 0.0, MutableWrapper(0.0)] |
| 102 | + mixarrs0 = make_zero((mixarr, mixarr)) |
| 103 | + @test typeof(mixarrs0) === typeof((mixarr, mixarr)) |
| 104 | + @test view(mixarr, 1:2) == ["a", 1.0] # no mutation |
| 105 | + @test mixarr[3] === wrapped # no mutation |
| 106 | + @test mixarrs0[1] === mixarrs0[2] |
| 107 | + @test mixarrs0[1] == ["a", 0.0, MutableWrapper(0.0)] |
| 108 | + |
| 109 | + # non-differentiable array + copy_if_inactive |
| 110 | + constarr = ["a"] |
| 111 | + constarr0 = make_zero(constarr) |
| 112 | + @test typeof(constarr0) === typeof(constarr) |
| 113 | + @test constarr == ["a"] # no mutation |
| 114 | + @test constarr0 === constarr |
| 115 | + constarr0copy = make_zero(constarr, #=copy_if_inactive=#Val(true)) |
| 116 | + @test typeof(constarr0copy) === typeof(constarr0) |
| 117 | + @test constarr == ["a"] # no mutation |
| 118 | + @test constarr0copy !== constarr |
| 119 | + @test constarr0copy == constarr |
| 120 | + |
| 121 | + # Tuple |
| 122 | + tup = ("a", 1.0, MutableWrapper(1.0)) |
| 123 | + tup0 = make_zero(tup) |
| 124 | + @test typeof(tup0) === typeof(tup) |
| 125 | + @test tup == ("a", 1.0, MutableWrapper(1.0)) # no mutation |
| 126 | + @test tup0 == ("a", 0.0, MutableWrapper(0.0)) |
| 127 | + |
| 128 | + # NamedTuple |
| 129 | + ntup = (a="a", b=1.0, c=MutableWrapper(1.0)) |
| 130 | + ntup0 = make_zero(ntup) |
| 131 | + @test typeof(ntup0) === typeof(ntup) |
| 132 | + @test ntup == (a="a", b=1.0, c=MutableWrapper(1.0)) # no mutation |
| 133 | + @test ntup0 == (a="a", b=0.0, c=MutableWrapper(0.0)) |
| 134 | + |
| 135 | + # Box + multiple references |
| 136 | + box = Core.Box(1.0) |
| 137 | + box0 = make_zero(box) |
| 138 | + @test typeof(box0) === typeof(box) |
| 139 | + @test box.contents == 1.0 # no mutation |
| 140 | + @test box0.contents == 0.0 |
| 141 | + boxes0 = make_zero((box, box)) |
| 142 | + @test typeof(boxes0) === typeof((box, box)) |
| 143 | + @test box.contents == 1.0 # no mutation |
| 144 | + @test boxes0[1] === boxes0[2] |
| 145 | + @test boxes0[1].contents == 0.0 |
| 146 | + |
| 147 | + # differentiable custom type + multiple references |
| 148 | + wrapped = MutableWrapper(1.0) |
| 149 | + wrapped0 = make_zero(wrapped) |
| 150 | + @test typeof(wrapped0) === typeof(wrapped) |
| 151 | + @test wrapped == MutableWrapper(1.0) # no mutation |
| 152 | + @test wrapped0 == MutableWrapper(0.0) |
| 153 | + wrappeds0 = make_zero((wrapped, wrapped)) |
| 154 | + @test typeof(wrappeds0) === typeof((wrapped, wrapped)) |
| 155 | + @test wrapped == MutableWrapper(1.0) # no mutation |
| 156 | + @test wrappeds0[1] === wrappeds0[2] |
| 157 | + @test wrappeds0[1] == MutableWrapper(0.0) |
| 158 | + |
| 159 | + # non-differentiable custom type + copy_if_inactive |
| 160 | + constwrapped = MutableWrapper("a") |
| 161 | + constwrapped0 = make_zero(constwrapped) |
| 162 | + @test typeof(constwrapped0) === typeof(constwrapped) |
| 163 | + @test constwrapped == MutableWrapper("a") # no mutation |
| 164 | + @test constwrapped0 === constwrapped |
| 165 | + constwrapped0copy = make_zero(constwrapped, #=copy_if_inactive=#Val(true)) |
| 166 | + @test typeof(constwrapped0copy) === typeof(constwrapped0) |
| 167 | + @test constwrapped == MutableWrapper("a") # no mutation |
| 168 | + @test constwrapped0copy !== constwrapped |
| 169 | + @test constwrapped0copy == constwrapped |
| 170 | + |
| 171 | + # immutable struct with active, mutable, inactive and undefined fields |
| 172 | + incomplete = Incomplete("a", 1.0, MutableWrapper(1.0)) |
| 173 | + incomplete0 = make_zero(incomplete) |
| 174 | + @test typeof(incomplete0) === typeof(incomplete) |
| 175 | + @test incomplete == Incomplete("a", 1.0, MutableWrapper(1.0)) # no mutation |
| 176 | + @test incomplete0 == Incomplete("a", 0.0, MutableWrapper(0.0)) |
| 177 | + |
| 178 | + # mutable struct with inactive, active, undefined, and mutable fields |
| 179 | + # + multiple references |
| 180 | + incompletemut = MutableIncomplete("a", 1.0, 1.0, MutableWrapper(1.0)) |
| 181 | + incompletemut0 = make_zero(incompletemut) |
| 182 | + @test typeof(incompletemut0) === typeof(incompletemut) |
| 183 | + @test incompletemut == MutableIncomplete("a", 1.0, 1.0, MutableWrapper(1.0)) # no mutation |
| 184 | + @test incompletemut0 == MutableIncomplete("a", 0.0, 0.0, MutableWrapper(0.0)) |
| 185 | + incompletemuts0 = make_zero((incompletemut, incompletemut)) |
| 186 | + @test typeof(incompletemuts0) === typeof((incompletemut, incompletemut)) |
| 187 | + @test incompletemut == MutableIncomplete("a", 1.0, 1.0, MutableWrapper(1.0)) # no mutation |
| 188 | + @test incompletemuts0[1] === incompletemuts0[2] |
| 189 | + @test incompletemuts0[1] == MutableIncomplete("a", 0.0, 0.0, MutableWrapper(0.0)) |
| 190 | +end |
| 191 | + |
| 192 | +@testset "make_zero!" begin |
| 193 | + # floats in mutable struct |
| 194 | + rewrapped, imwrapped = MutableWrapper(1.0), MutableWrapper(1.0im) |
| 195 | + make_zero!(rewrapped) |
| 196 | + make_zero!(imwrapped) |
| 197 | + @test rewrapped == MutableWrapper(0.0) |
| 198 | + @test imwrapped == MutableWrapper(0.0im) |
| 199 | + |
| 200 | + # mixed tuple in mutable container |
| 201 | + wrapped = MutableWrapper(1.0) |
| 202 | + tuparr = [(1.0, wrapped)] |
| 203 | + make_zero!(tuparr) |
| 204 | + @test tuparr[1] === (0.0, wrapped) |
| 205 | + @test wrapped == MutableWrapper(0.0) |
| 206 | + |
| 207 | + # mixed namedtuple in mutable container |
| 208 | + wrapped = MutableWrapper(1.0) |
| 209 | + ntuparr = [(a=1.0, b=wrapped)] |
| 210 | + make_zero!(ntuparr) |
| 211 | + @test ntuparr[1] === (a=0.0, b=wrapped) |
| 212 | + @test wrapped == MutableWrapper(0.0) |
| 213 | + |
| 214 | + # immutable struct with active, mutable, inactive and undefined fields in mutable container |
| 215 | + wrapped = MutableWrapper(1.0) |
| 216 | + incompletearr = [Incomplete("a", 1.0, wrapped)] |
| 217 | + make_zero!(incompletearr) |
| 218 | + @test incompletearr[1] == Incomplete("a", 0.0, wrapped) |
| 219 | + @test wrapped == MutableWrapper(0.0) |
| 220 | + |
| 221 | + # floats in Ref |
| 222 | + reref, imref = Ref(1.0), Ref(1.0im) |
| 223 | + make_zero!(reref) |
| 224 | + make_zero!(imref) |
| 225 | + @test reref[] == 0.0 |
| 226 | + @test imref[] == 0.0im |
| 227 | + |
| 228 | + # float arrays |
| 229 | + rearr, imarr = [1.0], [1.0im] |
| 230 | + make_zero!(rearr) |
| 231 | + make_zero!(imarr) |
| 232 | + @test rearr[1] == 0.0 |
| 233 | + @test imarr[1] == 0.0im |
| 234 | + |
| 235 | + # non-differentiable array |
| 236 | + constarr = ["a"] |
| 237 | + make_zero!(constarr) |
| 238 | + @test constarr[1] == "a" |
| 239 | + |
| 240 | + # array with active, mutable, inactive and unassigned elements + multiple references |
| 241 | + wrapped = MutableWrapper(1.0) |
| 242 | + genericarr = Vector(undef, 4) |
| 243 | + genericarr[1:3] .= ("a", 1.0, wrapped) |
| 244 | + genericarrs = [genericarr, genericarr] |
| 245 | + make_zero!(genericarrs) |
| 246 | + @test genericarrs[1] === genericarrs[2] |
| 247 | + @test genericarrs[1] === genericarr |
| 248 | + @test view(genericarr, 1:2) == ["a", 0.0] |
| 249 | + @test genericarr[3] === wrapped |
| 250 | + @test wrapped == MutableWrapper(0.0) |
| 251 | + @test !isassigned(genericarr, 4) |
| 252 | + |
| 253 | + # Ref with multiple references |
| 254 | + genericref = Ref((1.0,)) |
| 255 | + genericrefs = [genericref, genericref] |
| 256 | + make_zero!(genericrefs) |
| 257 | + @test genericrefs[1] === genericrefs[2] |
| 258 | + @test genericrefs[1] === genericref |
| 259 | + @test genericref[] == (0.0,) |
| 260 | + |
| 261 | + # Ref with mutable value |
| 262 | + wrapped = MutableWrapper(1.0) |
| 263 | + mutref = Ref(wrapped) |
| 264 | + make_zero!(mutref) |
| 265 | + @test mutref[] === wrapped |
| 266 | + @test wrapped == MutableWrapper(0.0) |
| 267 | + |
| 268 | + # Ref with non-differentiable value |
| 269 | + constref = Ref("a") |
| 270 | + make_zero!(constref) |
| 271 | + @test constref[] == "a" |
| 272 | + |
| 273 | + # Box with multiple references |
| 274 | + box = Core.Box(1.0) |
| 275 | + boxes = [box, box] |
| 276 | + make_zero!(boxes) |
| 277 | + @test boxes[1] === boxes[2] |
| 278 | + @test boxes[1] === box |
| 279 | + @test box.contents == 0.0 |
| 280 | + |
| 281 | + # Box with mutable value |
| 282 | + wrapped = MutableWrapper(1.0) |
| 283 | + mutbox = Core.Box(wrapped) |
| 284 | + make_zero!(mutbox) |
| 285 | + @test mutbox.contents === wrapped |
| 286 | + @test wrapped == MutableWrapper(0.0) |
| 287 | + |
| 288 | + # Box with non-differentiable value |
| 289 | + constbox = Core.Box("a") |
| 290 | + make_zero!(constbox) |
| 291 | + @test constbox.contents == "a" |
| 292 | + |
| 293 | + # mutable struct with inactive, active, const active, undefined, and mutable fields |
| 294 | + # + multiple references |
| 295 | + wrapped = MutableWrapper(1.0) |
| 296 | + incompletemut = MutableIncomplete("a", #=const=#1.0, 1.0, wrapped) |
| 297 | + incompletemuts = [incompletemut, incompletemut] |
| 298 | + make_zero!(incompletemuts) |
| 299 | + @test incompletemuts[1] === incompletemuts[2] |
| 300 | + @test incompletemuts[1] === incompletemut |
| 301 | + @test incompletemut == MutableIncomplete("a", #=const=#0.0, 0.0, MutableWrapper(0.0)) |
| 302 | + @test incompletemut.w === wrapped |
| 303 | + |
| 304 | + # wrapped differentiable array |
| 305 | + arr = [1.0] |
| 306 | + arrwrapped = MutableWrapper(arr) |
| 307 | + make_zero!(arrwrapped) |
| 308 | + @test arrwrapped.x === arr |
| 309 | + @test arr == [0.0] |
| 310 | + |
| 311 | + # early error on active/mixed type |
| 312 | + @test_throws ArgumentError make_zero!(1.0) |
| 313 | + @test_throws ArgumentError make_zero!((1.0, MutableWrapper(1.0))) |
| 314 | + |
| 315 | + # immutable struct with both active and undefined fields in immutable container |
| 316 | + # (the previous implementation would fail due to #1935) |
| 317 | + wrapped = MutableWrapper(1.0) |
| 318 | + incompletetuparr = [(Incomplete("a", 1.0, wrapped),)] |
| 319 | + make_zero!(incompletetuparr) |
| 320 | + @test incompletetuparr[1][1] == Incomplete("a", 0.0, MutableWrapper(0.0)) |
| 321 | + @test incompletetuparr[1][1].w === wrapped |
| 322 | +end |
0 commit comments