@@ -122,7 +122,9 @@ int gemm_bias_act_lt(
122122 reinterpret_cast <cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle ());
123123 // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
124124 // setting this to 1M.
125- size_t workspaceSize = 1024 * 1024 ;
125+ // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
126+ // https://github.com/NVIDIA/TransformerEngine/blob/a0f0065498bbcfc1da78cf9e8b166f5381613fbc/transformer_engine/pytorch/module.py#L91
127+ size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties ()->major >= 9 ? 32 : 4 );
126128 void * workspace = at::empty (
127129 {static_cast <int64_t >(workspaceSize)},
128130 at::device ({at::kCUDA , at::cuda::current_device ()}).dtype (at::kByte )).data_ptr ();
@@ -296,7 +298,8 @@ int gemm_bgradb_lt(
296298 reinterpret_cast <cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle ());
297299 // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
298300 // setting this to 1M.
299- size_t workspaceSize = 1024 * 1024 ;
301+ // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
302+ size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties ()->major >= 9 ? 32 : 4 );
300303 void * workspace = at::empty (
301304 {static_cast <int64_t >(workspaceSize)},
302305 at::device ({at::kCUDA , at::cuda::current_device ()}).dtype (at::kByte )).data_ptr ();
@@ -449,7 +452,8 @@ int gemm_dact_bgradb_lt(
449452 reinterpret_cast <cublasLtHandle_t>(at::cuda::getCurrentCUDABlasHandle ());
450453 // See https://github.com/pytorch/pytorch/issues/73328 for reasoning behind
451454 // setting this to 1M.
452- size_t workspaceSize = 1024 * 1024 ;
455+ // However, Apex sets it to 4M and TransformerEngine sets to 32M for Hopper and 4M for other GPUs
456+ size_t workspaceSize = 1024 * 1024 * (at::cuda::getCurrentDeviceProperties ()->major >= 9 ? 32 : 4 );
453457 void * workspace = at::empty (
454458 {static_cast <int64_t >(workspaceSize)},
455459 at::device ({at::kCUDA , at::cuda::current_device ()}).dtype (at::kByte )).data_ptr ();
0 commit comments