Skip to content

Commit 884d63b

Browse files
committed
header-only adapter for EP API
1 parent 0409ba6 commit 884d63b

File tree

14 files changed

+1071
-0
lines changed

14 files changed

+1071
-0
lines changed

include/onnxruntime/ep/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
## EP adapter
2+
3+
This folder contains a set of C++ header files. They are used specifically for allowing ONNX Runtime internal kernel-based EPs to use the plugin-style EP API while keep minimal changes to existing code.
4+
5+
### Usage
6+
7+
Make sure to include "ep/_pch.h" for all source code in the implementation. Using PCH is recommended.

include/onnxruntime/ep/_pch.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "api.h"
7+
#include "common.h"
8+
9+
// This header is only used when building WebGPU/CUDA EP as a shared library.
10+
//
11+
// This header file is used as a precompiled header so it is always included first.
12+
13+
#pragma push_macro("ORT_EP_API_ADAPTER_HEADER_INCLUDED")
14+
#define ORT_EP_API_ADAPTER_HEADER_INCLUDED
15+
16+
#include "adapter/allocator.h"
17+
#include "adapter/logging.h"
18+
#include "adapter/ep.h"
19+
#include "adapter/kernel_registry.h"
20+
21+
#pragma pop_macro("ORT_EP_API_ADAPTER_HEADER_INCLUDED")
22+
23+
//
24+
// EP specific using declarations
25+
//
26+
27+
#define EP_SPECIFIC_USING_DECLARATIONS \
28+
using FuncManager = onnxruntime::ep::adapter::FuncManager; \
29+
using KernelCreatePtrFn = onnxruntime::ep::adapter::KernelCreatePtrFn; \
30+
using KernelDefBuilder = onnxruntime::ep::adapter::KernelDefBuilder; \
31+
using KernelRegistry = onnxruntime::ep::adapter::KernelRegistry; \
32+
using KernelCreateInfo = onnxruntime::ep::adapter::KernelCreateInfo; \
33+
using BuildKernelCreateInfoFn = onnxruntime::ep::adapter::KernelCreateInfo (*)(); \
34+
using OpKernelInfo = onnxruntime::ep::adapter::OpKernelInfo; \
35+
using OpKernelContext = onnxruntime::ep::adapter::OpKernelContext; \
36+
using OpKernel = onnxruntime::ep::adapter::OpKernel; \
37+
using DataTransferManager = onnxruntime::ep::adapter::DataTransferManager; \
38+
namespace logging { \
39+
using Logger = onnxruntime::ep::adapter::Logger; \
40+
}
41+
42+
namespace onnxruntime {
43+
namespace webgpu {
44+
EP_SPECIFIC_USING_DECLARATIONS
45+
} // namespace webgpu
46+
namespace cuda {
47+
EP_SPECIFIC_USING_DECLARATIONS
48+
} // namespace cuda
49+
50+
#ifndef DISABLE_CONTRIB_OPS
51+
namespace contrib {
52+
namespace webgpu {
53+
EP_SPECIFIC_USING_DECLARATIONS
54+
} // namespace webgpu
55+
namespace cuda {
56+
EP_SPECIFIC_USING_DECLARATIONS
57+
} // namespace cuda
58+
} // namespace contrib
59+
#endif
60+
61+
} // namespace onnxruntime
Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/framework/allocator.h"
7+
8+
namespace onnxruntime {
9+
namespace ep {
10+
namespace adapter {
11+
12+
/// <summary>
13+
/// A bridge class between the EP API OrtAllocator and an IAllocator implementation.
14+
/// </summary>
15+
class Allocator : public OrtAllocator {
16+
public:
17+
explicit Allocator(AllocatorPtr impl) : OrtAllocator{}, impl_(impl) {
18+
version = ORT_API_VERSION;
19+
Alloc = AllocImpl;
20+
Free = FreeImpl;
21+
Info = InfoImpl;
22+
}
23+
24+
private:
25+
static void* ORT_API_CALL AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept {
26+
auto* allocator = static_cast<Allocator*>(this_ptr);
27+
return allocator->impl_->Alloc(size);
28+
}
29+
30+
static void ORT_API_CALL FreeImpl(OrtAllocator* this_ptr, void* p) noexcept {
31+
auto* allocator = static_cast<Allocator*>(this_ptr);
32+
allocator->impl_->Free(p);
33+
}
34+
35+
static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const OrtAllocator* this_ptr) noexcept {
36+
auto* allocator = static_cast<const Allocator*>(this_ptr);
37+
return &allocator->impl_->Info();
38+
}
39+
40+
AllocatorPtr impl_;
41+
};
42+
43+
} // namespace adapter
44+
} // namespace ep
45+
} // namespace onnxruntime
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED)
7+
#error "This header should not be included directly. Include ep/_pch.h instead."
8+
#endif
9+
10+
#include "core/common/status.h"
11+
#include "core/common/common.h"
12+
#include "core/framework/data_transfer.h"
13+
#include "core/framework/tensor.h"
14+
15+
namespace onnxruntime {
16+
namespace ep {
17+
namespace adapter {
18+
19+
/// <summary>
20+
/// </summary>
21+
struct DataTransferManager {
22+
explicit DataTransferManager(std::unique_ptr<IDataTransfer> impl) : impl_{std::move(impl)} {}
23+
24+
common::Status CopyTensor(const Tensor& src, Tensor& dst) const {
25+
if (src.Shape().Size() != dst.Shape().Size()) {
26+
return ORT_MAKE_STATUS(ONNXRUNTIME,
27+
FAIL,
28+
"Tensor size mismatch: source tensor size is ",
29+
src.Shape().Size(),
30+
", destination tensor size is ",
31+
dst.Shape().Size());
32+
}
33+
34+
if (impl_->CanCopy(src.Location().device, dst.Location().device)) {
35+
return impl_->CopyTensor(src, dst);
36+
}
37+
38+
return ORT_MAKE_STATUS(ONNXRUNTIME,
39+
FAIL,
40+
"There's no data transfer registered for copying tensors from ",
41+
src.Location().device.ToString(),
42+
" to ",
43+
dst.Location().device.ToString());
44+
}
45+
46+
private:
47+
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DataTransferManager);
48+
std::unique_ptr<IDataTransfer> impl_;
49+
};
50+
51+
} // namespace adapter
52+
} // namespace ep
53+
} // namespace onnxruntime
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED)
7+
#error "This header should not be included directly. Include ep/_pch.h instead."
8+
#endif
9+
10+
#include "data_transfer_manager.h"
11+
12+
#include "core/framework/execution_provider.h"
13+
14+
namespace onnxruntime {
15+
namespace ep {
16+
namespace adapter {
17+
18+
/// <summary>
19+
/// Wrapper around IExecutionProvider to expose via OrtEp.
20+
/// </summary>
21+
class Ep : public OrtEp {
22+
protected:
23+
explicit Ep(IExecutionProvider* impl, AllocatorPtr temp_space_cpu_allocator, AllocatorPtr temp_space_allocator)
24+
: OrtEp{},
25+
impl_(impl),
26+
data_transfer_manager_{impl->GetDataTransfer()},
27+
profiler_{impl->GetProfiler()},
28+
temp_space_cpu_allocator_{temp_space_cpu_allocator},
29+
temp_space_allocator_{temp_space_allocator} {
30+
}
31+
32+
public:
33+
inline IExecutionProvider* EpImpl() const noexcept {
34+
return impl_.get();
35+
}
36+
inline const DataTransferManager& GetDataTransferManager() const noexcept {
37+
return data_transfer_manager_;
38+
}
39+
[[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const {
40+
*output = temp_space_cpu_allocator_;
41+
return Status::OK();
42+
}
43+
[[nodiscard]] Status GetTempSpaceAllocator(AllocatorPtr* output) const {
44+
*output = temp_space_allocator_;
45+
return Status::OK();
46+
}
47+
48+
private:
49+
std::unique_ptr<IExecutionProvider> impl_;
50+
DataTransferManager data_transfer_manager_;
51+
std::unique_ptr<profiling::EpProfiler> profiler_;
52+
AllocatorPtr temp_space_cpu_allocator_;
53+
AllocatorPtr temp_space_allocator_;
54+
};
55+
56+
} // namespace adapter
57+
} // namespace ep
58+
} // namespace onnxruntime
Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED)
7+
#error "This header should not be included directly. Include ep/_pch.h instead."
8+
#endif
9+
10+
#include <memory>
11+
12+
#include "core/framework/data_types.h"
13+
14+
namespace onnxruntime {
15+
namespace ep {
16+
namespace adapter {
17+
18+
/// <summary>
19+
/// Gets an OrtMLDataType for a tensor type. Throws on error.
20+
/// </summary>
21+
/// <param name="elem_type"></param>
22+
/// <returns></returns>
23+
inline const OrtDataType* GetTensorType(ONNXTensorElementDataType elem_type) {
24+
const OrtEpApi& ep_api = Ort::GetEpApi();
25+
const OrtDataType* result = nullptr;
26+
27+
Ort::ThrowOnError(ep_api.GetTensorDataType(elem_type, &result));
28+
return result;
29+
}
30+
31+
inline const OrtDataType* MLDataTypeToOrtDataType(MLDataType ml_type) {
32+
auto tensor_type = ml_type->AsTensorType();
33+
EP_ENFORCE(tensor_type != nullptr, "EP Kernel registration only supports tensor types.");
34+
auto elem_type = tensor_type->GetElementType();
35+
auto primitive_type = static_cast<const PrimitiveDataTypeBase*>(elem_type);
36+
auto onnx_type = static_cast<ONNXTensorElementDataType>(primitive_type->GetDataType());
37+
return GetTensorType(onnx_type);
38+
}
39+
40+
struct KernelDefBuilder {
41+
static std::unique_ptr<KernelDefBuilder> Create() { return std::make_unique<KernelDefBuilder>(); }
42+
43+
explicit KernelDefBuilder() {}
44+
45+
KernelDefBuilder& SetName(const char* op_name) {
46+
builder_.SetOperatorType(op_name);
47+
return *this;
48+
}
49+
50+
KernelDefBuilder& SetDomain(const char* domain) {
51+
builder_.SetDomain(domain);
52+
return *this;
53+
}
54+
55+
KernelDefBuilder& SinceVersion(int since_version) {
56+
return SinceVersion(since_version, INT_MAX);
57+
}
58+
59+
KernelDefBuilder& SinceVersion(int since_version_start, int since_version_end) {
60+
builder_.SetSinceVersion(since_version_start, since_version_end);
61+
return *this;
62+
}
63+
64+
KernelDefBuilder& Provider(const char* provider_type) {
65+
builder_.SetExecutionProvider(provider_type);
66+
return *this;
67+
}
68+
69+
KernelDefBuilder& TypeConstraint(const char* arg_name, std::vector<MLDataType> types) {
70+
std::vector<const OrtDataType*> ort_types;
71+
ort_types.reserve(types.size());
72+
for (const auto& type : types) {
73+
ort_types.push_back(MLDataTypeToOrtDataType(type));
74+
}
75+
builder_.AddTypeConstraint(arg_name, ort_types);
76+
return *this;
77+
}
78+
79+
KernelDefBuilder& TypeConstraint(const char* arg_name, MLDataType type) {
80+
builder_.AddTypeConstraint(arg_name, MLDataTypeToOrtDataType(type));
81+
return *this;
82+
}
83+
84+
KernelDefBuilder& MayInplace(const std::vector<std::pair<int, int>>& inplaces) {
85+
for (const auto& pair : inplaces) {
86+
builder_.AddInputOutputMutableAlias(pair.first, pair.second);
87+
}
88+
return *this;
89+
}
90+
KernelDefBuilder& MayInplace(int input_index, int output_index) {
91+
builder_.AddInputOutputMutableAlias(input_index, output_index);
92+
return *this;
93+
}
94+
95+
KernelDefBuilder& Alias(const std::vector<std::pair<int, int>>& aliases) {
96+
for (const auto& pair : aliases) {
97+
builder_.AddInputOutputAlias(pair.first, pair.second);
98+
}
99+
return *this;
100+
}
101+
KernelDefBuilder& Alias(int input_index, int output_index) {
102+
builder_.AddInputOutputAlias(input_index, output_index);
103+
return *this;
104+
}
105+
106+
KernelDefBuilder& InputMemoryType(OrtMemType type, int input_index) {
107+
builder_.SetInputMemType(input_index, type);
108+
return *this;
109+
}
110+
111+
KernelDefBuilder& InputMemoryType(OrtMemType type, const std::vector<int>& input_indexes) {
112+
for (int input_index : input_indexes) {
113+
builder_.SetInputMemType(input_index, type);
114+
}
115+
return *this;
116+
}
117+
118+
KernelDefBuilder& OutputMemoryType(OrtMemType type, int output_index) {
119+
builder_.SetOutputMemType(output_index, type);
120+
return *this;
121+
}
122+
123+
KernelDefBuilder& OutputMemoryType(OrtMemType type, const std::vector<int>& output_indexes) {
124+
for (int output_index : output_indexes) {
125+
builder_.SetOutputMemType(output_index, type);
126+
}
127+
return *this;
128+
}
129+
130+
KernelDefBuilder& ExecQueueId(int queue_id) { return *this; }
131+
132+
Ort::KernelDef Build() { return builder_.Build(); }
133+
134+
private:
135+
Ort::KernelDefBuilder builder_;
136+
};
137+
138+
} // namespace adapter
139+
} // namespace ep
140+
} // namespace onnxruntime

0 commit comments

Comments
 (0)