Skip to content

Conversation

@bzantium
Copy link
Collaborator

@bzantium bzantium commented Oct 16, 2025

Description

This PR introduces support for setting grain_worker_count: -1 to enable Grain's experimental pick_performance_config feature. This allows Grain to automatically determine the optimal number of data loading workers, significantly improving training performance and preventing data pipeline bottlenecks when enabled.

Previously, users had to manually tune grain_worker_count through trial and error to achieve good hardware utilization.

The Problem

When tokenizing raw text data on the fly, the data input pipeline can easily become a bottleneck if the number of parallel workers (grain_worker_count) is too low. This leads to poor accelerator utilization (low TFLOP/s), slow training steps, and a frustrating user experience.

The Solution

By adding the option to set grain_worker_count: -1, we give users the ability to delegate the selection of the worker count to Grain's built-in auto-tuning mechanism. This provides a robust option that adapts to the user's specific hardware and data, ensuring the input pipeline can keep up with the accelerators.

The default value remains 1 to provide a stable, low-resource baseline. Users can now explicitly set grain_worker_count: -1 to leverage this automatic performance tuning for high-throughput training.

As shown in the performance tests below, this automatic configuration achieves stable, high-throughput training comparable to a manually optimized setting.

grain_worker_count Average TFLOP/s/device Average Time/Step (s) Stability
1 ~29 TFLOP/s ~30.6 s Unstable
2 ~60 TFLOP/s ~13.5 s Highly Unstable
4 ~195 TFLOP/s ~4.3 s Weakly Unstable
8 ~195 TFLOP/s ~4.3 s Stable
-1 (auto) ~195 TFLOP/s ~4.3 s Stable

This change simplifies the user workflow for performance optimization and makes it easier to achieve optimal throughput without manual intervention.

Tests

The effectiveness of this feature was verified by running the training command below on a v6e-32 pod with different values for grain_worker_count and observing the impact on TFLOP/s and step time.

To reproduce, run the command with grain_worker_count set to 1 (default), 4, 8, and -1 (the new auto-tune option).

python3 -m MaxText.train src/MaxText/configs/base.yml \
    base_output_directory=${BASE_OUTPUT_DIRECTORY} \
    run_name=$RUN_NAME \
    dataset_type=grain \
    grain_train_files=${DATA_PATH} \
    grain_worker_count=-1 \
    per_device_batch_size=2 \
    model_name=llama3-8b \
    steps=10 \
    max_target_length=8192 \
    enable_checkpointing=false \
    attention=flash \
    dtype=bfloat16

Confirm that grain_worker_count=-1 results in stable and high TFLOP/s (~195 on v6e-32) and low step times (~4.3s), consistent with the performance of a manually tuned optimal value like 4 or 8.

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code. For an optional AI review, add the gemini-review label.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed, including adding new documentation pages to the relevant Table of Contents (toctree directive) as explained in our documentation.

fixes #2509

@bzantium
Copy link
Collaborator Author

This requires grain>=0.2.13 and this can be resolved by #2354

@aireenmei aireenmei self-assigned this Oct 22, 2025
multiprocessing_options = (
grain.experimental.pick_performance_config(
ds=dataset,
ram_budget_mb=1024,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we make ram_budget_mb configurable?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, that's a good point. I had set it to 1024 based on the Grain official tutorial, but I agree it's much better to make it configurable. This will allow users to tune the value for their specific environment. I'll make that change.

@bzantium bzantium changed the title feat(data): Default grain_worker_count to -1 for automatic performance tuning feat(data): Support auto-tune for the number of grain workers Nov 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Automatically Optimize grain_worker_count for Improved Data Loading Performance

2 participants