Skip to content

Commit bd621e7

Browse files
authored
Merge branch 'main' into abstract-distributed-apis-checkpoint-loader
2 parents b1b168d + f25fca6 commit bd621e7

File tree

18 files changed

+599
-109
lines changed

18 files changed

+599
-109
lines changed

.github/workflows/build-and-test.yml

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,6 @@ jobs:
3232
CPP_FAIL_UNDER: 80
3333
permissions:
3434
contents: read # Required for actions/checkout
35-
pull-requests: write # Required to post the comment
3635
steps:
3736
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
3837

@@ -65,12 +64,14 @@ jobs:
6564
- name: Check Python test coverage
6665
run: |
6766
# Verify python coverage thresholds
67+
echo -e "\n##### Generating Python coverage XML #####"
68+
coverage xml -o python-coverage.xml
6869
echo -e "\n##### Verifying Python coverage thresholds #####"
6970
coverage report --fail-under=${{ env.PYTHON_FAIL_UNDER }}
70-
coverage xml -o python-coverage.xml
7171
7272
- name: Python Coverage Summary
7373
uses: irongut/CodeCoverageSummary@51cc3a756ddcd398d447c044c02cb6aa83fdae95 # ratchet:irongut/CodeCoverageSummary@v1.3.0
74+
if: always() # Run even if threshold check above fails
7475
with:
7576
filename: python-coverage.xml
7677
badge: true
@@ -83,12 +84,16 @@ jobs:
8384
thresholds: '${{ env.PYTHON_FAIL_UNDER }} 95'
8485

8586
- name: Add Python Coverage Title
87+
if: always()
8688
run: |
87-
echo '### Python Code Coverage Summary' | cat - code-coverage-results.md > temp && mv temp python-code-coverage-results.md
89+
# Only run if the summary was actually generated
90+
if [ -f code-coverage-results.md ]; then
91+
echo '### Python Code Coverage Summary' | cat - code-coverage-results.md > temp && mv temp python-code-coverage-results.md
92+
fi
8893
8994
- name: Add Python Coverage PR Comment
9095
uses: marocchino/sticky-pull-request-comment@773744901bac0e8cbb5a0dc842800d45e9b2b405 # ratchet:marocchino/sticky-pull-request-comment@v2
91-
if: github.event_name == 'pull_request'
96+
if: false # TODO remove once new workflow confirmed to work
9297
with:
9398
recreate: true
9499
path: python-code-coverage-results.md
@@ -116,6 +121,7 @@ jobs:
116121
117122
- name: C++ Coverage Summary
118123
uses: irongut/CodeCoverageSummary@51cc3a756ddcd398d447c044c02cb6aa83fdae95 # ratchet:irongut/CodeCoverageSummary@v1.3.0
124+
if: always() # Run even if threshold check above fails
119125
with:
120126
filename: cxx-coverage.xml
121127
badge: true
@@ -128,27 +134,39 @@ jobs:
128134
thresholds: '${{ env.CPP_FAIL_UNDER }} 40'
129135

130136
- name: Add C++ Coverage Title
137+
if: always()
131138
run: |
132-
echo '### C++ Code Coverage Summary' | cat - code-coverage-results.md > temp && mv temp cpp-code-coverage-results.md
139+
if [ -f code-coverage-results.md ]; then
140+
echo '### C++ Code Coverage Summary' | cat - code-coverage-results.md > temp && mv temp cpp-code-coverage-results.md
141+
fi
133142
134143
- name: Add C++ Coverage PR Comment
135144
uses: marocchino/sticky-pull-request-comment@773744901bac0e8cbb5a0dc842800d45e9b2b405 # ratchet:marocchino/sticky-pull-request-comment@v2
136-
if: github.event_name == 'pull_request'
145+
if: false # TODO: remove when new workflow confirmed to work
137146
with:
138147
header: cpp-coverage
139148
recreate: true
140149
path: cpp-code-coverage-results.md
141150

