Skip to content

Commit 60accc8

Browse files
ermilovmaximGoogle-ML-Automation
authored andcommitted
implement optional threadpool support in CubinCustomKernelCompiler
PiperOrigin-RevId: 903290909
1 parent adea5dc commit 60accc8

4 files changed

Lines changed: 56 additions & 10 deletions

File tree

xla/backends/gpu/codegen/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,7 @@ cc_library(
387387
"//xla/codegen/emitters:kernel_arguments",
388388
"//xla/service/gpu:launch_dimensions",
389389
"//xla/stream_executor:device_description",
390+
"//xla/tsl/platform:env",
390391
"//xla/tsl/platform:status_macros",
391392
"@com_google_absl//absl/functional:any_invocable",
392393
"@com_google_absl//absl/status:statusor",

xla/backends/gpu/codegen/cubin_custom_kernel_compiler.cc

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ limitations under the License.
2121
#include <utility>
2222
#include <vector>
2323

24+
#include "absl/status/statusor.h"
2425
#include "xla/tsl/platform/status_macros.h"
2526
#include "llvm/ExecutionEngine/Orc/ThreadSafeModule.h"
2627
#include "llvm/IR/Module.h"
@@ -40,6 +41,27 @@ xla::Future<std::unique_ptr<Thunk>> CubinCustomKernelCompiler::Compile(
4041
const std::string& sanitized_kernel_name,
4142
const emitters::KernelArguments& kernel_arguments,
4243
const LaunchDimensions& launch_dimensions) {
44+
if (!thread_pool_) {
45+
return CompileImpl(std::move(thunk_info), std::move(kernel_source),
46+
sanitized_kernel_name, kernel_arguments,
47+
launch_dimensions);
48+
}
49+
return tsl::MakeFutureOn(
50+
*thread_pool_->AsExecutor(),
51+
[this, thunk_info = std::move(thunk_info),
52+
kernel_source = std::move(kernel_source), sanitized_kernel_name,
53+
kernel_arguments, launch_dimensions]() mutable {
54+
return CompileImpl(std::move(thunk_info), std::move(kernel_source),
55+
sanitized_kernel_name, kernel_arguments,
56+
launch_dimensions);
57+
});
58+
}
59+
60+
absl::StatusOr<std::unique_ptr<Thunk>> CubinCustomKernelCompiler::CompileImpl(
61+
Thunk::ThunkInfo thunk_info, LlvmKernelSource kernel_source,
62+
const std::string& sanitized_kernel_name,
63+
const emitters::KernelArguments& kernel_arguments,
64+
const LaunchDimensions& launch_dimensions) {
4365
llvm::orc::ThreadSafeModule thread_safe_module =
4466
std::move(kernel_source).thread_safe_module();
4567
llvm::Module* llvm_module = thread_safe_module.getModuleUnlocked();

xla/backends/gpu/codegen/cubin_custom_kernel_compiler.h

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ limitations under the License.
3232
#include "xla/future.h"
3333
#include "xla/service/gpu/launch_dimensions.h"
3434
#include "xla/stream_executor/device_description.h"
35+
#include "xla/tsl/platform/threadpool.h"
3536

3637
namespace xla::gpu {
3738

@@ -45,17 +46,20 @@ using LlvmIrCompiler = absl::AnyInvocable<absl::StatusOr<std::vector<uint8_t>>(
4546
// Implementation of KernelCompiler that compiles LLVM IR to CUBIN format using
4647
// a provided compilation function.
4748
//
48-
// Note: This implementation is currently synchronous. The compilation happens
49+
// Note: CubinCustomKernelCompiler utilizes provided threadpool.
50+
// If threadpool is not provided, the compilation happens
4951
// fully within this call, and the result is returned as an immediately ready
5052
// Future.
5153
class CubinCustomKernelCompiler : public KernelCompiler {
5254
public:
5355
CubinCustomKernelCompiler(LlvmIrCompiler compiler,
5456
const se::DeviceDescription& gpu_device_info,
55-
const DebugOptions& debug_options)
57+
const DebugOptions& debug_options,
58+
tsl::thread::ThreadPool* thread_pool = nullptr)
5659
: compiler_(std::move(compiler)),
5760
device_info_(gpu_device_info),
58-
debug_options_(debug_options) {}
61+
debug_options_(debug_options),
62+
thread_pool_(thread_pool) {}
5963

6064
xla::Future<std::unique_ptr<Thunk>> Compile(
6165
Thunk::ThunkInfo thunk_info, LlvmKernelSource kernel_source,
@@ -64,9 +68,16 @@ class CubinCustomKernelCompiler : public KernelCompiler {
6468
const LaunchDimensions& launch_dimensions) override;
6569

6670
private:
71+
absl::StatusOr<std::unique_ptr<Thunk>> CompileImpl(
72+
Thunk::ThunkInfo thunk_info, LlvmKernelSource kernel_source,
73+
const std::string& sanitized_kernel_name,
74+
const emitters::KernelArguments& kernel_arguments,
75+
const LaunchDimensions& launch_dimensions);
76+
6777
LlvmIrCompiler compiler_;
6878
const se::DeviceDescription device_info_;
6979
const DebugOptions debug_options_;
80+
tsl::thread::ThreadPool* thread_pool_;
7081
};
7182

7283
} // namespace xla::gpu

xla/service/gpu/gpu_compiler.cc

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717

1818
#include <algorithm>
1919
#include <array>
20+
#include <atomic>
2021
#include <cstdint>
2122
#include <functional>
2223
#include <memory>
@@ -2519,6 +2520,13 @@ GpuCompiler::CompileToBackendResult(
25192520
module, schedule_metadata.scheduler_mem_limit,
25202521
gpu_topology.gpu_target_config().device_description, alias_info.get()));
25212522

2523+
MaybeOwningThreadPool thread_pool = CreateMaybeOwningThreadPool(
2524+
/*parallelism=*/module->config()
2525+
.debug_options()
2526+
.xla_gpu_force_compilation_parallelism(),
2527+
/*default_thread_pool=*/options.thread_pool,
2528+
/*default_parallelism=*/tsl::port::MaxParallelism());
2529+
25222530
ASSIGN_OR_RETURN(
25232531
bool can_use_link_modules,
25242532
CanUseLinkModules(module->config(),
@@ -2531,6 +2539,7 @@ GpuCompiler::CompileToBackendResult(
25312539
.xla_gpu_enable_llvm_module_compilation_parallelism();
25322540

25332541
CompileModuleResults compile_module_results;
2542+
std::atomic<int> shard_number = 0;
25342543

25352544
{
25362545
xla::llvm_ir::LLVMCommandLineOptionsReleasableLock llvm_options_lock(
@@ -2541,15 +2550,16 @@ GpuCompiler::CompileToBackendResult(
25412550
auto llvm_compiler =
25422551
[&](llvm::Module& llvm_module, const se::DeviceDescription& descr,
25432552
const DebugOptions& opts) -> absl::StatusOr<std::vector<uint8_t>> {
2544-
ASSIGN_OR_RETURN(BackendCompileResult result,
2545-
CompileSingleModule(module->config(), descr, module,
2546-
&llvm_module, false, std::nullopt));
2553+
ASSIGN_OR_RETURN(
2554+
BackendCompileResult result,
2555+
CompileSingleModule(module->config(), descr, module, &llvm_module,
2556+
false, shard_number.fetch_add(1)));
25472557
return std::move(result.binary);
25482558
};
25492559
CubinCustomKernelCompiler kernel_compiler(
25502560
std::move(llvm_compiler),
25512561
gpu_topology.gpu_target_config().device_description,
2552-
module->config().debug_options());
2562+
module->config().debug_options(), thread_pool.get_mutable());
25532563
kernel_compiler.SetPreOptimizationHook([&](const llvm::Module& module) {
25542564
CallUserPreOptimizationHook(module);
25552565
});
@@ -2571,7 +2581,8 @@ GpuCompiler::CompileToBackendResult(
25712581
for (const std::unique_ptr<llvm::Module>& llvm_module :
25722582
compile_module_results.llvm_modules) {
25732583
llvm_ir::DumpIrIfEnabled(*module, *llvm_module,
2574-
/*optimized=*/false);
2584+
/*optimized=*/false,
2585+
std::to_string(shard_number.fetch_add(1)));
25752586
CallUserPreOptimizationHook(*llvm_module);
25762587
}
25772588
if (compile_module_results.llvm_module_constants != nullptr) {
@@ -2613,7 +2624,7 @@ GpuCompiler::CompileToBackendResult(
26132624
gpu_topology.gpu_target_config().device_description,
26142625
module, &*compile_module_results.llvm_modules[0],
26152626
/*relocatable=*/false,
2616-
/*shard_number=*/std::nullopt));
2627+
/*shard_number=*/shard_number.fetch_add(1)));
26172628
}
26182629

26192630
if (!backend_result.asm_text.empty()) {
@@ -3198,13 +3209,14 @@ GpuCompiler::LoadExecutableFromAotResult(
31983209
BufferAssignment::FromProto(proto.buffer_assignment(), hlo_module.get(),
31993210
BufferSizeBytesFunction(), alias_info.get()));
32003211

3212+
std::atomic<int> shard_number = 0;
32013213
auto llvm_compiler =
32023214
[&](llvm::Module& llvm_module, const se::DeviceDescription& descr,
32033215
const DebugOptions& opts) -> absl::StatusOr<std::vector<uint8_t>> {
32043216
ASSIGN_OR_RETURN(
32053217
BackendCompileResult result,
32063218
CompileSingleModule(hlo_module->config(), descr, hlo_module.get(),
3207-
&llvm_module, false, std::nullopt));
3219+
&llvm_module, false, shard_number.fetch_add(1)));
32083220
return std::move(result.binary);
32093221
};
32103222
CubinCustomKernelCompiler kernel_compiler(

0 commit comments

Comments
 (0)