-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathenvs.py
More file actions
56 lines (51 loc) · 2.33 KB
/
envs.py
File metadata and controls
56 lines (51 loc) · 2.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# SPDX-License-Identifier: MIT
# Copyright (C) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
import os
from typing import Callable, Any
# Data parallel environment variables
environment_variables: dict[str, Callable[[], Any]] = {
"ATOM_DP_RANK": lambda: int(os.getenv("ATOM_DP_RANK", "0")),
"ATOM_DP_RANK_LOCAL": lambda: int(os.getenv("ATOM_DP_RANK_LOCAL", "0")),
"ATOM_DP_SIZE": lambda: int(os.getenv("ATOM_DP_SIZE", "1")),
"ATOM_DP_MASTER_IP": lambda: os.getenv("ATOM_DP_MASTER_IP", "127.0.0.1"),
"ATOM_DP_MASTER_PORT": lambda: int(os.getenv("ATOM_DP_MASTER_PORT", "29500")),
"ATOM_ENFORCE_EAGER": lambda: os.getenv("ATOM_ENFORCE_EAGER", "0") == "1",
# add qk-norm-rope-cache-quant fusion for Qwen3-Moe model, default disabled,
# Qwen3-Moe model should enable this for better performance.
"ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION": lambda: os.getenv(
"ATOM_ENABLE_QK_NORM_ROPE_CACHE_QUANT_FUSION", "0"
)
== "1",
"ATOM_USE_TRITON_GEMM": lambda: os.getenv("ATOM_USE_TRITON_GEMM", "0") == "1",
"ATOM_USE_TRITON_MLA_DECODE": lambda: os.getenv("ATOM_USE_TRITON_MLA_DECODE", "0")
== "1",
"ATOM_USE_TRITON_MXFP4_BMM": lambda: os.getenv("ATOM_USE_TRITON_MXFP4_BMM", "0")
== "1",
"ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION": lambda: os.getenv(
"ATOM_ENABLE_DS_INPUT_RMSNORM_QUANT_FUSION", "1"
)
== "1",
"ATOM_ENABLE_DS_QKNORM_QUANT_FUSION": lambda: os.getenv(
"ATOM_ENABLE_DS_QKNORM_QUANT_FUSION", "1"
)
== "1",
"ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION": lambda: os.getenv(
"ATOM_ENABLE_ALLREDUCE_RMSNORM_FUSION", "1"
)
== "1",
"ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT": lambda: os.getenv(
"ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_RMSNORM_QUANT", "1"
)
== "1",
"ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT": lambda: os.getenv(
"ATOM_LLAMA_ENABLE_AITER_TRITON_FUSED_SILU_MUL_QUANT", "1"
)
== "1",
"ATOM_USE_FLYDSL_MOE": lambda: os.getenv("ATOM_USE_FLYDSL_MOE", "0") == "1",
"ATOM_CK_FREE": lambda: os.getenv("ATOM_CK_FREE", "0") == "1",
}
def __getattr__(name: str):
# lazy evaluation of environment variables
if name in environment_variables:
return environment_variables[name]()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")