Skip to content

Distributed BERT Pre-Training Workflow Using Flyte 2.0 and AWS Trainium#223

Open
samhita-alla wants to merge 3 commits intoawslabs:mainfrom
samhita-alla:add-flyte-trainium-blueprint
Open

Distributed BERT Pre-Training Workflow Using Flyte 2.0 and AWS Trainium#223
samhita-alla wants to merge 3 commits intoawslabs:mainfrom
samhita-alla:add-flyte-trainium-blueprint

Conversation

@samhita-alla
Copy link

@samhita-alla samhita-alla commented Nov 19, 2025

What does this PR do?

Issue: #158

This PR adds a complete Flyte 2.0 workflow for running distributed training on an EKS-deployed Flyte 2.0 backend (This was validated on a Union 2.0 deployment, but the workflow is fully compatible with any Flyte 2.0 deployment)

The example demonstrates BERT-Large pre-training on the FineWeb dataset using AWS Trainium.

What this training pipeline enables

Distributed training on Trainium

Configured with the Elastic plugin. Extending to multi-node training is as simple as setting nnodes.

trainium_env = flyte.TaskEnvironment(
    name="bert-trainium-training",
    image=flyte.Image.from_base(
        image_uri="public.ecr.aws/neuron/pytorch-training-neuronx:2.8.0-neuronx-py311-sdk2.26.0-ubuntu22.04"
    )
    .clone(name="bert-trainium-training")
    .with_env_vars({"UV_PYTHON": "/usr/local/bin/python3.11"}).with_apt_packages("git")
    .with_pip_packages(
        "git+https://github.com/flyteorg/flyte-sdk@a70370bbe348d52351beb6c2f4684efa5f387d46",
        "flyteplugins-pytorch==2.0.0b29",
        "transformers==4.57.1",
        "datasets==4.4.1",
        "tokenizers==0.22.1",
        "huggingface-hub==0.35.3",
    ),
    resources=flyte.Resources(
        cpu=110,
        memory="400Gi",
        # Trainium accelerator configuration
        gpu="Trn1:16",
    ),
    plugin_config=Elastic(
        nnodes=1,  # 1 Trainium instance (trn1.32xlarge)
        nproc_per_node=32,  # 32 NeuronCores per instance
    ),
    env_vars={
        "NEURON_RT_NUM_CORES": "32",  # Use all Neuron cores
        "NEURON_CC_FLAGS": "--model-type=transformer --distribution-strategy=llm-training --enable-mixed-precision-accumulation",
        "NEURON_COMPILE_CACHE_URL": "/var/tmp/neuron-compile-cache",  # Persistent cache
        "NEURON_FUSE_SOFTMAX": "1",  # Enable softmax fusion
        "NEURON_RT_STOCHASTIC_ROUNDING_EN": "1",  # Enable stochastic rounding for BF16
    },
    cache="auto",
)

Cached data preprocessing

The preprocessing task stores preprocessed PyTorch tensors in S3 and returns the path. When the same inputs are used, the task fully resolves from cache.

Drop-in distributed setup with the PyTorch Neuron SDK

The training task uses the Neuron SDK and defaults to a quick setup, but any configuration can be swapped in directly from the UI if desired.

Real-time metrics streaming in the UI

Loss curves and custom dashboards update live as training progresses (Trainium utilization metrics will be supported soon; CPU and memory metrics already appear in the UI)

Screenshot 2025-11-19 at 10 11 59 PM image

Full visibility across the entire pipeline

Inspect inputs/outputs, view logs for both leader and worker processes, and trace every step end-to-end.

image

Built-in caching, retries, and error handling

Training tasks can be cached, retried at the task level, or retried via exception handling inside the task.

Native AWS integrations

S3 for datasets + checkpoints, CloudWatch for logging, ECR for images, etc.

image

No manual torchrun configuration

The Elastic plugin automatically sets up and launches torch distributed; users just run the script with python train.py.

Crash-proof training end-to-end

Checkpoints and Neuron compilation cache are saved to blob storage every n steps. If an execution fails, users can resume from the exact step by simply providing the checkpoint + cache.

Historical metrics automatically restore and continue rendering from the resumed point.

image

Clear recovery guidance in logs

If training fails, checkpoint + cache paths are surfaced in the task logs.

Multi-phase training support

Phase 2 automatically consumes the output model from Phase 1. If a failure occurs mid-phase, training can resume directly from that point.

Motivation

The goal of this example is to demonstrate that distributed training, whether pre-training or fine-tuning, can be both effortless to experiment with and robust enough for production. A key requirement for ML teams is the ability to run the same workflow locally, in a lightweight test environment, and at full production scale without rewriting code or reconfiguring infrastructure. Flyte 2.0, combined with EKS, delivers exactly that.

By setting up Flyte on EKS once (typically done by a platform engineer), ML engineers can run complex distributed training jobs on Trainium or GPUs without touching any infrastructure. This clean separation of concerns (platform setup vs. model development) ensures rapid iteration, consistent execution environments, and reliable scaling as workloads grow. The workflow in this PR showcases how seamless that experience can be.

More

  • Yes, I have tested the PR using my local account setup (Provide any test evidence report under Additional Notes)
  • Mandatory for new blueprints. Yes, I have added a example to support my blueprint PR
  • Mandatory for new blueprints. Yes, I have updated the website/docs or website/blog section for this feature
  • Yes, I ran pre-commit run -a with this PR. Link for installing pre-commit locally

For Moderators

  • E2E Test successfully complete before merge?

Additional Notes

Haven’t updated the docs yet since wanted to get the team’s thoughts on the pipeline first. I can add a deeper dive into the platform setup in the documentation afterward. This PR currently includes only the training workflow code; the backend setup will be documented separately.

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
@samhita-alla samhita-alla marked this pull request as draft November 19, 2025 16:38
@samhita-alla samhita-alla changed the title add BERT pre-training workflow with flyte 2.0 on EKS Distributed BERT Pre-Training Workflow Using Flyte 2.0 and AWS Trainium Nov 20, 2025
@samhita-alla samhita-alla marked this pull request as ready for review November 20, 2025 10:26
@omrishiv
Copy link
Contributor

Hi @samhita-alla, I messaged you on slack. Just wanted to make sure you received it.

Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
Signed-off-by: Samhita Alla <aallasamhita@gmail.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants