Skip to content

Commit 74d8c7d

Browse files
committed
fix check CUDA_DEVICE_MAX_CONNECTIONS
1 parent 6b7df0b commit 74d8c7d

File tree

1 file changed

+16
-2
lines changed

1 file changed

+16
-2
lines changed

internlm/utils/common.py

+16-2
Original file line numberDiff line numberDiff line change
@@ -250,8 +250,22 @@ def enable_pytorch_expandable_segments():
250250

251251

252252
def check_cuda_env():
253-
if os.getenv("CUDA_DEVICE_MAX_CONNECTIONS") is None:
254-
logger.warning("Env var CUDA_DEVICE_MAX_CONNECTIONS has not be set, please note this!")
253+
if internlm_accelerator.get_accelerator_backend() == AcceleratorType.GPU:
254+
max_connections = os.getenv("CUDA_DEVICE_MAX_CONNECTIONS")
255+
assert (
256+
max_connections is not None
257+
), "Env var CUDA_DEVICE_MAX_CONNECTIONS has not been set, please set it to 1!"
258+
assert (
259+
max_connections == '1'
260+
), "Env var CUDA_DEVICE_MAX_CONNECTIONS is set to {}, it should be set to 1!".format(max_connections)
261+
262+
avoid_record_streams = os.getenv("TORCH_NCCL_AVOID_RECORD_STREAMS")
263+
assert (
264+
avoid_record_streams is not None
265+
), "Env var TORCH_NCCL_AVOID_RECORD_STREAMS has not been set, please set it to 1!"
266+
assert (
267+
avoid_record_streams == '1'
268+
), "Env var TORCH_NCCL_AVOID_RECORD_STREAMS is set to {}, it should be set to 1!".format(avoid_record_streams)
255269

256270

257271
class DummyProfile:

0 commit comments

Comments
 (0)