@@ -17,6 +17,7 @@ limitations under the License.
1717
1818#include < algorithm>
1919#include < array>
20+ #include < atomic>
2021#include < cstdint>
2122#include < functional>
2223#include < memory>
@@ -2519,6 +2520,13 @@ GpuCompiler::CompileToBackendResult(
25192520 module , schedule_metadata.scheduler_mem_limit ,
25202521 gpu_topology.gpu_target_config ().device_description , alias_info.get ()));
25212522
2523+ MaybeOwningThreadPool thread_pool = CreateMaybeOwningThreadPool (
2524+ /* parallelism=*/ module ->config ()
2525+ .debug_options ()
2526+ .xla_gpu_force_compilation_parallelism (),
2527+ /* default_thread_pool=*/ options.thread_pool ,
2528+ /* default_parallelism=*/ tsl::port::MaxParallelism ());
2529+
25222530 ASSIGN_OR_RETURN (
25232531 bool can_use_link_modules,
25242532 CanUseLinkModules (module ->config (),
@@ -2531,6 +2539,7 @@ GpuCompiler::CompileToBackendResult(
25312539 .xla_gpu_enable_llvm_module_compilation_parallelism ();
25322540
25332541 CompileModuleResults compile_module_results;
2542+ std::atomic<int > shard_number = 0 ;
25342543
25352544 {
25362545 xla::llvm_ir::LLVMCommandLineOptionsReleasableLock llvm_options_lock (
@@ -2541,15 +2550,16 @@ GpuCompiler::CompileToBackendResult(
25412550 auto llvm_compiler =
25422551 [&](llvm::Module& llvm_module, const se::DeviceDescription& descr,
25432552 const DebugOptions& opts) -> absl::StatusOr<std::vector<uint8_t >> {
2544- ASSIGN_OR_RETURN (BackendCompileResult result,
2545- CompileSingleModule (module ->config (), descr, module ,
2546- &llvm_module, false , std::nullopt ));
2553+ ASSIGN_OR_RETURN (
2554+ BackendCompileResult result,
2555+ CompileSingleModule (module ->config (), descr, module , &llvm_module,
2556+ false , shard_number.fetch_add (1 )));
25472557 return std::move (result.binary );
25482558 };
25492559 CubinCustomKernelCompiler kernel_compiler (
25502560 std::move (llvm_compiler),
25512561 gpu_topology.gpu_target_config ().device_description ,
2552- module ->config ().debug_options ());
2562+ module ->config ().debug_options (), thread_pool. get_mutable () );
25532563 kernel_compiler.SetPreOptimizationHook ([&](const llvm::Module& module ) {
25542564 CallUserPreOptimizationHook (module );
25552565 });
@@ -2571,7 +2581,8 @@ GpuCompiler::CompileToBackendResult(
25712581 for (const std::unique_ptr<llvm::Module>& llvm_module :
25722582 compile_module_results.llvm_modules ) {
25732583 llvm_ir::DumpIrIfEnabled (*module , *llvm_module,
2574- /* optimized=*/ false );
2584+ /* optimized=*/ false ,
2585+ std::to_string (shard_number.fetch_add (1 )));
25752586 CallUserPreOptimizationHook (*llvm_module);
25762587 }
25772588 if (compile_module_results.llvm_module_constants != nullptr ) {
@@ -2613,7 +2624,7 @@ GpuCompiler::CompileToBackendResult(
26132624 gpu_topology.gpu_target_config ().device_description ,
26142625 module , &*compile_module_results.llvm_modules [0 ],
26152626 /* relocatable=*/ false ,
2616- /* shard_number=*/ std:: nullopt ));
2627+ /* shard_number=*/ shard_number. fetch_add ( 1 ) ));
26172628 }
26182629
26192630 if (!backend_result.asm_text .empty ()) {
@@ -3198,13 +3209,14 @@ GpuCompiler::LoadExecutableFromAotResult(
31983209 BufferAssignment::FromProto (proto.buffer_assignment (), hlo_module.get (),
31993210 BufferSizeBytesFunction (), alias_info.get ()));
32003211
3212+ std::atomic<int > shard_number = 0 ;
32013213 auto llvm_compiler =
32023214 [&](llvm::Module& llvm_module, const se::DeviceDescription& descr,
32033215 const DebugOptions& opts) -> absl::StatusOr<std::vector<uint8_t >> {
32043216 ASSIGN_OR_RETURN (
32053217 BackendCompileResult result,
32063218 CompileSingleModule (hlo_module->config (), descr, hlo_module.get (),
3207- &llvm_module, false , std:: nullopt ));
3219+ &llvm_module, false , shard_number. fetch_add ( 1 ) ));
32083220 return std::move (result.binary );
32093221 };
32103222 CubinCustomKernelCompiler kernel_compiler (
0 commit comments