151+
- name: Save PR number
152+
# Use always() so this runs even if previous coverage/test steps failed.
153+
if: always() && github.event_name == 'pull_request'
154+
run: |
155+
echo ${{ github.event.number }} > pr_number.txt
156+
142157
- name: Archive coverage reports
143158
uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # ratchet:actions/upload-artifact@v4
159+
if: always()
144160
with:
145161
name: coverage-reports
162+
if-no-files-found: warn # Default, but setting explicitly for awareness as non-PRs won't have pr_number.txt
146163
path: |
147164
htmlcov/
148165
python-coverage.xml
149166
cxx-coverage.xml
150167
python-code-coverage-results.md
151168
cpp-code-coverage-results.md
169+
pr_number.txt
152170
153171
lint-code:
154172
runs-on: ubuntu-latest
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
name: Post Coverage Comment
16+
17+
on:
18+
workflow_run:
19+
workflows: ["Build and Test"]
20+
types:
21+
- completed
22+
23+
jobs:
24+
post-comment:
25+
runs-on: ubuntu-latest
26+
# This workflow runs in the context of the base repository, so it has write permissions
27+
# even when the triggering workflow was from a fork.
28+
# We run even if the build or thresholds failed, so long as it wasn't cancelled.
29+
if: >
30+
github.event.workflow_run.event == 'pull_request' &&
31+
github.event.workflow_run.conclusion != 'cancelled'
32+
permissions:
33+
pull-requests: write
34+
steps:
35+
- name: Download coverage reports artifact
36+
uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # ratchet:actions/download-artifact@v7
37+
continue-on-error: true # Artifact might be missing on very early failures
38+
with:
39+
name: coverage-reports
40+
github-token: ${{ secrets.GITHUB_TOKEN }}
41+
run-id: ${{ github.event.workflow_run.id }}
42+
43+
- name: Check for coverage files
44+
# We use an explicit step to check for file existence and set outputs.
45+
# This is more robust than using hashFiles() in an 'if' expression,
46+
# as hashFiles() is primarily intended for cache keys and lacks
47+
# a dedicated file_exists() equivalent in GitHub Actions expressions.
48+
id: check_files
49+
run: |
50+
if [ -f pr_number.txt ]; then
51+
echo "pr_number=$(cat pr_number.txt)" >> $GITHUB_OUTPUT
52+
echo "pr_found=true" >> $GITHUB_OUTPUT
53+
else
54+
echo "pr_found=false" >> $GITHUB_OUTPUT
55+
fi
56+
57+
if [ -f python-code-coverage-results.md ]; then
58+
echo "python_found=true" >> $GITHUB_OUTPUT
59+
else
60+
echo "python_found=false" >> $GITHUB_OUTPUT
61+
fi
62+
63+
if [ -f cpp-code-coverage-results.md ]; then
64+
echo "cpp_found=true" >> $GITHUB_OUTPUT
65+
else
66+
echo "cpp_found=false" >> $GITHUB_OUTPUT
67+
fi
68+
69+
- name: Post Python Coverage PR Comment
70+
if: steps.check_files.outputs.pr_found == 'true' && steps.check_files.outputs.python_found == 'true'
71+
uses: marocchino/sticky-pull-request-comment@773744901bac0e8cbb5a0dc842800d45e9b2b405 # ratchet:marocchino/sticky-pull-request-comment@v2
72+
with:
73+
recreate: true
74+
number: ${{ steps.check_files.outputs.pr_number }}
75+
path: python-code-coverage-results.md
76+
77+
- name: Post C++ Coverage PR Comment
78+
if: steps.check_files.outputs.pr_found == 'true' && steps.check_files.outputs.cpp_found == 'true'
79+
uses: marocchino/sticky-pull-request-comment@773744901bac0e8cbb5a0dc842800d45e9b2b405 # ratchet:marocchino/sticky-pull-request-comment@v2
80+
with:
81+
header: cpp-coverage
82+
recreate: true
83+
number: ${{ steps.check_files.outputs.pr_number }}
84+
path: cpp-code-coverage-results.md

