Skip to content

Commit bbc2a3f

Browse files
fabnemEPFLCopilot
andauthored
Default tp_size based on slurm number of GPUs (#17)
Co-authored-by: fabnemEPFL <117652591+fabnemEPFL@users.noreply.github.com> Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com>
1 parent f6bcac8 commit bbc2a3f

1 file changed

Lines changed: 35 additions & 1 deletion

File tree

  • src/mmirage/core/process/processors/llm

src/mmirage/core/process/processors/llm/config.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from dataclasses import dataclass, field
44

55
import logging
6+
import os
67
from typing import Dict, Optional, Sequence, Type, Any, List
78
from pydantic import BaseModel, create_model
89

@@ -15,6 +16,39 @@
1516
env = Environment()
1617

1718

19+
def _parse_tp_size_from_env() -> int:
20+
"""Parse tensor parallelism size from SLURM_GPUS_ON_NODE environment variable.
21+
22+
Defensively parses the environment variable, handling invalid values:
23+
- Returns 1 if the variable is None or empty
24+
- Strips whitespace before parsing
25+
- Returns 1 for non-integer values
26+
- Returns 1 for values <= 0
27+
28+
Returns:
29+
Tensor parallelism size (>= 1), defaults to 1 on any parsing error.
30+
"""
31+
env_value = os.environ.get("SLURM_GPUS_ON_NODE")
32+
if not env_value:
33+
return 1
34+
35+
try:
36+
tp_size = int(env_value.strip())
37+
# Ensure tp_size is positive (must be >= 1)
38+
if tp_size <= 0:
39+
logger.warning(
40+
f"Invalid SLURM_GPUS_ON_NODE value '{env_value}' (must be > 0), defaulting tp_size to 1"
41+
)
42+
return 1
43+
return tp_size
44+
except ValueError:
45+
# ValueError: invalid integer format
46+
logger.warning(
47+
f"Invalid SLURM_GPUS_ON_NODE value '{env_value}', defaulting tp_size to 1"
48+
)
49+
return 1
50+
51+
1852
@dataclass
1953
class SGLangServerArgs:
2054
"""Server arguments for SGLang engine.
@@ -27,7 +61,7 @@ class SGLangServerArgs:
2761
"""
2862

2963
model_path: str = "none"
30-
tp_size: int = 1
64+
tp_size: int = field(default_factory=_parse_tp_size_from_env)
3165
trust_remote_code: bool = True
3266
disable_custom_all_reduce: bool = False
3367

0 commit comments

Comments
 (0)