@@ -13,24 +13,6 @@ using Zygote:
1313 withgradient,
1414 withjacobian
1515
16- struct ZygoteNothingError <: Exception
17- f
18- x
19- contexts
20- end
21-
22- function Base. showerror (io:: IO , e:: ZygoteNothingError )
23- (; f, x, contexts) = e
24- sig = (typeof (x), map (typeof ∘ DI. unwrap, contexts)... )
25- return print (
26- io,
27- " Zygote failed to differentiate function `$f ` with argument types `$sig ` (the pullback returned `nothing`)." ,
28- )
29- end
30-
31- check_nothing (:: Nothing , f, x, contexts) = throw (ZygoteNothingError (f, x, contexts))
32- check_nothing (:: Any , f, x, contexts) = nothing
33-
3416DI. check_available (:: AutoZygote ) = true
3517DI. inplace_support (:: AutoZygote ) = DI. InPlaceNotSupported ()
3618
@@ -64,7 +46,6 @@ function DI.value_and_pullback(
6446 tx = map (ty) do dy
6547 first (pb (dy))
6648 end
67- check_nothing (first (tx), f, x, contexts)
6849 return y, tx
6950end
7051
@@ -80,7 +61,6 @@ function DI.value_and_pullback(
8061 tx = map (ty) do dy
8162 first (pb (dy))
8263 end
83- check_nothing (first (tx), f, x, contexts)
8464 return copy (y), tx
8565end
8666
@@ -96,7 +76,6 @@ function DI.pullback(
9676 tx = map (ty) do dy
9777 first (pb (dy))
9878 end
99- check_nothing (first (tx), f, x, contexts)
10079 return tx
10180end
10281
@@ -110,15 +89,13 @@ function DI.value_and_gradient(
11089 f, :: DI.NoGradientPrep , :: AutoZygote , x, contexts:: Vararg{DI.Context,C}
11190) where {C}
11291 (; val, grad) = withgradient (f, x, map (translate, contexts)... )
113- check_nothing (first (grad), f, x, contexts)
11492 return val, first (grad)
11593end
11694
11795function DI. gradient (
11896 f, :: DI.NoGradientPrep , :: AutoZygote , x, contexts:: Vararg{DI.Context,C}
11997) where {C}
12098 grad = gradient (f, x, map (translate, contexts)... )
121- check_nothing (first (grad), f, x, contexts)
12299 return first (grad)
123100end
124101
@@ -147,15 +124,13 @@ function DI.value_and_jacobian(
147124 y = f (x, map (translate, contexts)... )
148125 # https://github.com/FluxML/Zygote.jl/issues/1506
149126 jac = jacobian (f, x, map (translate, contexts)... )
150- check_nothing (first (jac), f, x, contexts)
151127 return y, first (jac)
152128end
153129
154130function DI. jacobian (
155131 f, :: DI.NoJacobianPrep , :: AutoZygote , x, contexts:: Vararg{DI.Context,C}
156132) where {C}
157133 jac = jacobian (f, x, map (translate, contexts)... )
158- check_nothing (first (jac), f, x, contexts)
159134 return first (jac)
160135end
161136
@@ -242,7 +217,6 @@ function DI.hessian(
242217) where {C}
243218 fc = DI. with_contexts (f, contexts... )
244219 hess = hessian (fc, x)
245- check_nothing (hess, f, x, contexts)
246220 return hess
247221end
248222
0 commit comments