33from dataclasses import dataclass , field
44
55import logging
6+ import os
67from typing import Dict , Optional , Sequence , Type , Any , List
78from pydantic import BaseModel , create_model
89
1516env = Environment ()
1617
1718
19+ def _parse_tp_size_from_env () -> int :
20+ """Parse tensor parallelism size from SLURM_GPUS_ON_NODE environment variable.
21+
22+ Defensively parses the environment variable, handling invalid values:
23+ - Returns 1 if the variable is None or empty
24+ - Strips whitespace before parsing
25+ - Returns 1 for non-integer values
26+ - Returns 1 for values <= 0
27+
28+ Returns:
29+ Tensor parallelism size (>= 1), defaults to 1 on any parsing error.
30+ """
31+ env_value = os .environ .get ("SLURM_GPUS_ON_NODE" )
32+ if not env_value :
33+ return 1
34+
35+ try :
36+ tp_size = int (env_value .strip ())
37+ # Ensure tp_size is positive (must be >= 1)
38+ if tp_size <= 0 :
39+ logger .warning (
40+ f"Invalid SLURM_GPUS_ON_NODE value '{ env_value } ' (must be > 0), defaulting tp_size to 1"
41+ )
42+ return 1
43+ return tp_size
44+ except ValueError :
45+ # ValueError: invalid integer format
46+ logger .warning (
47+ f"Invalid SLURM_GPUS_ON_NODE value '{ env_value } ', defaulting tp_size to 1"
48+ )
49+ return 1
50+
51+
1852@dataclass
1953class SGLangServerArgs :
2054 """Server arguments for SGLang engine.
@@ -27,7 +61,7 @@ class SGLangServerArgs:
2761 """
2862
2963 model_path : str = "none"
30- tp_size : int = 1
64+ tp_size : int = field ( default_factory = _parse_tp_size_from_env )
3165 trust_remote_code : bool = True
3266 disable_custom_all_reduce : bool = False
3367
0 commit comments