File tree 1 file changed +16
-2
lines changed
1 file changed +16
-2
lines changed Original file line number Diff line number Diff line change @@ -250,8 +250,22 @@ def enable_pytorch_expandable_segments():
250
250
251
251
252
252
def check_cuda_env ():
253
- if os .getenv ("CUDA_DEVICE_MAX_CONNECTIONS" ) is None :
254
- logger .warning ("Env var CUDA_DEVICE_MAX_CONNECTIONS has not be set, please note this!" )
253
+ if internlm_accelerator .get_accelerator_backend () == AcceleratorType .GPU :
254
+ max_connections = os .getenv ("CUDA_DEVICE_MAX_CONNECTIONS" )
255
+ assert (
256
+ max_connections is not None
257
+ ), "Env var CUDA_DEVICE_MAX_CONNECTIONS has not been set, please set it to 1!"
258
+ assert (
259
+ max_connections == '1'
260
+ ), "Env var CUDA_DEVICE_MAX_CONNECTIONS is set to {}, it should be set to 1!" .format (max_connections )
261
+
262
+ avoid_record_streams = os .getenv ("TORCH_NCCL_AVOID_RECORD_STREAMS" )
263
+ assert (
264
+ avoid_record_streams is not None
265
+ ), "Env var TORCH_NCCL_AVOID_RECORD_STREAMS has not been set, please set it to 1!"
266
+ assert (
267
+ avoid_record_streams == '1'
268
+ ), "Env var TORCH_NCCL_AVOID_RECORD_STREAMS is set to {}, it should be set to 1!" .format (avoid_record_streams )
255
269
256
270
257
271
class DummyProfile :
You can’t perform that action at this time.
0 commit comments