Skip to content

accum_param_gradients! does not support scale_factor for static functions #387

@ztangent

Description

@ztangent

As stated in the title, accum_param_gradients! does not support scale_factor for static functions. Calling accum_param_gradients! with a third argument returns ERROR: Not implemented, because it defaults to the abstract GFI definition.

This is due to (1) the lack of a generated method definition with the appropriate signature:

push!(generated_functions, quote
@generated function $(GlobalRef(Gen, :accumulate_param_gradients!))(trace::T, retval_grad) where {T<:$(QuoteNode(StaticIRTrace))}
$(QuoteNode(codegen_accumulate_param_gradients!))(trace, retval_grad)
end
end)

And (2) the lack of logic to handle a scale factor in the backward pass for trainable parameter nodes:

function back_codegen!(stmts, ir, selected_calls, fwd_marked, back_marked, node::TrainableParameterNode, mode)
# handle case when it is the return node
if node === ir.return_node && node in fwd_marked
@assert node in back_marked
push!(stmts, :(isnothing(retval_grad) && error("Required return value gradient but got nothing")))
push!(stmts, :($(gradient_var(node)) += retval_grad))
end
if node in fwd_marked && node in back_marked
cur_param_grad = :($(QuoteNode(get_param_grad))(trace.$static_ir_gen_fn_ref,
$(QuoteNode(node.name))))
push!(stmts, :($(QuoteNode(set_param_grad!))(trace.$static_ir_gen_fn_ref,
$(QuoteNode(node.name)),
$cur_param_grad + $(gradient_var(node)))))
end
end

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions