perf(core): Explicitly set PyTorch intra-op threads.#21
Conversation
Set the torch number of threads to max(1, num_cpus // 2 // num_ranks // thread_count), help to improve writing performance and also resolve the runtime error when using tensor.copy_() with multiple write threads.
Summary of ChangesHello @Leahlijuan, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request introduces a critical enhancement to how PyTorch's intra-operation parallelism is managed during data writing operations. By explicitly setting the number of PyTorch threads based on system resources and the number of concurrent writers, the change aims to significantly improve writing performance and resolve specific runtime errors that can occur when Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a change to explicitly set the number of intra-op threads for PyTorch during checkpoint writing. This is a good optimization to improve performance and prevent potential race conditions with multiple writer threads. The use of a try...finally block to restore the original thread count is also a good practice.
I've found one critical issue that could lead to a ZeroDivisionError on CPU-only machines. Please see my detailed comment.
g-husam
left a comment
There was a problem hiding this comment.
left some minor comments
Fixes #18 Confirmed it works by removing the license header for `CMakeLists.txt` here: https://github.com/google/ml-flashpoint/pull/27/checks?sha=7bb36ee0cf3c9d14ff2a74e532f6e39afacf0350
Clarifying where the improved read/write speedups were exactly. Also adding note on infra-agnosticism, and removing the WIP label from the doc site.
#19) - Extracts context recovery logic from DefaultMLFlashpointCheckpointLoader - Introduces _get_extra_local_objects and _get_extra_needed_objects to DefaultMLFlashpointCheckpointLoader - Updates NeMo wrapper to instantiate the new NeMoMLFlashpointCheckpointLoader --------- Co-authored-by: g-husam <husameldawi@google.com>
g-husam
left a comment
There was a problem hiding this comment.
small tweak to get_accelerator_count and comment placing and good to go
| thread_count = max(thread_count, 1) | ||
| num_cpus = os.cpu_count() or 1 | ||
| num_ranks = max(get_accelerator_count(), 1) | ||
| torch_thread_count = max(1, num_cpus // 2 // num_ranks // thread_count) |
There was a problem hiding this comment.
as a future improvement, we can make the percentage of CPUs used configurable. here we are assuming 50% usage (by dividing by 2), but some scenarios may desire using more or less.
actually one more important thing - we should add a test that catches the failure this is fixing. So if we comment out this fix, that test should fail, and with it, it should pass. I think the failure was whenever we use multiple write threads or something? |
confirm test that catches the failure this is fixing
This won't be able to catch through unit tests. The error only happens when running actual training |
Co-authored-by: g-husam <husameldawi@google.com>
Set the torch number of threads to max(1, num_cpus // 2 // num_ranks // thread_count), help to improve writing performance and also resolve the runtime error when using tensor.copy_() with multiple write threads.