Feature/intermediates cache prefetch#2392
Feature/intermediates cache prefetch#2392GOavi101 wants to merge 4 commits intovllm-project:mainfrom
Conversation
…loading - IntermediatesCache.iter_prefetch() overlaps onload of next batch with consumption of current batch via a background thread - AWQ _run_samples uses iter_prefetch when offload_device is set to overlap CPU->device transfer with module forward passes - Add test_iter_prefetch_matches_iter to verify prefetch yields same results as iter Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
|
👋 Hi! Thank you for contributing to llm-compressor. Please add the ready label when the PR is ready for review. Note: This is required to complete the testing suite, please only add the label once the PR is code complete and local testing has been performed. |
Summary of ChangesHello @GOavi101, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the performance of AWQ (Activation-aware Weight Quantization) offloading by implementing a prefetching mechanism for the intermediates cache. By loading the next batch of data in a background thread while the current batch is being processed, the changes aim to overlap data transfer with computation, thereby reducing overall execution time, especially when offloading to the CPU. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces an optional prefetching mechanism to the IntermediatesCache to optimize performance by overlapping CPU-to-device onload with the forward pass, particularly beneficial for AWQ when offloading to CPU. The new iter_prefetch() method in IntermediatesCache uses a background thread to prefetch the next batch, and the _run_samples() method in AWQModifier has been updated to leverage this feature conditionally. New unit tests have been added to ensure the correctness of the iter_prefetch() method, covering both empty cache scenarios and ensuring consistency with the regular iter() method. The changes are well-implemented and directly address the stated goal of reducing wall-clock time.
| module(**batch_kwargs) for batch_kwargs in self._parent_args_cache[module] | ||
| ] | ||
| cache = self._parent_args_cache[module] | ||
| # When offloading, prefetch overlaps CPU->device onload with forward pass. |
There was a problem hiding this comment.
The comment on this line exceeds the recommended line length of 79 characters, as per PEP 8. Please consider rephrasing or splitting it into multiple lines for better readability.
| # When offloading, prefetch overlaps CPU->device onload with forward pass. | |
| # Prefetching overlaps CPU->device onload with forward pass when offloading. |
|
Looks good, probably want the same gating as the other pr and also probably want to unify the two to use the same util and make the tests be specific to the shared functionality. |
|
The quality checks have failed. Please run |
…loading - IntermediatesCache.iter_prefetch() overlaps onload of next batch with consumption of current batch via a background thread - AWQ _run_samples uses iter_prefetch when offload_device is set to overlap CPU->device transfer with module forward passes - Add test_iter_prefetch_matches_iter to verify prefetch yields same results as iter Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
457af3e to
9128c2b
Compare
| assert batch_dicts_equal(b_iter, b_prefetch), f"batch {i} differs" | ||
|
|
||
|
|
||
| def deep_equal(a, b) -> bool: |
There was a problem hiding this comment.
why is this moved? better to leave it where it was
There was a problem hiding this comment.
This is still moved and adds a bunch of line changes that will alter the blame and history for no reason
e125440 to
9128c2b
Compare
|
The quality checks have failed. Please run |
…tate Signed-off-by: Avishek Goswami <avishek.goswami@ibm.com>
fe3d424 to
595fd90
Compare
| ] | ||
|
|
||
|
|
||
| def deep_equal(a, b) -> bool: |
There was a problem hiding this comment.
once you get rid of these line changes, we'll be good to land
HDCharles
left a comment
There was a problem hiding this comment.
need to still fix the deep equal thing, otherwise this looks good!
kylesayrs
left a comment
There was a problem hiding this comment.
The speedup from these changes needs to be tested before this can be merged. There are a lot of ways in which this can look good but due to implementation details not achieve the desired outcome.
I'm also not 100% sure I understand the theory here. According to the torch.Tensor.to documentation:
In general, the transfer is blocking on the device side (even if it isn’t on the host side): the copy on the device cannot occur while another operation is being executed
brian-dellabetta
left a comment
There was a problem hiding this comment.
Looks great! Thanks for adding this. I have one nit on variable naming convention, otherwise LGTM
| else: | ||
| session.state.loss_masks = None | ||
|
|
||
| use_prefetch = getattr(dataset_args, "sequential_prefetch", False) |
There was a problem hiding this comment.
nit -- I prefer we use a single config var name for this for consistency
kylesayrs
left a comment
There was a problem hiding this comment.
I think this just needs validation before landing.
We should also consider
- Having a dedicated "prefetch" thread
- Prefetching multiple samples (lookahead 5)
- Whether this can be accomplished using
torch.Tensor.to(non_blocking=True) - Whether a similar technique can be done when storing activations (similar to
non_blocking=True)
Optional prefetch was added to the intermediates cache and wired into AWQ when offloading.
IntermediatesCache
New method iter_prefetch() iterates over batches like iter() but prefetches the next batch in a background thread so onload from the offload device overlaps with use of the current batch, reducing wall‑clock time when offloading to CPU.
AWQ
When offload_device is set, _run_samples() uses cache.iter_prefetch() instead of the cache iterator so CPU→device onload overlaps with the forward pass over cached parent args during smoothing.
Tests
Two tests were added: one that prefetch yields the same batches as iter(), and one that prefetch on an empty cache yields nothing. No new public API; prefetch is used automatically when AWQ offloads.
Fix: #2374