-
Notifications
You must be signed in to change notification settings - Fork 510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding LIME and Kernel SHAP #468
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vivekmig has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just finished a pass. Might have missed some things as the PR is quite large.
Overall it LGTM. I left some general comments and nits.
captum/attr/_core/lime.py
Outdated
|
||
curr_inputs = [] | ||
|
||
if len(curr_inputs) > 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It might be better to refactor the inside of this into a utility function or a local function due to duplicate code.
|
||
combined_interp_inps = torch.cat(interpretable_inps) | ||
combined_outputs = ( | ||
torch.cat(outputs) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does cat here assume batch size == 1?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it's just expected to be a 1D, which should be the case, since it's flattened above.
[73.3716, 193.3349, 113.3349], | ||
perturbations_per_eval=(1, 2, 3), | ||
n_samples=500, | ||
expected_coefs_only=[73.3716, 193.3349, 113.3349], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just commenting so I understand the tests better - is this dependent on n_samples + seed + order that these tests execute as the random generator is not patched (due to the samples generated)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good question! For most of these, I computed the "exact" expected output manually. In KernelSHAP, these are nicely exactly equal to true Shapley Values, which are straightforward to compute. For LIME, I manually constructed the expected distribution (containing all possible samples in the distribution weighted appropriately) and trained a Lasso regression model on that to obtain the unique expected result. The output here does depend on n_samples, seed and order, but with sufficient samples, it should converge to within epsilon of the exact expected result. I've generally set the delta threshold and n_samples large enough to avoid any issue with random seed, but it's possible that if these parameters are too low for any case, the variance could be an issue.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the reviews @NarineK and @miguelmartin75 ! I addressed the comments and answered any questions in the comments. Will switch to the interpretable model interface for training and add tests for both PyTorch and sklearn models in a separate diff once those changes land.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vivekmig has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vivekmig has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good! Thank you for working on this. Left couple nits, questions.
Addressed all comments, thanks for the reviews :) ! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vivekmig has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@vivekmig has imported this pull request. If you are a Facebook employee, you can view this diff on Phabricator.
Summary: Adding LIME and Kernel SHAP to Captum. This includes all documentation, type hints, and data parallel / JIT tests. Pull Request resolved: pytorch#468 Reviewed By: NarineK Differential Revision: D23733322 Pulled By: vivekmig fbshipit-source-id: 1ecc21306493ce4bd84ce175d4e08c21aaa49083
Adding LIME and Kernel SHAP to Captum. This includes all documentation, type hints, and data parallel / JIT tests.