Distributed BERT Pre-Training Workflow Using Flyte 2.0 and AWS Trainium#223
Open
samhita-alla wants to merge 3 commits intoawslabs:mainfrom
Open
Distributed BERT Pre-Training Workflow Using Flyte 2.0 and AWS Trainium#223samhita-alla wants to merge 3 commits intoawslabs:mainfrom
samhita-alla wants to merge 3 commits intoawslabs:mainfrom
Conversation
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
Contributor
|
Hi @samhita-alla, I messaged you on slack. Just wanted to make sure you received it. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
What does this PR do?
Issue: #158
This PR adds a complete Flyte 2.0 workflow for running distributed training on an EKS-deployed Flyte 2.0 backend (This was validated on a Union 2.0 deployment, but the workflow is fully compatible with any Flyte 2.0 deployment)
The example demonstrates BERT-Large pre-training on the FineWeb dataset using AWS Trainium.
What this training pipeline enables
Distributed training on Trainium
Configured with the
Elasticplugin. Extending to multi-node training is as simple as settingnnodes.Cached data preprocessing
The preprocessing task stores preprocessed PyTorch tensors in S3 and returns the path. When the same inputs are used, the task fully resolves from cache.
Drop-in distributed setup with the PyTorch Neuron SDK
The training task uses the Neuron SDK and defaults to a quick setup, but any configuration can be swapped in directly from the UI if desired.
Real-time metrics streaming in the UI
Loss curves and custom dashboards update live as training progresses (Trainium utilization metrics will be supported soon; CPU and memory metrics already appear in the UI)
Full visibility across the entire pipeline
Inspect inputs/outputs, view logs for both leader and worker processes, and trace every step end-to-end.
Built-in caching, retries, and error handling
Training tasks can be cached, retried at the task level, or retried via exception handling inside the task.
Native AWS integrations
S3 for datasets + checkpoints, CloudWatch for logging, ECR for images, etc.
No manual torchrun configuration
The
Elasticplugin automatically sets up and launches torch distributed; users just run the script withpython train.py.Crash-proof training end-to-end
Checkpoints and Neuron compilation cache are saved to blob storage every n steps. If an execution fails, users can resume from the exact step by simply providing the checkpoint + cache.
Historical metrics automatically restore and continue rendering from the resumed point.
Clear recovery guidance in logs
If training fails, checkpoint + cache paths are surfaced in the task logs.
Multi-phase training support
Phase 2 automatically consumes the output model from Phase 1. If a failure occurs mid-phase, training can resume directly from that point.
Motivation
The goal of this example is to demonstrate that distributed training, whether pre-training or fine-tuning, can be both effortless to experiment with and robust enough for production. A key requirement for ML teams is the ability to run the same workflow locally, in a lightweight test environment, and at full production scale without rewriting code or reconfiguring infrastructure. Flyte 2.0, combined with EKS, delivers exactly that.
By setting up Flyte on EKS once (typically done by a platform engineer), ML engineers can run complex distributed training jobs on Trainium or GPUs without touching any infrastructure. This clean separation of concerns (platform setup vs. model development) ensures rapid iteration, consistent execution environments, and reliable scaling as workloads grow. The workflow in this PR showcases how seamless that experience can be.
More
website/docsorwebsite/blogsection for this featurepre-commit run -awith this PR. Link for installing pre-commit locallyFor Moderators
Additional Notes
Haven’t updated the docs yet since wanted to get the team’s thoughts on the pipeline first. I can add a deeper dive into the platform setup in the documentation afterward. This PR currently includes only the training workflow code; the backend setup will be documented separately.