|
88 | 88 |
|
89 | 89 | cache = Mooncake.prepare_gradient_cache(f, x) |
90 | 90 | v, dx = Mooncake.value_and_gradient!!(cache, f, x) |
91 | | - @test dx[2] isa Mooncake.Tangent{@NamedTuple{x1::Float64, x2::Float64}} |
| 91 | + @test dx[2] isa Mooncake.Tangent{@NamedTuple{x1::Float64,x2::Float64}} |
92 | 92 | @test dx[2].fields == (; x1=2 * x.x1, x2=cos(x.x2)) |
93 | 93 |
|
94 | 94 | cache = Mooncake.prepare_gradient_cache( |
|
101 | 101 | rule = build_rrule(f, x) |
102 | 102 |
|
103 | 103 | v, dx = Mooncake.value_and_gradient!!(rule, f, x) |
104 | | - @test dx[2] isa Mooncake.Tangent{@NamedTuple{x1::Float64, x2::Float64}} |
| 104 | + @test dx[2] isa Mooncake.Tangent{@NamedTuple{x1::Float64,x2::Float64}} |
105 | 105 | @test dx[2].fields == (; x1=2 * x.x1, x2=cos(x.x2)) |
106 | 106 |
|
107 | 107 | v, dx = Mooncake.value_and_gradient!!(rule, f, x; friendly_tangents=true) |
|
271 | 271 |
|
272 | 272 | @testset "__exclude_unsupported_output , $(test_set)" for test_set in |
273 | 273 | additional_test_set |
| 274 | + |
274 | 275 | try |
275 | 276 | Mooncake.__exclude_unsupported_output(test_set[2]) |
276 | 277 | catch err |
|
280 | 281 |
|
281 | 282 | @testset "_copy_output & _copy_to_output!!, $(test_set)" for test_set in |
282 | 283 | additional_test_set |
| 284 | + |
283 | 285 | original = test_set[2] |
284 | 286 | try |
285 | 287 | if isnothing(Mooncake.__exclude_unsupported_output(original)) |
@@ -351,15 +353,12 @@ end |
351 | 353 | cache_sp_unfriendly = Mooncake.prepare_derivative_cache( |
352 | 354 | fx_sp...; config=Mooncake.Config(; friendly_tangents=false, kwargs...) |
353 | 355 | ) |
354 | | - if get(kwargs, :debug_mode, false) |
355 | | - @test_throws ErrorException Mooncake.value_and_derivative!!( |
356 | | - cache_sp_unfriendly, zip(fx_sp, dfx_sp)... |
357 | | - ) |
358 | | - else |
359 | | - @test_throws TypeError Mooncake.value_and_derivative!!( |
360 | | - cache_sp_unfriendly, zip(fx_sp, dfx_sp)... |
361 | | - ) |
362 | | - end |
| 356 | + @test_throws ArgumentError Mooncake.value_and_derivative!!( |
| 357 | + cache_sp_unfriendly, zip(fx_sp, dfx_sp)... |
| 358 | + ) |
| 359 | + @test_throws "Tangent types do not match primal types:" Mooncake.value_and_derivative!!( |
| 360 | + cache_sp_unfriendly, zip(fx_sp, dfx_sp)... |
| 361 | + ) |
363 | 362 | end |
364 | 363 | end |
365 | 364 |
|
|
0 commit comments