Skip to content

Commit a63aa3f

Browse files
authored
Merge branch 'main' into alechan/upgrade-xpk-v0.13.0
2 parents fe5639a + d39fa23 commit a63aa3f

File tree

5 files changed

+381
-9
lines changed

5 files changed

+381
-9
lines changed

.github/container/Dockerfile.axlearn

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,35 @@
11
# syntax=docker/dockerfile:1-labs
22
ARG BASE_IMAGE=ghcr.io/nvidia/jax-mealkit:jax
3-
ARG URLREF_AXLEARN=https://github.com/Steboss/axlearn.git#main
3+
ARG URLREF_AXLEARN=https://github.com/apple/axlearn.git
44
ARG SRC_PATH_AXLEARN=/opt/axlearn
5+
ARG DEST_MANIFEST_DIR=/opt/manifest.d
6+
ARG GIT_USER_NAME=JAX Toolbox
7+
58

69
###############################################################################
710
## Download source and configure dependencies
811
###############################################################################
912
FROM ${BASE_IMAGE} AS mealkit
13+
ARG DEST_MANIFEST_DIR
14+
ARG SRC_PATH_AXLEARN
15+
ARG GIT_USER_NAME
16+
ARG GIT_USER_EMAIL
1017
ARG URLREF_AXLEARN
1118
ARG SRC_PATH_AXLEARN
1219

13-
RUN git-clone.sh "${URLREF_AXLEARN}" "${SRC_PATH_AXLEARN}"
20+
# Run the patch with cloning
21+
RUN <<"EOF" bash -exu
22+
git config --global user.email "${GIT_USER_EMAIL}"
23+
git config --global user.name "${GIT_USER_NAME}"
24+
git-clone.sh "${URLREF_AXLEARN}" "${SRC_PATH_AXLEARN}"
25+
${DEST_MANIFEST_DIR}/create-distribution.sh \
26+
--manifest ${DEST_MANIFEST_DIR}/manifest.yaml \
27+
--package axlearn
28+
EOF
1429

1530
# these packages are needed to run axlearn tests
1631
# https://github.com/apple/axlearn/blob/main/pyproject.toml as reference
32+
WORKDIR /opt/axlearn
1733
RUN <<"EOF" bash -ex
1834
echo "-e ${SRC_PATH_AXLEARN}" > /opt/pip-tools.d/requirements-axlearn.in
1935
cat <<REQUIREMENTS >> /opt/pip-tools.d/requirements-axlearn.in

.github/container/fuji-train-perf.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
HybridMeshShape,
2323
)
2424

25+
import ast
26+
2527
FLAGS = flags.FLAGS
2628

2729
MODELS_INFO = {
@@ -129,10 +131,9 @@
129131

130132
parser.add_argument(
131133
"--trace_steps",
132-
nargs="+",
133-
type=int,
134+
type=ast.literal_eval,
134135
default=None,
135-
help="Steps to trace (e.g. [1, 20, 50]). To profile the training give `--jax_profiler_port 9999` to the script.",
136+
help="Steps where we want JAX tracing. Example: [1,20,30]. AXLearn will pick up the 3 consecutive steps. Please, remember to add also `--jax_profiler_port 9999` to the input command",
136137
)
137138

138139
parser.add_argument("--world_size", type=int, help="Total number of GPUs")
@@ -253,8 +254,8 @@ def main(parsed_args):
253254
save_checkpoint_steps = parsed_args.save_checkpoint_steps
254255
write_summary_steps = parsed_args.write_summary_steps
255256
output_log_file = parsed_args.output_log_file
256-
trace_steps = parsed_args.trace_steps
257257
world_size = parsed_args.world_size
258+
trace_steps = parsed_args.trace_steps
258259

259260
print(
260261
f"=== Parameter Check ===\n"

.github/container/git-clone.sh

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ while [ : ]; do
4141
;;
4242
--)
4343
shift;
44-
break
44+
break
4545
;;
4646
esac
4747
done
@@ -80,7 +80,6 @@ git submodule update --init --recursive
8080
popd
8181

8282
## update the manifest file
83-
8483
mkdir -p $(dirname ${MANIFEST})
8584
touch ${MANIFEST}
8685
PACKAGE=$(basename "${DESTINATION}")

.github/container/manifest.yaml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,12 @@ pathwaysutils:
9696
latest_verified_commit: 359776d454940ffaa337c36d1df16308d44a95a9
9797
mode: pip-vcs
9898
axlearn:
99-
url: https://github.com/Steboss/axlearn.git
99+
url: https://github.com/apple/axlearn.git
100+
mirror_url: https://github.com/nvjax-svc-0/axlearn.git
100101
tracking_ref: main
101102
mode: git-clone
103+
patches:
104+
pull/1339/head: file://patches/axlearn/PR-1339.patch
102105
qwix:
103106
url: https://github.com/google/qwix.git
104107
tracking_ref: main

0 commit comments

Comments
 (0)