docs/README.md

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,16 @@ When comparing
3636

3737
We observe:
3838

39-
* Data write times that are up to 20-30x faster for ML Flashpoint specifically, with little to no optimization.
40-
This is expected to further improve with additional optimizations.
41-
* Total checkpoint recovery times that are ~7-10x faster for ML Flashpoint specifically (includes the time it takes to do checkpoint detection, cross-node coordination, replication, read into model state and be ready to resume training).
39+
* Data write times that are up to 120x faster for ML Flashpoint specifically, currently reaching up to ~30 GB/s/node write throughput (scales linearly with cluster size).
40+
* Total checkpoint recovery times that are ~7-12x faster for ML Flashpoint specifically, depending on number of nodes lost (includes the time it takes to do checkpoint detection, cross-node coordination, replication, read into model state and be ready to resume training).
4241
* For _async_ checkpointing:
4342
* Improvements averaging **3%** (Gemma 27B) & **6%** (Llama 70B) for _overall job time_ in the hybrid approach.
4443
* Improvements reach **5%** (Gemma 27B) & **10%** (Llama 70B) when NeMo checkpointing is deferred to the end (300th step) instead of being done every 50 steps.
4544
* These improvements only account for checkpoint _save_ efficiency, representing a "lower bound" value as it doesn't account for the speedups in _recovery_ time.
4645
* Any job interruptions would also benefit from ML Flashpoint's recovery performance gains.
4746

47+
Stay tuned and watch the [repository](https://github.com/google/ml-flashpoint) for updates on future improvements!
48+
4849
!!! info
4950

5051
While [ML runtime goodput](https://cloud.google.com/blog/products/ai-machine-learning/goodput-metric-as-measure-of-ml-productivity) is important, we focus on overall job time as an end-to-end metric, as it is simpler and allows for straightforward _total_ cost comparisons.
@@ -69,8 +70,7 @@ To use ML Flashpoint, the basic requirements for the training environment are:
6970
* This is enforced so that the pairwise strategy doesn't put a higher memory burden on one node than the others, and so the general capacity requirements are roughly consistent across nodes.
7071
1. A `tmpfs` mount is strongly recommended to be used for the container base path, that is separate from `/dev/shm`.
7172
E.g. a `/tmp` mount, which can be added to `/etc/fstab` on Linux machines to mount it persistently (A3-Mega example):
72-
1. `tmpfs /tmp tmpfs rw,nosuid,nodev,size=1024G,mode=1777,noswap,huge=within_size 0 0`
73-
1. `huge=within_size` is recommended to use huge pages for any files large enough, since checkpoint data is on the order of many GBs.
73+
1. `tmpfs /tmp tmpfs rw,nosuid,nodev,size=1024G,mode=1777,noswap 0 0`
7474
1. `noswap` is recommended to avoid degrading performance.
7575
This can be omitted if you prefer to allow transparent disk swapping to accommodate more checkpoint storage than can fit in memory, at the cost of poorer checkpointing performance.
7676
1. The amount of memory needed is at least equal to the checkpoint size per node x 4, to account for replicas and in-progress checkpoints.

docs/user-guide.md

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@ from ml_flashpoint.replication.replication_manager import ReplicationManager
133133

134134
# Megatron Checkpointing
135135
from megatron.core import dist_checkpointing as mcore_dist_checkpointing
136+
from ml_flashpoint.adapter.megatron.save_utils import save_local_aware_megatron_checkpoint
136137
```
137138

138139
#### Save Strategy
@@ -150,8 +151,19 @@ megatron_save_strategy = MLFlashpointMegatronAsyncSaveStrategy(
150151
)
151152
```
152153

153-
Because Megatron's `dist_checkpointing.save()` function writes "common" data only on global rank 0, which does not align with local checkpointing, you can orchestrate saves using the save strategy the same way it's done in [`MLFlashpointCheckpointIO.save_checkpoint()`](https://github.com/google/ml-flashpoint/blob/b9767583520106f59743b9e8050769523cfbef6e/src/ml_flashpoint/adapter/nemo/checkpoint_io.py#L137-L171) in the `ml_flashpoint.adapter.nemo` package.
154-
You'll notice that the logic there aims to mimic `dist_checkpointing.save`, but it saves common data on each node (via local rank 0) as opposed to solely on the coordinator node (global rank 0).
154+
Because Megatron's `dist_checkpointing.save()` function writes "common" data only on global rank 0, which does not align with local checkpointing, use the provided helper function `save_local_aware_megatron_checkpoint()` from the `ml_flashpoint.adapter.megatron.save_utils` module.
155+
156+
This helper mimics `dist_checkpointing.save()`, but saves common data on each node (via local rank 0) rather than solely on the coordinator node (global rank 0).
157+
158+
```python
159+
# In your save loop
160+
async_request = save_local_aware_megatron_checkpoint(
161+
checkpoint=state_dict,
162+
checkpoint_dir=str(curr_step_checkpoint_id),
163+
save_strategy=megatron_save_strategy,
164+
async_save=True,
165+
)
166+
```
155167

156168
!!! note
157169

pyproject.toml

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,22 @@
1919
# ===================================================================
2020
[project]
2121
name = "ml-flashpoint"
22-
version = "0.0.0"
22+
dynamic = [ "version" ]
2323
description = "A memory-first, lightning fast, easy-to-use ML checkpointing library."
24+
readme = "README.md"
25+
license = { file = "LICENSE" }
26+
classifiers = [
27+
"Intended Audience :: Developers",
28+
"Intended Audience :: Science/Research",
29+
"License :: OSI Approved :: Apache Software License",
30+
"Operating System :: POSIX :: Linux",
31+
"Programming Language :: Python :: 3",
32+
"Programming Language :: Python :: 3.10",
33+
"Programming Language :: Python :: 3.11",
34+
"Programming Language :: Python :: 3.12",
35+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
36+
"Topic :: Software Development :: Libraries :: Python Modules",
37+
]
2438

2539
# Specifies the minimum version of Python required to install and run this package.
2640
requires-python = ">=3.10"
@@ -101,6 +115,7 @@ requires = [
101115
"scikit-build-core==0.11.6",
102116
"cmake==3.31.10",
103117
"ninja==1.11.1.3",
118+
"setuptools-scm==9.2.2",
104119
]
105120

106121
# The Python object that `pip` will call to execute the build.
@@ -114,6 +129,9 @@ build-backend = "scikit_build_core.build"
114129
# ===================================================================
115130
[tool.scikit-build]
116131

132+
# Tells scikit-build-core to use setuptools-scm to retrieve the version from git.
133+
metadata.version.provider = "scikit_build_core.metadata.setuptools_scm"
134+
117135
# Specifies the minimum version of CMake that must be present on the system.
118136
cmake.version = ">=3.18"
119137

@@ -134,6 +152,14 @@ cmake.source-dir = "."
134152
# https://scikit-build-core.readthedocs.io/en/latest/configuration/index.html#customizing-the-built-wheel
135153
# wheel.packages = ["src/ml_flashpoint"]
136154

155+
# ===================================================================
156+
# Tool-specific Configuration for setuptools-scm
157+
# ===================================================================
158+
[tool.setuptools_scm]
159+
# Fallback version to use if git is not available or the directory is not a git repo.
160+
# This prevents build failures in environments like some CI runners or /tmp clones.
161+
fallback_version = "0.0.0"
162+
137163
# ===================================================================
138164
# Tool-specific Configuration for Ruff
139165
# ===================================================================

src/ml_flashpoint/adapter/megatron/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,3 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from ml_flashpoint.adapter.megatron.save_utils import (
16+
save_local_aware_megatron_checkpoint as save_local_aware_megatron_checkpoint,
17+
)

src/ml_flashpoint/adapter/megatron/save_strategies.py

Lines changed: 33 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pathlib import Path
2020
from typing import Union
2121

22+
import torch
2223
from megatron.core.dist_checkpointing.mapping import ShardedStateDict
2324
from megatron.core.dist_checkpointing.strategies.async_utils import AsyncRequest
2425
from megatron.core.dist_checkpointing.strategies.base import AsyncSaveShardedStrategy
@@ -32,7 +33,7 @@
3233

3334
from ml_flashpoint.adapter.pytorch import custom_state_dict_saver as statedictsaver
3435
from ml_flashpoint.adapter.pytorch.memory_storage_writer import MemoryStorageWriter
35-
from ml_flashpoint.core import utils
36+
from ml_flashpoint.core import mlf_logging, utils
3637
from ml_flashpoint.core.checkpoint_id_types import CheckpointContainerId
3738
from ml_flashpoint.core.checkpoint_saver import MLFlashpointCheckpointSaver, ObjectWriteBucket
3839
from ml_flashpoint.core.mlf_logging import get_logger
@@ -41,6 +42,26 @@
4142
_LOGGER = get_logger(__name__)
4243

4344

45+
def _save_checkpoint(
46+
staged_buckets: list[ObjectWriteBucket],
47+
checkpoint_id: CheckpointContainerId,
48+
storage_writer: MemoryStorageWriter,
49+
rank: int,
50+
step: int,
51+
):
52+
"""
53+
This function is the 'async_fn' run in Megatron's :class:`AsyncRequest`.
54+
"""
55+
56+
mlf_logging.setup_worker_logging(rank, step)
57+
statedictsaver.write_data(
58+
checkpoint_id=checkpoint_id,
59+
storage_writer=storage_writer,
60+
staged_write_buckets=staged_buckets,
61+
replicate_after_write=False,
62+
)
63+
64+
4465
def default_backend_format_name() -> str:
4566
return "ml_flashpoint"
4667

@@ -105,7 +126,7 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union
105126
# 1b. Re-initialize the StorageWriter to use a new instance per save to avoid hangs from shared state.
106127
self._storage_writer = MemoryStorageWriter(
107128
checkpoint_saver=self._checkpoint_saver,
108-
mp_manager=self._storage_writer._mp_manager,
129+
mp_manager=self._storage_writer._main_process_torchmp_manager,
109130
thread_count=self._storage_writer._thread_count,
110131
)
111132
# 1c. Reset the StorageWriter for this checkpoint version.
@@ -156,17 +177,6 @@ def async_save(self, sharded_state_dict: ShardedStateDict, checkpoint_dir: Union
156177
with open(os.path.join(checkpoint_dir, "metadata.json"), "w") as f:
157178
json.dump(metadata, f)
158179

159-
def _save_checkpoint(staged_buckets: list[ObjectWriteBucket]):
160-
"""
161-
This function is the 'async_fn' run in Megatron's :class:`AsyncRequest`.
162-
"""
163-
statedictsaver.write_data(
164-
checkpoint_id=checkpoint_id,
165-
storage_writer=self._storage_writer,
166-
staged_write_buckets=staged_buckets,
167-
replicate_after_write=False,
168-
)
169-
170180
finalize_fns = [
171181
# Replicate written objects
172182
partial(
@@ -188,9 +198,18 @@ def _save_checkpoint(staged_buckets: list[ObjectWriteBucket]):
188198
),
189199
]
190200

201+
current_rank = torch.distributed.get_rank() if torch.distributed.is_initialized() else -1
202+
current_step = mlf_logging.get_current_step()
203+
191204
return AsyncRequest(
192205
async_fn=_save_checkpoint,
193206
async_fn_args=(),
194-
async_fn_kwargs={"staged_buckets": staged_write_buckets},
207+
async_fn_kwargs={
208+
"staged_buckets": staged_write_buckets,
209+
"checkpoint_id": checkpoint_id,
210+
"storage_writer": self._storage_writer,
211+
"rank": current_rank,
212+
"step": current_step,
213+
},
195214
finalize_fns=finalize_fns,
196215
)

0 commit comments

Comments
 (0)