-
Notifications
You must be signed in to change notification settings - Fork 418
feat(data): Support auto-tune for the number of grain workers #2510
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
This requires |
| multiprocessing_options = ( | ||
| grain.experimental.pick_performance_config( | ||
| ds=dataset, | ||
| ram_budget_mb=1024, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make ram_budget_mb configurable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks, that's a good point. I had set it to 1024 based on the Grain official tutorial, but I agree it's much better to make it configurable. This will allow users to tune the value for their specific environment. I'll make that change.
004a09b to
077ef4f
Compare
grain_worker_count to -1 for automatic performance tuning
Description
This PR introduces support for setting
grain_worker_count: -1to enable Grain's experimentalpick_performance_configfeature. This allows Grain to automatically determine the optimal number of data loading workers, significantly improving training performance and preventing data pipeline bottlenecks when enabled.Previously, users had to manually tune
grain_worker_countthrough trial and error to achieve good hardware utilization.The Problem
When tokenizing raw text data on the fly, the data input pipeline can easily become a bottleneck if the number of parallel workers (
grain_worker_count) is too low. This leads to poor accelerator utilization (low TFLOP/s), slow training steps, and a frustrating user experience.The Solution
By adding the option to set
grain_worker_count: -1, we give users the ability to delegate the selection of the worker count to Grain's built-in auto-tuning mechanism. This provides a robust option that adapts to the user's specific hardware and data, ensuring the input pipeline can keep up with the accelerators.The default value remains
1to provide a stable, low-resource baseline. Users can now explicitly setgrain_worker_count: -1to leverage this automatic performance tuning for high-throughput training.As shown in the performance tests below, this automatic configuration achieves stable, high-throughput training comparable to a manually optimized setting.
grain_worker_countThis change simplifies the user workflow for performance optimization and makes it easier to achieve optimal throughput without manual intervention.
Tests
The effectiveness of this feature was verified by running the training command below on a
v6e-32pod with different values forgrain_worker_countand observing the impact on TFLOP/s and step time.To reproduce, run the command with
grain_worker_countset to1(default),4,8, and-1(the new auto-tune option).python3 -m MaxText.train src/MaxText/configs/base.yml \ base_output_directory=${BASE_OUTPUT_DIRECTORY} \ run_name=$RUN_NAME \ dataset_type=grain \ grain_train_files=${DATA_PATH} \ grain_worker_count=-1 \ per_device_batch_size=2 \ model_name=llama3-8b \ steps=10 \ max_target_length=8192 \ enable_checkpointing=false \ attention=flash \ dtype=bfloat16Confirm that
grain_worker_count=-1results in stable and high TFLOP/s (~195 on v6e-32) and low step times (~4.3s), consistent with the performance of a manually tuned optimal value like 4 or 8.Checklist
Before submitting this PR, please make sure (put X in square brackets):
gemini-reviewlabel.fixes #2509