Skip to content

Expose an API for the accumulated/total gradients for shared parameters #211

@jondeuce

Description

@jondeuce

Motivation and description

It would be useful to have a public API for the two steps within update! separately:

function update!(tree, model, grad, higher...)
# First walk is to accumulate the gradient. This recursion visits every copy of
# shared leaves, but stops when branches are absent from the gradient:
grads = IdDict{Leaf, Any}()
_grads!(grads, tree, model, grad, higher...)
# Second walk is to update the model. The params cache indexed by (tree,x),
# so that identified Leafs can tie isbits parameters, but setup won't do that for you:
newmodel = _update!(tree, model; grads, params = IdDict())
tree, newmodel # note that tree is guaranteed to be updated. Also that it's not necc a tree.
end

In the total gradient is typically more useful than the individual gradient contributions. The use case where this came up was for tracking parameter gradient norms: near convergence the shared gradients may sum to zero, but generally each contribution will be non-zero, so summing the component norms will give you the wrong idea about the gradient size.

Possible Implementation

A trivial implementation would just refactor update! into two functions:

function update!(tree, model, grad, higher...)
    # First walk is to accumulate the gradient. This recursion visits every copy of
    # shared leaves, but stops when branches are absent from the gradient:
    grads = total_gradients(tree, model, grad, higher...)
    # Second walk is to update the model. The params cache indexed by (tree,x),
    # so that identified Leafs can tie isbits parameters, but setup won't do that for you:
    return update!(tree, model, grads)
end

function total_gradients(tree, x, x̄s...)
    grads = IdDict{Leaf, Any}()
    _grads!(grads, tree, x, x̄s...)
    return grads
end

function update!(tree, model, grads::IdDict)
    newmodel = _update!(tree, model; grads, params = IdDict())
    tree, newmodel  # note that tree is guaranteed to be updated. Also that it's not necc a tree.
end

But anything along these lines would be great.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions