-
Notifications
You must be signed in to change notification settings - Fork 684
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add an implementation to intercept method calls. #1356
base: main
Are you sure you want to change the base?
Changes from all commits
4a73d73
0aeef0d
1c9a14c
1a1100a
1788c79
b023beb
6c5dbce
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -25,6 +25,7 @@ | |
import jax | ||
from jax import random | ||
from jax import lax | ||
from jax._src.api import vmap | ||
from jax.nn import initializers | ||
import jax.numpy as jnp | ||
|
||
|
@@ -1240,6 +1241,164 @@ def __call__(self, x): | |
'intermediates': {'Bar_0': {'test': (2,)}} | ||
}) | ||
|
||
def test_intercept_method(self): | ||
"""Test that intercept_method is called correctly.""" | ||
class Bar(nn.Module): | ||
def test(self, x): | ||
return x + 1 | ||
|
||
class Foo(nn.Module): | ||
@nn.compact | ||
def __call__(self, x): | ||
return Bar().test(x) + 1 | ||
|
||
def intercept_method(mdl, fun, args, kwargs): | ||
# Add 1 to the output of Bar.test(). | ||
if isinstance(mdl, Bar): | ||
return fun(mdl, *args, **kwargs) + 1 | ||
return fun(mdl, *args, **kwargs) | ||
|
||
output = Foo().apply({}, 1, intercept_method=intercept_method) | ||
self.assertEqual(output, 4) | ||
|
||
def test_intercept_method_works_for_gradients(self): | ||
"""Test that intercept_method works correctly for taking gradients.""" | ||
class MyModel(nn.Module): | ||
hidden_size: int = 10 | ||
output_size: int = 1 | ||
|
||
@nn.compact | ||
def __call__(self, inputs): | ||
hidden = nn.Dense(self.hidden_size, name='layer1')(inputs) | ||
hidden = nn.Dense(self.hidden_size, name='layer2')(hidden) | ||
output = nn.Dense(self.output_size, name='output_layer')(hidden) | ||
return output | ||
|
||
class Foo(nn.Module): | ||
@nn.compact | ||
def __call__(self, x): | ||
output = MyModel()(x) | ||
return jnp.squeeze(output) | ||
|
||
input_size = 5 | ||
inputs = np.random.RandomState(0).normal(size=(input_size)) | ||
# This intercept method will allow to capture the gradients of every nn.Dense method. | ||
|
||
def intercept_method(mdl, fun, args, kwargs): | ||
# Add eps to the input of nn.Dense __call__. | ||
if isinstance(mdl, nn.Dense): | ||
inputs, = args | ||
eps = mdl.variable('inter_grads', 'activation', | ||
lambda: jnp.zeros_like(inputs, dtype=jnp.float32)) | ||
inputs = inputs + eps.value | ||
y = fun(mdl, inputs, **kwargs) | ||
return y | ||
return fun(mdl, *args, **kwargs) | ||
|
||
model = Foo() | ||
rng_key = jax.random.PRNGKey(0) | ||
variables = model.init(rng_key, inputs, intercept_method=intercept_method) | ||
|
||
def with_respect_to_inputs(inputs): | ||
output = model.apply( | ||
variables, inputs, intercept_method=intercept_method) | ||
return output | ||
|
||
def with_respect_to_vars(variables): | ||
output = model.apply( | ||
variables, inputs, intercept_method=intercept_method) | ||
return output | ||
|
||
expected_grads = jax.grad(with_respect_to_inputs)(inputs) | ||
grads = jax.grad(with_respect_to_vars)(variables) | ||
# Since we have 3 Dense layers we should get 3 vectors of gradients. | ||
self.assertEqual(len(grads['inter_grads']['MyModel_0']), 3) | ||
# Check that the gradients with respect to inputs of the first Dense layer | ||
# match the gradients with respect to inputs to the Foo model. | ||
np.testing.assert_array_equal( | ||
expected_grads, | ||
grads['inter_grads']['MyModel_0']['layer1']['activation']) | ||
# Check the dimentionality of the computed grads with respect to inputs | ||
# the the 2nd and 3rd Dense layers, it should match the dimentionalities | ||
# of the inputs to the Dense layers. | ||
self.assertEqual(grads['inter_grads']['MyModel_0'] | ||
['layer2']['activation'].shape, (10,)) | ||
self.assertEqual(grads['inter_grads']['MyModel_0'] | ||
['output_layer']['activation'].shape, (10,)) | ||
|
||
def test_intercept_method_with_multiple_functions(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please also test intercept method with transformations (vmap and/or scan) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have added the vmap test, wasn't sure if you wanted to wrap the apply method with vmap or the whole module with decorator, went with apply. But let me know if that works for you. |
||
"""Tests that intercept_method is called on each function.""" | ||
class Bar(nn.Module): | ||
|
||
def __call__(self, x): | ||
return x + 2 | ||
|
||
def test(self, x): | ||
return x + 1 | ||
|
||
class Foo(nn.Module): | ||
@nn.compact | ||
def __call__(self, x): | ||
return Bar().test(x) + Bar()(x) + 1 | ||
|
||
def intercept_method(mdl, fun, args, kwargs): | ||
# Add 1 to the output of Bar methods. | ||
if isinstance(mdl, Bar): | ||
return fun(mdl, *args, **kwargs) + 1 | ||
return fun(mdl, *args, **kwargs) | ||
|
||
output = Foo().apply({}, 1, intercept_method=intercept_method) | ||
self.assertEqual(output, 8) | ||
|
||
def test_intercept_method_with_vmap(self): | ||
"""Tests that intercept_method works with vmap.""" | ||
|
||
class Foo(nn.Module): | ||
@nn.compact | ||
def __call__(self, x): | ||
return x + 1 | ||
|
||
def intercept_method(mdl, fun, args, kwargs): | ||
if isinstance(mdl, Foo): | ||
output = fun(mdl, *args, **kwargs) | ||
return jnp.dot(output, output) | ||
return fun(mdl, *args, **kwargs) | ||
|
||
def model_apply(inputs): | ||
return Foo().apply({}, inputs, intercept_method=intercept_method) | ||
|
||
vmapped_apply = jax.vmap(model_apply, in_axes=1) | ||
output = vmapped_apply(np.array([[0, 1, 2], [0, 1, 2]])) | ||
np.testing.assert_array_equal(output, [2, 8, 18]) | ||
|
||
def test_intercept_method_that_captures_intermediate_output(self): | ||
"""Applies self.sow() to the output of a call using intercept_method.""" | ||
class Bar(nn.Module): | ||
def test(self, x): | ||
return x + 1 | ||
|
||
class Foo(nn.Module): | ||
@nn.compact | ||
def __call__(self, x): | ||
return Bar().test(x) + 1 | ||
|
||
def interceptor(mdl, fun, args, kwargs): | ||
y = fun(mdl, *args, **kwargs) | ||
mdl.sow('intermediates', fun.__name__, y) | ||
return y | ||
|
||
def intercept_method(mdl, fun, args, kwargs): | ||
# Sow the output of Bar functions. | ||
if isinstance(mdl, Bar): | ||
return interceptor(mdl, fun, args, kwargs) | ||
return fun(mdl, *args, **kwargs) | ||
|
||
_, state = Foo().apply({}, 1, intercept_method=intercept_method, | ||
mutable=['intermediates']) | ||
self.assertEqual(state, { | ||
'intermediates': {'Bar_0': {'test': (2,)}} | ||
}) | ||
|
||
def test_functional_apply(self): | ||
class Foo(nn.Module): | ||
def setup(self): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This signature will fail if users write an intercept_method where kwargs clash with the first 2 arguments. Like (eg.: mdl, mod, fn, fun, etc.) these are all legit names for the first 2 args but could easily happen in kwargs too. An alternative if to make the signature
f(module, original_function, args, kwargs)
so we can't have any clashesThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, I went with your suggestion, so the interceptor now takes
f(module, original_function, args, kwargs)
.