Skip to content

Commit 399e85b

Browse files
authored
CINN CustomDevice (PaddlePaddle#77158)
* Arch for CINN CustomDevice. * Add CustomDevice Interface GetMaxBlocksPerMultiProcessor * Build paddle-cpu with: cmake .. -GNinja -DPY_VERSION=3.10 -DWITH_GPU=OFF -DWITH_DISTRIBUTE=ON -DWITH_CINN=ON -DWITH_CUSTOM_DEVICE=ON * Add /paddle/cinn/backends/custom_device/ and ../paddle/cinn/runtime/custom_device/ * Add CinnCustomDevicePlugin into ../paddle/cinn/runtime/custom_device/custom_device_backend_api.h(.cc) * [CINN] Refactor custom_device runtime to be hardware-agnostic via Plugin API - Abstract hardware-specific logic into CinnCustomDevicePlugin. - Remove vendor-specific files (HIP/MACA modules and source headers). - Update CustomBackendAPI and utils to dispatch tasks via Plugin and Phi DeviceManager. - Decouple CINN runtime from specific GPU backends (HIP/DCU/MACA). * Add customModule interface. * Add GetWarpSize, GetMaxRegistersPerMultiProcessor, GetPreferredVectorWidth into interfaces. Remove hard code in group_tile_config.cc. * Remove magic numbers in group_tile_config.cc * Fix some bug. * Fix undefine proto problem. * Macro to ForceRegisterCinnCustomDeviceIntrinsics(). * cinn_jit_instruction.cc support CustomDevice. * Fix undefined proto for custom_device_intrinsics_reduce.cc & custom_device_intrinsics_float16.cc * Update VLOG level for CustomDevice host module nad device module. * CodeStyle Check. * Add CINN_WITH_CUSTOM_DEVICE on custom_device_*.cc * Translate Chinese comment to English. * Fix CustomBackendAPI::free(void* data) and DeviceManager->MemoryDeallocate(data, size). * Fix pybind.cc and cinn_is_available() * Fix cinn_is_available() * Support argidx ArgMin/ArgMax Block Reduce for CustomDevice CINN. * Bugfix custom_device_intrinsics_reduce.cc add endif. * Bugfix CustomDevice PHI use_cudnn = false. * Fix group_tile_config.cc and tile_first_general_tactic.cc, get warp_size from custom_device. * Add int64 abs. * Add float16.h bfloat16.h float8e4m3.h to .gitignore * Fix group_tile_config.cc 32->warp_size. * Support CINN group_schedule tactic with CustomDevice. * Fix .gitmodules * Fix pre-commit CodeStyle. * Fix const_cast, PADDLE_WITH_CUSTOM_DEVICE, annotation, Copilot suggestion, pybind.cc utils.py. * Fix tile_broadcast_tactic.cc GetMaxThreadsPerBlock() to context_->target.max_num_threads * Fix according to CodeReview. * Fix cinn_is_available * Fix CodeReview 2.
1 parent 82d1e0e commit 399e85b

File tree

77 files changed

+3397
-137
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

77 files changed

+3397
-137
lines changed

.gitignore

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,3 +123,6 @@ python/paddle/base/dygraph/generated_tensor_methods_patch.py
123123
#fp8
124124
paddle/fluid/fp8/deep_gemm/include/cute/*
125125
paddle/fluid/fp8/deep_gemm/include/cutlass/*
126+
paddle/cinn/runtime/custom_device/float16.h
127+
paddle/cinn/runtime/custom_device/bfloat16.h
128+
paddle/cinn/runtime/custom_device/float8e4m3.h

cmake/cinn.cmake

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,26 @@ if(WITH_ROCM)
133133
file(COPY paddle/cinn/common/float16.h DESTINATION $ENV{runtime_include_dir})
134134
endif()
135135

136+
if(WITH_CUSTOM_DEVICE)
137+
message(STATUS "CINN Compile with custom device support")
138+
139+
add_definitions(-DCINN_WITH_CUSTOM_DEVICE)
140+
141+
if(NOT DEFINED ENV{runtime_include_dir})
142+
set(ENV{runtime_include_dir}
143+
"${CMAKE_SOURCE_DIR}/paddle/cinn/runtime/custom_device")
144+
add_definitions(
145+
-DRUNTIME_INCLUDE_DIR="${CMAKE_SOURCE_DIR}/paddle/cinn/runtime/custom_device"
146+
)
147+
endif()
148+
149+
message(STATUS "copy float16 headers for custom device")
150+
file(MAKE_DIRECTORY $ENV{runtime_include_dir})
151+
file(COPY paddle/cinn/common/float16.h paddle/cinn/common/bfloat16.h
152+
paddle/cinn/common/float8e4m3.h
153+
DESTINATION $ENV{runtime_include_dir})
154+
endif()
155+
136156
set(cinnapi_src CACHE INTERNAL "" FORCE)
137157
set(core_src CACHE INTERNAL "" FORCE)
138158
set(core_includes CACHE INTERNAL "" FORCE)

paddle/cinn/backends/CMakeLists.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ if(WITH_SYCL)
3030
add_subdirectory(sycl)
3131
endif()
3232

33+
if(WITH_CUSTOM_DEVICE)
34+
add_subdirectory(custom_device)
35+
endif()
36+
3337
if(WITH_OPENMP)
3438
cinn_cc_library(__x86_source_fake_lib SRCS _x86_builtin_source.cc)
3539
endif()

paddle/cinn/backends/codegen_cuda_host.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,13 @@ class CodeGenGpuHost : public CodeGenHost {
6363
} else {
6464
return CodeGenHost::Visit(op);
6565
}
66+
},
67+
[&](common::CustomDeviceArch) {
68+
if (op->name == runtime::intrinsic::call_custom_device_kernel) {
69+
return LowerGPUKernelCall(op);
70+
} else {
71+
return CodeGenHost::Visit(op);
72+
}
6673
});
6774
}
6875

paddle/cinn/backends/codegen_device_util.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,11 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
257257
[&](common::HygonDCUArchSYCL) {
258258
#ifdef CINN_WITH_SYCL
259259
shared_mem_bytes = Expr(0);
260+
#endif
261+
},
262+
[&](common::CustomDeviceArch) {
263+
#ifdef CINN_WITH_CUSTOM_DEVICE
264+
shared_mem_bytes = CalculateSharedMemory(func);
260265
#endif
261266
});
262267

@@ -283,6 +288,9 @@ void detail::CollectBucketStrategyHostFunctionVisitor::ProcessLoweredFunc(
283288
},
284289
[&](common::HygonDCUArchSYCL) {
285290
call_kernel = runtime::intrinsic::call_sycl_kernel;
291+
},
292+
[&](common::CustomDeviceArch) {
293+
call_kernel = runtime::intrinsic::call_custom_device_kernel;
286294
});
287295
// TODO(Dmovic): use new ir when backend update done.
288296
// Author(liujinnan): Copy args instead of use func args directly in host

paddle/cinn/backends/codegen_device_util.h

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@
2626
#ifdef CINN_WITH_SYCL
2727
#include "paddle/cinn/backends/sycl/codegen_sycl_dev.h"
2828
#endif
29+
#ifdef CINN_WITH_CUSTOM_DEVICE
30+
#include "paddle/cinn/backends/custom_device/codegen_custom_device_dev.h"
31+
#endif
2932
#include "paddle/cinn/cinn.h"
3033
#include "paddle/cinn/ir/ir.h"
3134
#include "paddle/cinn/ir/ir_mutator.h"
@@ -127,6 +130,14 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
127130
[&](std::variant<common::UnknownArch,
128131
common::X86Arch,
129132
common::ARMArch>) { CINN_NOT_IMPLEMENTED; },
133+
[&](common::CustomDeviceArch) {
134+
#ifdef CINN_WITH_CUSTOM_DEVICE
135+
custom_device::CodeGenCustomDevice codegen_dev(
136+
cinn::common::DefaultCustomDeviceTarget());
137+
codegen_dev.Compile(ir::LoweredFunc(func));
138+
shared_mem_bytes = codegen_dev.GetDynSharedMemOffset();
139+
#endif
140+
},
130141
[&](common::NVGPUArch) {
131142
#ifdef CINN_WITH_CUDA
132143
CodeGenCudaDev codegen_dev(cinn::common::DefaultNVGPUTarget());
@@ -165,6 +176,9 @@ struct CollectHostFunctionVisitor : public ir::IRMutator<> {
165176
[&](std::variant<common::UnknownArch,
166177
common::X86Arch,
167178
common::ARMArch>) { CINN_NOT_IMPLEMENTED; },
179+
[&](common::CustomDeviceArch) {
180+
call_kernel = runtime::intrinsic::call_custom_device_kernel;
181+
},
168182
[&](common::NVGPUArch) {
169183
call_kernel = runtime::intrinsic::call_cuda_kernel;
170184
},

paddle/cinn/backends/compiler.cc

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,12 @@
3737
#include "paddle/cinn/runtime/cuda/cuda_util.h"
3838
#include "paddle/cinn/runtime/flags.h"
3939
#endif
40+
#ifdef CINN_WITH_CUSTOM_DEVICE
41+
#include "paddle/cinn/backends/custom_device/codegen_custom_device_dev.h"
42+
#include "paddle/cinn/backends/custom_device/compiler_custom_device.h"
43+
#include "paddle/cinn/runtime/custom_device/custom_device_backend_api.h"
44+
#include "paddle/phi/backends/device_manager.h"
45+
#endif
4046
#ifdef CINN_WITH_HIP
4147
#include "paddle/cinn/backends/hip/codegen_hip_dev.h"
4248
#include "paddle/cinn/backends/hip/compiler_hip.h"
@@ -253,7 +259,10 @@ void Compiler::Build(const Module& module, const std::string& code) {
253259
[&](common::ARMArch) { CINN_NOT_IMPLEMENTED; },
254260
[&](common::NVGPUArch) { CompileCudaModule(module, code); },
255261
[&](common::HygonDCUArchHIP) { CompileHipModule(module, code); },
256-
[&](common::HygonDCUArchSYCL) { CompileSyclModule(module, code); });
262+
[&](common::HygonDCUArchSYCL) { CompileSyclModule(module, code); },
263+
[&](common::CustomDeviceArch) {
264+
CompileCustomDeviceModule(module, code);
265+
});
257266
}
258267

259268
void Compiler::AppendCX86(const Module& module) {
@@ -344,6 +353,19 @@ std::string Compiler::GetSourceCode(const ir::Module& module) {
344353
[&](common::UnknownArch) -> std::string { CINN_NOT_IMPLEMENTED; },
345354
[&](common::X86Arch) -> std::string { CINN_NOT_IMPLEMENTED; },
346355
[&](common::ARMArch) -> std::string { CINN_NOT_IMPLEMENTED; },
356+
[&](common::CustomDeviceArch) -> std::string {
357+
#ifdef CINN_WITH_CUSTOM_DEVICE
358+
auto _host_module_device_module_ =
359+
SplitDeviceAndHostModule(module); // NOLINT
360+
auto& host_module = std::get<0>(_host_module_device_module_);
361+
auto& device_module = std::get<1>(_host_module_device_module_);
362+
custom_device::CodeGenCustomDevice codegen(target_);
363+
auto source_code = codegen.Compile(device_module);
364+
return source_code;
365+
#else
366+
CINN_NOT_IMPLEMENTED
367+
#endif
368+
},
347369
[&](common::NVGPUArch) -> std::string {
348370
#ifdef CINN_WITH_CUDA
349371
auto _host_module_device_module_ =
@@ -390,6 +412,7 @@ void Compiler::BuildDefault(const Module& module) {
390412
[&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; },
391413
[&](common::X86Arch) { CompileX86Module(module); },
392414
[&](common::ARMArch) { CINN_NOT_IMPLEMENTED; },
415+
[&](common::CustomDeviceArch) { CompileCustomDeviceModule(module); },
393416
[&](common::NVGPUArch) { CompileCudaModule(module); },
394417
[&](common::HygonDCUArchHIP) { CompileHipModule(module); },
395418
[&](common::HygonDCUArchSYCL) { CompileSyclModule(module); });
@@ -418,6 +441,7 @@ void Compiler::RegisterDeviceModuleSymbol() {
418441
[&](common::UnknownArch) { CINN_NOT_IMPLEMENTED; },
419442
[&](common::X86Arch) { return; },
420443
[&](common::ARMArch) { return; },
444+
[&](common::CustomDeviceArch) { RegisterCustomDeviceModuleSymbol(); },
421445
[&](common::NVGPUArch) { RegisterCudaModuleSymbol(); },
422446
[&](common::HygonDCUArchHIP) { RegisterHipModuleSymbol(); },
423447
[&](common::HygonDCUArchSYCL) { RegisterSyclModuleSymbol(); });
@@ -526,6 +550,60 @@ void Compiler::RegisterCudaModuleSymbol() {
526550
#endif
527551
}
528552

553+
void Compiler::RegisterCustomDeviceModuleSymbol() {
554+
#ifdef CINN_WITH_CUSTOM_DEVICE
555+
// 1. Get the plugin instance (needed for LoadModule later)
556+
auto dev_types = phi::DeviceManager::GetAllCustomDeviceTypes();
557+
PADDLE_ENFORCE_EQ(!dev_types.empty(),
558+
true,
559+
::common::errors::NotFound(
560+
"No custom device registered in DeviceManager."));
561+
std::string dev_type = dev_types[0];
562+
auto place = phi::CustomPlace(dev_type, 0);
563+
auto& plugin =
564+
cinn::runtime::custom_device::CinnCustomDevicePlugin::GetInstance(place);
565+
566+
// 2. Invoke cdrtc::Compiler to compile source → shared lib
567+
common::Target target = common::DefaultCustomDeviceTarget();
568+
cdrtc::Compiler compiler(target);
569+
std::string lib_path = compiler(device_fn_code_);
570+
571+
PADDLE_ENFORCE_EQ(
572+
!lib_path.empty(),
573+
true,
574+
::common::errors::External("Custom Device Toolchain compile failed."));
575+
576+
// 3. Invoke the plugin runtime to load the module
577+
this->device_module_ = plugin.GetRuntime()->LoadModule(lib_path);
578+
PADDLE_ENFORCE_NOT_NULL(
579+
this->device_module_,
580+
::common::errors::External(
581+
"Custom Device Runtime failed to load module from %s",
582+
lib_path.c_str()));
583+
584+
// 4. Register Kernel symbols
585+
// Retrieve the device function pointers (handles) and register them
586+
// as [kernel_name]_ptr_
587+
RuntimeSymbols symbols;
588+
for (const auto& kernel_fn_name : device_fn_name_) {
589+
void* fn_kernel = this->device_module_->GetFunction(kernel_fn_name);
590+
591+
PADDLE_ENFORCE_NOT_NULL(fn_kernel,
592+
::common::errors::NotFound(
593+
"Custom Device Runtime cannot find kernel: %s",
594+
kernel_fn_name.c_str()));
595+
596+
// 5. Store the pointer for use by the ExecutionEngine
597+
fn_ptr_.push_back(fn_kernel);
598+
symbols.RegisterVar(kernel_fn_name + "_ptr_", fn_kernel);
599+
}
600+
601+
engine_->RegisterModuleRuntimeSymbols(std::move(symbols));
602+
#else
603+
CINN_NOT_IMPLEMENTED
604+
#endif
605+
}
606+
529607
void Compiler::RegisterHipModuleSymbol() {
530608
#ifdef CINN_WITH_HIP
531609
hiprtc::Compiler compiler;
@@ -632,6 +710,46 @@ void Compiler::CompileCudaModule(const Module& module,
632710
#endif
633711
}
634712

713+
void Compiler::CompileCustomDeviceModule(const Module& module,
714+
const std::string& code) {
715+
#ifdef CINN_WITH_CUSTOM_DEVICE
716+
auto _host_module_device_module_ =
717+
SplitDeviceAndHostModule(module); // NOLINT
718+
auto& host_module = std::get<0>(_host_module_device_module_);
719+
auto& device_module = std::get<1>(_host_module_device_module_);
720+
VLOG(3) << "[CustomDevice] host module:\n" << host_module;
721+
722+
VLOG(3) << "[CustomDevice] device module:\n" << device_module;
723+
std::string source_code;
724+
725+
if (!FLAGS_cinn_debug_custom_code_path.empty()) {
726+
std::string file_path = FLAGS_cinn_debug_custom_code_path;
727+
source_code = GetFileContent(file_path);
728+
} else if (code.empty()) {
729+
custom_device::CodeGenCustomDevice codegen(target_);
730+
source_code = codegen.Compile(device_module);
731+
} else {
732+
source_code = code;
733+
}
734+
735+
PADDLE_ENFORCE_EQ(!source_code.empty(),
736+
true,
737+
::common::errors::InvalidArgument(
738+
"Compile CustomDevice code failed from device module"));
739+
VLOG(1) << "[CustomDevice] Source:\n" << source_code;
740+
SourceCodePrint::GetInstance()->write(source_code);
741+
device_fn_code_ += source_code;
742+
743+
for (auto& fn : device_module.functions()) {
744+
std::string kernel_fn_name = fn->name;
745+
device_fn_name_.emplace_back(kernel_fn_name);
746+
}
747+
engine_->Link<CodeGenGpuHost>(host_module);
748+
#else
749+
CINN_NOT_IMPLEMENTED
750+
#endif
751+
}
752+
635753
void Compiler::CompileHipModule(const Module& module, const std::string& code) {
636754
#ifdef CINN_WITH_HIP
637755
auto _host_module_device_module_ =

paddle/cinn/backends/compiler.h

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
#ifdef CINN_WITH_CUDA
3030
#include "paddle/cinn/runtime/cuda/cuda_module.h"
3131
#endif
32+
#ifdef CINN_WITH_CUSTOM_DEVICE
33+
#include "paddle/cinn/runtime/custom_device/custom_device_backend_api.h"
34+
#endif
3235
#ifdef CINN_WITH_HIP
3336
#include "paddle/cinn/runtime/hip/hip_module.h"
3437
#endif
@@ -174,13 +177,18 @@ class Compiler final {
174177

175178
void RegisterCudaModuleSymbol();
176179

180+
void RegisterCustomDeviceModuleSymbol();
181+
177182
void RegisterHipModuleSymbol();
178183

179184
void RegisterSyclModuleSymbol();
180185

181186
void CompileCudaModule(const ir::Module& module,
182187
const std::string& code = "");
183188

189+
void CompileCustomDeviceModule(const ir::Module& module,
190+
const std::string& code = "");
191+
184192
void CompileHipModule(const ir::Module& module, const std::string& code = "");
185193

186194
void CompileSyclModule(const ir::Module& module,
@@ -211,6 +219,11 @@ class Compiler final {
211219
std::unique_ptr<runtime::cuda::CUDAModule> cuda_module_;
212220
void* cuda_module_handle_{nullptr};
213221
#endif
222+
223+
#ifdef CINN_WITH_CUSTOM_DEVICE
224+
std::unique_ptr<runtime::CustomModule> device_module_;
225+
#endif
226+
214227
#ifdef CINN_WITH_HIP
215228
std::unique_ptr<runtime::hip::HIPModule> hip_module_;
216229
#endif
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
core_gather_headers()
2+
3+
gather_srcs(cinnapi_src SRCS codegen_custom_device_dev.cc
4+
compiler_custom_device.cc)

0 commit comments

Comments
 (0)