File tree 1 file changed +14
-2
lines changed
1 file changed +14
-2
lines changed Original file line number Diff line number Diff line change @@ -250,8 +250,20 @@ def enable_pytorch_expandable_segments():
250
250
251
251
252
252
def check_cuda_env ():
253
- if os .getenv ("CUDA_DEVICE_MAX_CONNECTIONS" ) is None :
254
- logger .warning ("Env var CUDA_DEVICE_MAX_CONNECTIONS has not be set, please note this!" )
253
+ if internlm_accelerator .get_accelerator_backend () == AcceleratorType .GPU :
254
+ max_connections = os .getenv ("CUDA_DEVICE_MAX_CONNECTIONS" )
255
+ assert max_connections is not None , "Env var CUDA_DEVICE_MAX_CONNECTIONS has not been set, please set it to 1!"
256
+ assert (
257
+ max_connections == "1"
258
+ ), "Env var CUDA_DEVICE_MAX_CONNECTIONS is set to {}, it should be set to 1!" .format (max_connections )
259
+
260
+ avoid_record_streams = os .getenv ("TORCH_NCCL_AVOID_RECORD_STREAMS" )
261
+ assert (
262
+ avoid_record_streams is not None
263
+ ), "Env var TORCH_NCCL_AVOID_RECORD_STREAMS has not been set, please set it to 1!"
264
+ assert (
265
+ avoid_record_streams == "1"
266
+ ), "Env var TORCH_NCCL_AVOID_RECORD_STREAMS is set to {}, it should be set to 1!" .format (avoid_record_streams )
255
267
256
268
257
269
class DummyProfile :
You can’t perform that action at this time.
0 commit comments