Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add an implementation to intercept method calls. #1356

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

bastings
Copy link
Contributor

@bastings bastings commented Jun 7, 2021

What does this PR do?

Fixes #1355

In particular, it adds an intercept_method kwarg to Module.apply that allows users to specify a method that can intercept calls to module methods.

Checklist

  • This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other
    checks if that's the case).
  • This change is discussed in a Github issue/
    discussion (please add a
    link).
  • The documentation and docstrings adhere to the
    documentation guidelines.
  • This change includes necessary high-coverage tests.
    (No quality testing = no merge!)

@google-cla google-cla bot added the cla: yes label Jun 7, 2021
@bastings bastings requested a review from avital June 7, 2021 14:05
@avital
Copy link
Contributor

avital commented Jun 7, 2021

Nice @bastings -- this looks very promising! I skimmed it but haven't yet read the implementation in full detail.

A few questions:

  1. Can capture_intermediates be implemented in terms of intercept_method?
  2. Can you please remove the formatting changes in module_test.py? That'll make it easier to review and less likely to have merge conflicts.

@avital avital requested a review from jheek June 7, 2021 18:36
@bastings
Copy link
Contributor Author

bastings commented Jun 8, 2021

hi!

  1. Yes! We left that question for you to decide. We currently kept both to not break things, but you could definitely deprecate capture_intermediates. We even wrote a test that shows how to do it.
  2. Yes :) the autoformatter did this though, so please consider a separate commit that does the same to make it possible to format code :)

@jheek
Copy link
Member

jheek commented Jun 8, 2021

Yes! We left that question for you to decide. We currently kept both to not break things, but you could definitely deprecate capture_intermediates. We even wrote a test that shows how to do it.

We probably want to keep the separate APIs (at least for now). But I think it's better to keep just one implementation so under the hood I think sow should be implemented as a hook.

self.assertEqual(grads['inter_grads']['MyModel_0']
['output_layer']['activation'].shape, (10,))

def test_intercept_method_with_multiple_functions(self):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please also test intercept method with transformations (vmap and/or scan)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have added the vmap test, wasn't sure if you wanted to wrap the apply method with vmap or the whole module with decorator, went with apply. But let me know if that works for you.

intercept_method: An optional hook that intercepts all function calls and
can be used to modify args and kwargs before calling each original
method, and also to modify its output. The provided method needs to
take the form `f(module, original_function, *args, **kwargs)` and is
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This signature will fail if users write an intercept_method where kwargs clash with the first 2 arguments. Like (eg.: mdl, mod, fn, fun, etc.) these are all legit names for the first 2 args but could easily happen in kwargs too. An alternative if to make the signature f(module, original_function, args, kwargs) so we can't have any clashes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I went with your suggestion, so the interceptor now takes f(module, original_function, args, kwargs).

@codecov-commenter
Copy link

Codecov Report

Merging #1356 (1788c79) into master (8846461) will increase coverage by 0.02%.
The diff coverage is 90.90%.

Impacted file tree graph

@@            Coverage Diff             @@
##           master    #1356      +/-   ##
==========================================
+ Coverage   82.34%   82.37%   +0.02%     
==========================================
  Files          65       65              
  Lines        5318     5327       +9     
==========================================
+ Hits         4379     4388       +9     
  Misses        939      939              
Impacted Files Coverage Δ
flax/linen/module.py 95.12% <90.90%> (+0.10%) ⬆️

Continue to review full report at Codecov.

Legend - Click here to learn more
Δ = absolute <relative> (impact), ø = not affected, ? = missing data
Powered by Codecov. Last update 8846461...1788c79. Read the comment docs.

bastings and others added 2 commits June 9, 2021 11:30
@jheek jheek added Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. Status: blocked The issue/PR is blocked by another issue/PR. labels Jun 17, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
cla: yes Priority: P2 - no schedule Best effort response and resolution. We have no plan to work on this at the moment. Status: blocked The issue/PR is blocked by another issue/PR.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Feature request] To be able to intercept module calls for an existing model.
5 participants