As stated in the title, accum_param_gradients! does not support scale_factor for static functions. Calling accum_param_gradients! with a third argument returns ERROR: Not implemented, because it defaults to the abstract GFI definition.
This is due to (1) the lack of a generated method definition with the appropriate signature:
|
push!(generated_functions, quote |
|
@generated function $(GlobalRef(Gen, :accumulate_param_gradients!))(trace::T, retval_grad) where {T<:$(QuoteNode(StaticIRTrace))} |
|
$(QuoteNode(codegen_accumulate_param_gradients!))(trace, retval_grad) |
|
end |
|
end) |
And (2) the lack of logic to handle a scale factor in the backward pass for trainable parameter nodes:
|
function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::TrainableParameterNode, mode) |
|
|
|
# handle case when it is the return node |
|
if node === ir.return_node && node in fwd_marked |
|
@assert node in back_marked |
|
push!(stmts, :(isnothing(retval_grad) && error("Required return value gradient but got nothing"))) |
|
push!(stmts, :($(gradient_var(node)) += retval_grad)) |
|
end |
|
|
|
if node in fwd_marked && node in back_marked |
|
cur_param_grad = :($(QuoteNode(get_param_grad))(trace.$static_ir_gen_fn_ref, |
|
$(QuoteNode(node.name)))) |
|
push!(stmts, :($(QuoteNode(set_param_grad!))(trace.$static_ir_gen_fn_ref, |
|
$(QuoteNode(node.name)), |
|
$cur_param_grad + $(gradient_var(node))))) |
|
end |
|
end |
As stated in the title,
accum_param_gradients!does not supportscale_factorfor static functions. Callingaccum_param_gradients!with a third argument returnsERROR: Not implemented, because it defaults to the abstract GFI definition.This is due to (1) the lack of a generated method definition with the appropriate signature:
Gen.jl/src/static_ir/backprop.jl
Lines 508 to 512 in e5ed96f
And (2) the lack of logic to handle a scale factor in the backward pass for trainable parameter nodes:
Gen.jl/src/static_ir/backprop.jl
Lines 169 to 185 in e5ed96f