Skip to content

perf(core): Explicitly set PyTorch intra-op threads.#21

Merged
Leahlijuan merged 19 commits intomainfrom
feature/torchthreads
Feb 9, 2026
Merged

perf(core): Explicitly set PyTorch intra-op threads.#21
Leahlijuan merged 19 commits intomainfrom
feature/torchthreads

Conversation

@Leahlijuan
Copy link
Collaborator

Set the torch number of threads to max(1, num_cpus // 2 // num_ranks // thread_count), help to improve writing performance and also resolve the runtime error when using tensor.copy_() with multiple write threads.

Set the torch number of threads to max(1, num_cpus // 2 // num_ranks // thread_count), help to improve writing performance and also resolve the runtime error when using tensor.copy_() with multiple write threads.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @Leahlijuan, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces a critical enhancement to how PyTorch's intra-operation parallelism is managed during data writing operations. By explicitly setting the number of PyTorch threads based on system resources and the number of concurrent writers, the change aims to significantly improve writing performance and resolve specific runtime errors that can occur when tensor.copy_() is invoked by multiple threads simultaneously. The implementation also ensures robust resource management by restoring the original PyTorch thread configuration upon completion.

Highlights

  • PyTorch Intra-Op Thread Management: Explicitly sets the number of PyTorch intra-operation threads within the write_data function to optimize performance and prevent runtime errors when using tensor.copy_() with multiple write threads.
  • Dynamic Thread Calculation: The PyTorch thread count is dynamically calculated based on the number of CPU cores, CUDA device ranks, and the specified thread_count for the write operation, using the formula max(1, num_cpus // 2 // num_ranks // thread_count).
  • Resource Cleanup and Safety: The original PyTorch thread count is saved and then restored in a try...finally block, ensuring that the global PyTorch thread settings are reverted after the write_data operation completes, regardless of success or failure.
  • Enhanced Logging: Added detailed debug logging to show the original, calculated, and applied PyTorch thread counts, along with CPU and rank information.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a change to explicitly set the number of intra-op threads for PyTorch during checkpoint writing. This is a good optimization to improve performance and prevent potential race conditions with multiple writer threads. The use of a try...finally block to restore the original thread count is also a good practice.

I've found one critical issue that could lead to a ZeroDivisionError on CPU-only machines. Please see my detailed comment.

@Leahlijuan Leahlijuan self-assigned this Feb 2, 2026
@Leahlijuan Leahlijuan requested review from g-husam and kkkapu February 3, 2026 16:44
Copy link
Collaborator

@g-husam g-husam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

left some minor comments

g-husam and others added 7 commits February 4, 2026 15:53
Clarifying where the improved read/write speedups were exactly. Also
adding note on infra-agnosticism, and removing the WIP label from the
doc site.
#19)

- Extracts context recovery logic from
DefaultMLFlashpointCheckpointLoader

- Introduces _get_extra_local_objects and _get_extra_needed_objects to
DefaultMLFlashpointCheckpointLoader

- Updates NeMo wrapper to instantiate the new
NeMoMLFlashpointCheckpointLoader

---------

Co-authored-by: g-husam <husameldawi@google.com>
@Leahlijuan Leahlijuan requested a review from g-husam February 5, 2026 21:57
g-husam
g-husam previously approved these changes Feb 5, 2026
Copy link
Collaborator

@g-husam g-husam left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

small tweak to get_accelerator_count and comment placing and good to go

thread_count = max(thread_count, 1)
num_cpus = os.cpu_count() or 1
num_ranks = max(get_accelerator_count(), 1)
torch_thread_count = max(1, num_cpus // 2 // num_ranks // thread_count)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as a future improvement, we can make the percentage of CPUs used configurable. here we are assuming 50% usage (by dividing by 2), but some scenarios may desire using more or less.

@g-husam
Copy link
Collaborator

g-husam commented Feb 5, 2026

small tweak to get_accelerator_count and comment placing and good to go

actually one more important thing - we should add a test that catches the failure this is fixing. So if we comment out this fix, that test should fail, and with it, it should pass. I think the failure was whenever we use multiple write threads or something?

@g-husam g-husam self-requested a review February 5, 2026 22:10
@g-husam g-husam dismissed their stale review February 5, 2026 22:10

confirm test that catches the failure this is fixing

@Leahlijuan Leahlijuan closed this Feb 5, 2026
@g-husam g-husam changed the title feat(core): Explicitly set PyTorch intra-op threads. perf(core): Explicitly set PyTorch intra-op threads. Feb 5, 2026
@Leahlijuan
Copy link
Collaborator Author

small tweak to get_accelerator_count and comment placing and good to go

actually one more important thing - we should add a test that catches the failure this is fixing. So if we comment out this fix, that test should fail, and with it, it should pass. I think the failure was whenever we use multiple write threads or something?

This won't be able to catch through unit tests. The error only happens when running actual training

@Leahlijuan Leahlijuan reopened this Feb 5, 2026
@Leahlijuan Leahlijuan requested a review from g-husam February 9, 2026 17:17
@Leahlijuan Leahlijuan enabled auto-merge (squash) February 9, 2026 18:11
@Leahlijuan Leahlijuan merged commit d64060d into main Feb 9, 2026
5 checks passed
@Leahlijuan Leahlijuan deleted the feature/torchthreads branch February 9, 2026 18:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants