Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an implementation to intercept method calls. #1356

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
41 changes: 38 additions & 3 deletions flax/linen/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,13 @@ def capture_stack(self):
self._thread_data.capture_stack = []
return self._thread_data.capture_stack

@property
def intercept_stack(self):
"""Keeps track of the active intercept functions."""
if not hasattr(self._thread_data, 'intercept_stack'):
self._thread_data.intercept_stack = []
return self._thread_data.intercept_stack

# The global context
_context = _DynamicContext()

Expand Down Expand Up @@ -275,7 +282,14 @@ def wrapped_module_method(*args, **kwargs):
self._state.in_compact_method = True
_context.module_stack.append(self)
try:
y = fun(self, *args, **kwargs)
# Apply intercept_method if one was provided to apply().
if _context.intercept_stack and _context.intercept_stack[-1]:
intercept_method = _context.intercept_stack[-1]
# Intercept method is responsible for calling `fun`.
y = intercept_method(self, fun, args, kwargs)
else:
y = fun(self, *args, **kwargs)
# Sow output `capture_intermediates` was provided to apply.
if _context.capture_stack:
filter_fn = _context.capture_stack[-1]
if filter_fn and filter_fn(self, fun.__name__):
Expand Down Expand Up @@ -890,6 +904,7 @@ def apply(self,
rngs: Optional[RNGSequences] = None,
method: Callable[..., Any] = None,
mutable: CollectionFilter = False,
intercept_method: Optional[Callable[..., Any]] = None,
capture_intermediates: Union[bool, Callable[['Module', str],
bool]] = False,
**kwargs) -> Union[Any, Tuple[Any, FrozenVariableDict]]:
Expand Down Expand Up @@ -931,6 +946,13 @@ def other_fn(instance, ...):
treated as mutable: ``bool``: all/no collections are mutable.
``str``: The name of a single mutable collection. ``list``: A
list of names of mutable collections.
intercept_method: An optional hook that intercepts all function calls and
can be used to modify args and kwargs before calling each original
method, and also to modify its output. The provided method needs to
take the form `f(module, original_function, *args, **kwargs)` and is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This signature will fail if users write an intercept_method where kwargs clash with the first 2 arguments. Like (eg.: mdl, mod, fn, fun, etc.) these are all legit names for the first 2 args but could easily happen in kwargs too. An alternative if to make the signature f(module, original_function, args, kwargs) so we can't have any clashes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I went with your suggestion, so the interceptor now takes f(module, original_function, args, kwargs).

responsible for calling `original_function(module, *args, **kwargs)`.
Can be used to temporarily modify a model without making changes to
the underlying code.
capture_intermediates: If `True`, captures intermediate return values
of all Modules inside the "intermediates" collection. By default only
the return values of all ``__call__`` methods are stored. A function can
Expand All @@ -947,7 +969,9 @@ def other_fn(instance, ...):
method = _get_unbound_fn(method)
return apply(
method, self,
mutable=mutable, capture_intermediates=capture_intermediates
mutable=mutable,
intercept_method=intercept_method,
capture_intermediates=capture_intermediates
)(variables, *args, **kwargs, rngs=rngs)

def init_with_output(self,
Expand Down Expand Up @@ -1153,6 +1177,7 @@ def __call__(self, train: Optional[bool] = None):

def apply(fn: Callable[..., Any], module: Module,
mutable: CollectionFilter = False,
intercept_method: Optional[Callable[..., Any]] = None,
capture_intermediates: Union[bool, Callable[[Module, str], bool]] = False) -> Callable[..., Any]:
"""Creates an apply function to call ``fn`` with a bound module.

Expand Down Expand Up @@ -1185,22 +1210,32 @@ def f(foo, x):
treated as mutable: ``bool``: all/no collections are mutable.
``str``: The name of a single mutable collection. ``list``: A
list of names of mutable collections.
intercept_method: An optional hook that intercepts all function calls
that occur when `module` is applied, and can be used to modify
the args and kwargs of each call, and also to
modify its output. The provided method needs to have the signature
`f(module, original_function, args, kwargs)` and is responsible for
calling `original_function(module, *args, **kwargs)`. Can be used to
temporarily modify a model without making changes to the underlying code.
capture_intermediates: If `True`, captures intermediate return values
of all Modules inside the "intermediates" collection. By default only
the return values of all `__call__` methods are stored. A function can
be passed to change the filter behavior. The filter function takes
the Module instance and method name and returns a bool indicating
whether the output of that method invocation should be stored.

Returns:
The apply function wrapping ``fn``.
"""
@functools.wraps(fn)
def scope_fn(scope, *args, **kwargs):
_context.capture_stack.append(capture_intermediates)
_context.intercept_stack.append(intercept_method)
try:
return fn(module.clone(parent=scope), *args, **kwargs)
finally:
_context.capture_stack.pop()
_context.intercept_stack.pop()

if capture_intermediates is True:
capture_intermediates = capture_call_intermediates
Expand Down Expand Up @@ -1295,4 +1330,4 @@ def f(foo, x):
@functools.wraps(init_fn)
def init_wrapper(*args, **kwargs):
return init_fn(*args, **kwargs)[1]
return init_wrapper
return init_wrapper
159 changes: 159 additions & 0 deletions tests/linen/module_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import jax
from jax import random
from jax import lax
from jax._src.api import vmap
from jax.nn import initializers
import jax.numpy as jnp

Expand Down Expand Up @@ -1240,6 +1241,164 @@ def __call__(self, x):
'intermediates': {'Bar_0': {'test': (2,)}}
})

def test_intercept_method(self):
"""Test that intercept_method is called correctly."""
class Bar(nn.Module):
def test(self, x):
return x + 1

class Foo(nn.Module):
@nn.compact
def __call__(self, x):
return Bar().test(x) + 1

def intercept_method(mdl, fun, args, kwargs):
# Add 1 to the output of Bar.test().
if isinstance(mdl, Bar):
return fun(mdl, *args, **kwargs) + 1
return fun(mdl, *args, **kwargs)

output = Foo().apply({}, 1, intercept_method=intercept_method)
self.assertEqual(output, 4)

def test_intercept_method_works_for_gradients(self):
"""Test that intercept_method works correctly for taking gradients."""
class MyModel(nn.Module):
hidden_size: int = 10
output_size: int = 1

@nn.compact
def __call__(self, inputs):
hidden = nn.Dense(self.hidden_size, name='layer1')(inputs)
hidden = nn.Dense(self.hidden_size, name='layer2')(hidden)
output = nn.Dense(self.output_size, name='output_layer')(hidden)
return output

class Foo(nn.Module):
@nn.compact
def __call__(self, x):
output = MyModel()(x)
return jnp.squeeze(output)

input_size = 5
inputs = np.random.RandomState(0).normal(size=(input_size))
# This intercept method will allow to capture the gradients of every nn.Dense method.

def intercept_method(mdl, fun, args, kwargs):
# Add eps to the input of nn.Dense __call__.
if isinstance(mdl, nn.Dense):
inputs, = args
eps = mdl.variable('inter_grads', 'activation',
lambda: jnp.zeros_like(inputs, dtype=jnp.float32))
inputs = inputs + eps.value
y = fun(mdl, inputs, **kwargs)
return y
return fun(mdl, *args, **kwargs)

model = Foo()
rng_key = jax.random.PRNGKey(0)
variables = model.init(rng_key, inputs, intercept_method=intercept_method)

def with_respect_to_inputs(inputs):
output = model.apply(
variables, inputs, intercept_method=intercept_method)
return output

def with_respect_to_vars(variables):
output = model.apply(
variables, inputs, intercept_method=intercept_method)
return output

expected_grads = jax.grad(with_respect_to_inputs)(inputs)
grads = jax.grad(with_respect_to_vars)(variables)
# Since we have 3 Dense layers we should get 3 vectors of gradients.
self.assertEqual(len(grads['inter_grads']['MyModel_0']), 3)
# Check that the gradients with respect to inputs of the first Dense layer
# match the gradients with respect to inputs to the Foo model.
np.testing.assert_array_equal(
expected_grads,
grads['inter_grads']['MyModel_0']['layer1']['activation'])
# Check the dimentionality of the computed grads with respect to inputs
# the the 2nd and 3rd Dense layers, it should match the dimentionalities
# of the inputs to the Dense layers.
self.assertEqual(grads['inter_grads']['MyModel_0']
['layer2']['activation'].shape, (10,))
self.assertEqual(grads['inter_grads']['MyModel_0']
['output_layer']['activation'].shape, (10,))

def test_intercept_method_with_multiple_functions(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please also test intercept method with transformations (vmap and/or scan)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the vmap test, wasn't sure if you wanted to wrap the apply method with vmap or the whole module with decorator, went with apply. But let me know if that works for you.

"""Tests that intercept_method is called on each function."""
class Bar(nn.Module):

def __call__(self, x):
return x + 2

def test(self, x):
return x + 1

class Foo(nn.Module):
@nn.compact
def __call__(self, x):
return Bar().test(x) + Bar()(x) + 1

def intercept_method(mdl, fun, args, kwargs):
# Add 1 to the output of Bar methods.
if isinstance(mdl, Bar):
return fun(mdl, *args, **kwargs) + 1
return fun(mdl, *args, **kwargs)

output = Foo().apply({}, 1, intercept_method=intercept_method)
self.assertEqual(output, 8)

def test_intercept_method_with_vmap(self):
"""Tests that intercept_method works with vmap."""

class Foo(nn.Module):
@nn.compact
def __call__(self, x):
return x + 1

def intercept_method(mdl, fun, args, kwargs):
if isinstance(mdl, Foo):
output = fun(mdl, *args, **kwargs)
return jnp.dot(output, output)
return fun(mdl, *args, **kwargs)

def model_apply(inputs):
return Foo().apply({}, inputs, intercept_method=intercept_method)

vmapped_apply = jax.vmap(model_apply, in_axes=1)
output = vmapped_apply(np.array([[0, 1, 2], [0, 1, 2]]))
np.testing.assert_array_equal(output, [2, 8, 18])

def test_intercept_method_that_captures_intermediate_output(self):
"""Applies self.sow() to the output of a call using intercept_method."""
class Bar(nn.Module):
def test(self, x):
return x + 1

class Foo(nn.Module):
@nn.compact
def __call__(self, x):
return Bar().test(x) + 1

def interceptor(mdl, fun, args, kwargs):
y = fun(mdl, *args, **kwargs)
mdl.sow('intermediates', fun.__name__, y)
return y

def intercept_method(mdl, fun, args, kwargs):
# Sow the output of Bar functions.
if isinstance(mdl, Bar):
return interceptor(mdl, fun, args, kwargs)
return fun(mdl, *args, **kwargs)

_, state = Foo().apply({}, 1, intercept_method=intercept_method,
mutable=['intermediates'])
self.assertEqual(state, {
'intermediates': {'Bar_0': {'test': (2,)}}
})

def test_functional_apply(self):
class Foo(nn.Module):
def setup(self):
Expand Down