|
37 | 37 | #include "paddle/cinn/runtime/cuda/cuda_util.h" |
38 | 38 | #include "paddle/cinn/runtime/flags.h" |
39 | 39 | #endif |
| 40 | +#ifdef CINN_WITH_CUSTOM_DEVICE |
| 41 | +#include "paddle/cinn/backends/custom_device/codegen_custom_device_dev.h" |
| 42 | +#include "paddle/cinn/backends/custom_device/compiler_custom_device.h" |
| 43 | +#include "paddle/cinn/runtime/custom_device/custom_device_backend_api.h" |
| 44 | +#include "paddle/phi/backends/device_manager.h" |
| 45 | +#endif |
40 | 46 | #ifdef CINN_WITH_HIP |
41 | 47 | #include "paddle/cinn/backends/hip/codegen_hip_dev.h" |
42 | 48 | #include "paddle/cinn/backends/hip/compiler_hip.h" |
@@ -253,7 +259,10 @@ void Compiler::Build(const Module& module, const std::string& code) { |
253 | 259 | [&](common::ARMArch) { CINN_NOT_IMPLEMENTED; }, |
254 | 260 | [&](common::NVGPUArch) { CompileCudaModule(module, code); }, |
255 | 261 | [&](common::HygonDCUArchHIP) { CompileHipModule(module, code); }, |
256 | | - [&](common::HygonDCUArchSYCL) { CompileSyclModule(module, code); }); |
| 262 | + [&](common::HygonDCUArchSYCL) { CompileSyclModule(module, code); }, |
| 263 | + [&](common::CustomDeviceArch) { |
| 264 | + CompileCustomDeviceModule(module, code); |
| 265 | + }); |
257 | 266 | } |
258 | 267 |
|
259 | 268 | void Compiler::AppendCX86(const Module& module) { |
@@ -344,6 +353,19 @@ std::string Compiler::GetSourceCode(const ir::Module& module) { |
344 | 353 | [&](common::UnknownArch) -> std::string { CINN_NOT_IMPLEMENTED; }, |
345 | 354 | [&](common::X86Arch) -> std::string { CINN_NOT_IMPLEMENTED; }, |
346 | 355 | [&](common::ARMArch) -> std::string { CINN_NOT_IMPLEMENTED; }, |
| 356 | + [&](common::CustomDeviceArch) -> std::string { |
| 357 | +#ifdef CINN_WITH_CUSTOM_DEVICE |
| 358 | + auto _host_module_device_module_ = |
| 359 | + SplitDeviceAndHostModule(module); // NOLINT |
| 360 | + auto& host_module = std::get<0>(_host_module_device_module_); |
| 361 | + auto& device_module = std::get<1>(_host_module_device_module_); |
| 362 | + custom_device::CodeGenCustomDevice codegen(target_); |
| 363 | + auto source_code = codegen.Compile(device_module); |
| 364 | + return source_code; |
| 365 | +#else |
| 366 | + CINN_NOT_IMPLEMENTED |
| 367 | +#endif |
| 368 | + }, |
347 | 369 | [&](common::NVGPUArch) -> std::string { |
348 | 370 | #ifdef CINN_WITH_CUDA |
349 | 371 | auto _host_module_device_module_ = |
@@ -390,6 +412,7 @@ void Compiler::BuildDefault(const Module& module) { |
390 | 412 | [&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; }, |
391 | 413 | [&](common::X86Arch) { CompileX86Module(module); }, |
392 | 414 | [&](common::ARMArch) { CINN_NOT_IMPLEMENTED; }, |
| 415 | + [&](common::CustomDeviceArch) { CompileCustomDeviceModule(module); }, |
393 | 416 | [&](common::NVGPUArch) { CompileCudaModule(module); }, |
394 | 417 | [&](common::HygonDCUArchHIP) { CompileHipModule(module); }, |
395 | 418 | [&](common::HygonDCUArchSYCL) { CompileSyclModule(module); }); |
@@ -418,6 +441,7 @@ void Compiler::RegisterDeviceModuleSymbol() { |
418 | 441 | [&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; }, |
419 | 442 | [&](common::X86Arch) { return; }, |
420 | 443 | [&](common::ARMArch) { return; }, |
| 444 | + [&](common::CustomDeviceArch) { RegisterCustomDeviceModuleSymbol(); }, |
421 | 445 | [&](common::NVGPUArch) { RegisterCudaModuleSymbol(); }, |
422 | 446 | [&](common::HygonDCUArchHIP) { RegisterHipModuleSymbol(); }, |
423 | 447 | [&](common::HygonDCUArchSYCL) { RegisterSyclModuleSymbol(); }); |
@@ -526,6 +550,60 @@ void Compiler::RegisterCudaModuleSymbol() { |
526 | 550 | #endif |
527 | 551 | } |
528 | 552 |
|
| 553 | +void Compiler::RegisterCustomDeviceModuleSymbol() { |
| 554 | +#ifdef CINN_WITH_CUSTOM_DEVICE |
| 555 | + // 1. Get the plugin instance (needed for LoadModule later) |
| 556 | + auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes(); |
| 557 | + PADDLE_ENFORCE_EQ(!dev_types.empty(), |
| 558 | + true, |
| 559 | + ::common::errors::NotFound( |
| 560 | + "No custom device registered in DeviceManager.")); |
| 561 | + std::string dev_type = dev_types[0]; |
| 562 | + auto place = phi::CustomPlace(dev_type, 0); |
| 563 | + auto& plugin = |
| 564 | + cinn::runtime::custom_device::CinnCustomDevicePlugin::GetInstance(place); |
| 565 | + |
| 566 | + // 2. Invoke cdrtc::Compiler to compile source → shared lib |
| 567 | + common::Target target = common::DefaultCustomDeviceTarget(); |
| 568 | + cdrtc::Compiler compiler(target); |
| 569 | + std::string lib_path = compiler(device_fn_code_); |
| 570 | + |
| 571 | + PADDLE_ENFORCE_EQ( |
| 572 | + !lib_path.empty(), |
| 573 | + true, |
| 574 | + ::common::errors::External("Custom Device Toolchain compile failed.")); |
| 575 | + |
| 576 | + // 3. Invoke the plugin runtime to load the module |
| 577 | + this->device_module_ = plugin.GetRuntime()->LoadModule(lib_path); |
| 578 | + PADDLE_ENFORCE_NOT_NULL( |
| 579 | + this->device_module_, |
| 580 | + ::common::errors::External( |
| 581 | + "Custom Device Runtime failed to load module from %s", |
| 582 | + lib_path.c_str())); |
| 583 | + |
| 584 | + // 4. Register Kernel symbols |
| 585 | + // Retrieve the device function pointers (handles) and register them |
| 586 | + // as [kernel_name]_ptr_ |
| 587 | + RuntimeSymbols symbols; |
| 588 | + for (const auto& kernel_fn_name : device_fn_name_) { |
| 589 | + void* fn_kernel = this->device_module_->GetFunction(kernel_fn_name); |
| 590 | + |
| 591 | + PADDLE_ENFORCE_NOT_NULL(fn_kernel, |
| 592 | + ::common::errors::NotFound( |
| 593 | + "Custom Device Runtime cannot find kernel: %s", |
| 594 | + kernel_fn_name.c_str())); |
| 595 | + |
| 596 | + // 5. Store the pointer for use by the ExecutionEngine |
| 597 | + fn_ptr_.push_back(fn_kernel); |
| 598 | + symbols.RegisterVar(kernel_fn_name + "_ptr_", fn_kernel); |
| 599 | + } |
| 600 | + |
| 601 | + engine_->RegisterModuleRuntimeSymbols(std::move(symbols)); |
| 602 | +#else |
| 603 | + CINN_NOT_IMPLEMENTED |
| 604 | +#endif |
| 605 | +} |
| 606 | + |
529 | 607 | void Compiler::RegisterHipModuleSymbol() { |
530 | 608 | #ifdef CINN_WITH_HIP |
531 | 609 | hiprtc::Compiler compiler; |
@@ -632,6 +710,46 @@ void Compiler::CompileCudaModule(const Module& module, |
632 | 710 | #endif |
633 | 711 | } |
634 | 712 |
|
| 713 | +void Compiler::CompileCustomDeviceModule(const Module& module, |
| 714 | + const std::string& code) { |
| 715 | +#ifdef CINN_WITH_CUSTOM_DEVICE |
| 716 | + auto _host_module_device_module_ = |
| 717 | + SplitDeviceAndHostModule(module); // NOLINT |
| 718 | + auto& host_module = std::get<0>(_host_module_device_module_); |
| 719 | + auto& device_module = std::get<1>(_host_module_device_module_); |
| 720 | + VLOG(3) << "[CustomDevice] host module:\n" << host_module; |
| 721 | + |
| 722 | + VLOG(3) << "[CustomDevice] device module:\n" << device_module; |
| 723 | + std::string source_code; |
| 724 | + |
| 725 | + if (!FLAGS_cinn_debug_custom_code_path.empty()) { |
| 726 | + std::string file_path = FLAGS_cinn_debug_custom_code_path; |
| 727 | + source_code = GetFileContent(file_path); |
| 728 | + } else if (code.empty()) { |
| 729 | + custom_device::CodeGenCustomDevice codegen(target_); |
| 730 | + source_code = codegen.Compile(device_module); |
| 731 | + } else { |
| 732 | + source_code = code; |
| 733 | + } |
| 734 | + |
| 735 | + PADDLE_ENFORCE_EQ(!source_code.empty(), |
| 736 | + true, |
| 737 | + ::common::errors::InvalidArgument( |
| 738 | + "Compile CustomDevice code failed from device module")); |
| 739 | + VLOG(1) << "[CustomDevice] Source:\n" << source_code; |
| 740 | + SourceCodePrint::GetInstance()->write(source_code); |
| 741 | + device_fn_code_ += source_code; |
| 742 | + |
| 743 | + for (auto& fn : device_module.functions()) { |
| 744 | + std::string kernel_fn_name = fn->name; |
| 745 | + device_fn_name_.emplace_back(kernel_fn_name); |
| 746 | + } |
| 747 | + engine_->Link<CodeGenGpuHost>(host_module); |
| 748 | +#else |
| 749 | + CINN_NOT_IMPLEMENTED |
| 750 | +#endif |
| 751 | +} |
| 752 | + |
635 | 753 | void Compiler::CompileHipModule(const Module& module, const std::string& code) { |
636 | 754 | #ifdef CINN_WITH_HIP |
637 | 755 | auto _host_module_device_module_ = |
|
0 commit comments