diff --git a/src/cr_api.jl b/src/cr_api.jl index ee7a350..3093f1d 100644 --- a/src/cr_api.jl +++ b/src/cr_api.jl @@ -158,7 +158,7 @@ const GENERATED_RRULE_CACHE = Dict() Generate `rrule` using Yota. """ function ChainRulesCore.rrule_via_ad(::YotaRuleConfig, f, args...) - res = rrule(f, args...) + res = rrule(YOTA_RULE_CONFIG, f, args...) !isnothing(res) && return res sig = map(typeof, (f, args...)) if haskey(GENERATED_RRULE_CACHE, sig)