Skip to content

[Feature request] To be able to intercept module calls for an existing model. #1355

Closed
@zpolina

Description

@zpolina

We are interested in enabling a ‘hook’(‘interceptor’) into calls to Flax modules to add functionality for various purposes. For example, hooks would make it possible to intercept module function calls by changing the input or perturbing the output. It can also be used to access or modify the module gradients. The main motivation for this is to avoid forking existing off-the-shelf models and create code duplication when it could be avoided. It is also not always feasible to make changes to existing models. For example, a user might want to access a certain intermediate output or gradient for a large number of models.

Main use cases for this functionality are about perturbing/pruning models or explaining models behavior by looking into gradients. Often the perturbation parameters can be learnable, and injected into model particular layers. For example:

def intercept_method(mdl, fun, *args, **kwargs):
  if fun.__name__ == '__call__':
    y = fun(mdl, *args, **kwargs)
    mdl.sow('intermediates', fun.__name__, y)
    return y
  return fun(mdl, *args, **kwargs)

It might be convenient for users if capture_intermediates stays, but it could be implemented as an intercept method.

Other frameworks:

  • PyTorch allows a user to register hooks using register_forward_hook and register_forward_pre_hook. It also has a register_full_backward_hook that is called during the backward pass.
  • In the JAX ecosystem, Haiku has an intercept_methods function.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Priority: P2 - no scheduleBest effort response and resolution. We have no plan to work on this at the moment.

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions