Skip to content

Commit 2e25d13

Browse files
authored
Only set CUDA_DEVICE_MAX_CONNECTIONS=1 for Hopper/cc9.0 runs (#1249)
Cherry-pick of #1236 to main.
1 parent 505b741 commit 2e25d13

File tree

3 files changed

+6
-3
lines changed

3 files changed

+6
-3
lines changed

.github/container/Dockerfile.jax

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,6 @@ ENV BUILD_DATE=${BUILD_DATE}
8585
ENV XLA_FLAGS=""
8686
ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_latency_hiding_scheduler=true"
8787
ENV XLA_FLAGS="${XLA_FLAGS} --xla_gpu_enable_triton_gemm=false"
88-
ENV CUDA_DEVICE_MAX_CONNECTIONS=1
8988
ENV NCCL_NVLS_ENABLE=0
9089

9190
COPY --from=builder ${BUILD_PATH_JAXLIB} ${BUILD_PATH_JAXLIB}

.github/container/test-maxtext.sh

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -224,7 +224,12 @@ pushd ${MAXTEXT_DIR}
224224

225225
export NVTE_FUSED_ATTN=${ENABLE_FUSED_ATTN}
226226
export XLA_PYTHON_CLIENT_MEM_FRACTION=${MEM_FRACTION}
227-
export CUDA_DEVICE_MAX_CONNECTIONS=1
227+
228+
local_arch=$(local_cuda_arch)
229+
if [[ "${local_arch}" == "9.0" ]]; then
230+
echo "Setting CUDA_DEVICE_MAX_CONNECTIONS=1 for cc${local_arch} devices"
231+
export CUDA_DEVICE_MAX_CONNECTIONS=1
232+
fi
228233

229234
export BASE_XLA_FLAGS=${BASE_XLA_FLAGS:---xla_gpu_enable_latency_hiding_scheduler=true
230235
--xla_gpu_enable_triton_gemm=false

README.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,6 @@ The [JAX image](https://github.com/NVIDIA/JAX-Toolbox/pkgs/container/jax) is emb
315315

316316
| Environment Variable | Value | Explanation |
317317
| -------------------- | ----- | ----------- |
318-
| `CUDA_DEVICE_MAX_CONNECTIONS` | `1` | use a single queue for GPU work to lower latency of stream operations; OK since XLA already orders launches |
319318
| `NCCL_NVLS_ENABLE` | `0` | Disables NVLink SHARP ([1](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/env.html#nccl-nvls-enable)). Future releases will re-enable this feature. |
320319

321320
There are various other XLA flags users can set to improve performance. For a detailed explanation of these flags, please refer to the [GPU performance](./rosetta/docs/GPU_performance.md) doc. XLA flags can be tuned per workflow. For example, each script in [contrib/gpu/scripts_gpu](https://github.com/google/paxml/tree/main/paxml/contrib/gpu/scripts_gpu) sets its own [XLA flags](https://github.com/google/paxml/blob/93fbc8010dca95af59ab615c366d912136b7429c/paxml/contrib/gpu/scripts_gpu/benchmark_gpt_multinode.sh#L30-L33).

0 commit comments

Comments
 (0)