Description
We are interested in enabling a ‘hook’(‘interceptor’) into calls to Flax modules to add functionality for various purposes. For example, hooks would make it possible to intercept module function calls by changing the input or perturbing the output. It can also be used to access or modify the module gradients. The main motivation for this is to avoid forking existing off-the-shelf models and create code duplication when it could be avoided. It is also not always feasible to make changes to existing models. For example, a user might want to access a certain intermediate output or gradient for a large number of models.
Main use cases for this functionality are about perturbing/pruning models or explaining models behavior by looking into gradients. Often the perturbation parameters can be learnable, and injected into model particular layers. For example:
-
Differential Masking (ArXiv) (GitHub example with a hook)
-
Parameter space noise for exploration (ArXiv) (GitHub example with a hook)
-
SparseML (Documentation) (GitHub example with a hook)
-
Class Activation Mapping (ArXiv) (GitHub example with a hook)
Note that the hook is a generalization of
capture_intermediates
, in the sense thatcapture_intermediates
can be implemented as an interceptor as in the following example for the case ofcapture_intermediates=True
:
def intercept_method(mdl, fun, *args, **kwargs):
if fun.__name__ == '__call__':
y = fun(mdl, *args, **kwargs)
mdl.sow('intermediates', fun.__name__, y)
return y
return fun(mdl, *args, **kwargs)
It might be convenient for users if capture_intermediates
stays, but it could be implemented as an intercept method.
Other frameworks:
- PyTorch allows a user to register hooks using
register_forward_hook
andregister_forward_pre_hook
. It also has aregister_full_backward_hook
that is called during the backward pass. - In the JAX ecosystem, Haiku has an
intercept_methods
function.