Skip to content

Commit f0dea04

Browse files
rebel-jonghewkrebel-jonghewk
andauthored
fix: use container-aware thread count to avoid host op performance degradation in pods (#512)
Co-authored-by: rebel-jonghewk <jonghewk@rebellions.in>
1 parent 672ead5 commit f0dea04

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

vllm_rbln/v1/worker/optimum_worker.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,12 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import os
1516
from contextlib import nullcontext
1617
from types import NoneType
1718
from typing import Any
1819

20+
import numba
1921
import torch
2022
import torch.distributed
2123
import torch.nn as nn
@@ -38,6 +40,7 @@
3840
from vllm_rbln.utils.optimum.cache_blocks import sync_num_blocks
3941
from vllm_rbln.utils.optimum.rbln_params import get_rbln_params
4042
from vllm_rbln.v1.worker.optimum_model_runner import RBLNOptimumModelRunner
43+
from vllm_rbln.v1.worker.utils import set_cpu_affinity, set_omp_num_threads
4144

4245
logger = init_logger(__name__)
4346

@@ -96,6 +99,42 @@ def __init__(
9699
self.profiler = None
97100

98101
def init_device(self) -> None:
102+
allocated_cpus = len(os.sched_getaffinity(0))
103+
reported_cpus = os.cpu_count() or allocated_cpus
104+
105+
if allocated_cpus < reported_cpus:
106+
# Use physical cores only (exclude HT siblings).
107+
num_threads = max(2, allocated_cpus // 2)
108+
logger.info(
109+
"Container cpuset detected (%d/%d CPUs). "
110+
"Skipping set_cpu_affinity, setting threads to %d "
111+
"(physical cores only, excluding HT).",
112+
allocated_cpus,
113+
reported_cpus,
114+
num_threads,
115+
)
116+
117+
# Set all thread pool environment variables
118+
os.environ["OMP_NUM_THREADS"] = str(num_threads)
119+
os.environ["MKL_NUM_THREADS"] = str(num_threads)
120+
os.environ["OPENBLAS_NUM_THREADS"] = str(num_threads)
121+
os.environ["NUMEXPR_MAX_THREADS"] = str(num_threads)
122+
os.environ["RBLN_NUM_THREADS"] = str(num_threads)
123+
124+
# Directly set PyTorch thread counts
125+
torch.set_num_threads(num_threads)
126+
127+
set_omp_num_threads(self.rank, self.local_rank, num_threads)
128+
else:
129+
# Bare metal: use NUMA-aware binding
130+
set_cpu_affinity(self.rank, self.local_rank, self.parallel_config)
131+
allocated_cpus = len(os.sched_getaffinity(0))
132+
set_omp_num_threads(self.rank, self.local_rank, max(2, allocated_cpus))
133+
134+
# Sync numba and torch thread settings to avoid recompilation
135+
# caused by global state mismatch between the two runtimes
136+
numba.set_num_threads(torch.get_num_threads())
137+
torch.set_num_threads(numba.get_num_threads())
99138
# Initialize the distributed environment.
100139
init_worker_distributed_environment(
101140
self.vllm_config, self.rank, self.distributed_init_method, self.local_rank

0 commit comments

Comments
 (0)