Skip to content

Commit c9e4973

Browse files
committed
[EP API] header-only adapter for EP API
1 parent 9486eed commit c9e4973

File tree

16 files changed

+1424
-0
lines changed

16 files changed

+1424
-0
lines changed

include/onnxruntime/ep/README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
## EP adapter
2+
3+
This folder contains a set of C++ header files. They are used specifically for allowing ONNX Runtime internal kernel-based EPs to use the plugin-style EP API while keep minimal changes to existing code.
4+
5+
### Usage
6+
7+
Make sure to include "ep/_pch.h" for all source code in the implementation. Using PCH is recommended.

include/onnxruntime/ep/_pch.h

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "api.h"
7+
#include "common.h"
8+
9+
// This header is only used when building WebGPU/CUDA EP as a shared library.
10+
//
11+
// This header file is used as a precompiled header so it is always included first.
12+
13+
#pragma push_macro("ORT_EP_API_ADAPTER_HEADER_INCLUDED")
14+
#define ORT_EP_API_ADAPTER_HEADER_INCLUDED
15+
16+
#include "adapter/allocator.h"
17+
#include "adapter/logging.h"
18+
#include "adapter/ep.h"
19+
#include "adapter/kernel_registry.h"
20+
21+
#pragma pop_macro("ORT_EP_API_ADAPTER_HEADER_INCLUDED")
22+
23+
//
24+
// EP specific using declarations
25+
//
26+
27+
#define EP_SPECIFIC_USING_DECLARATIONS \
28+
using FuncManager = onnxruntime::ep::adapter::FuncManager; \
29+
using KernelCreatePtrFn = onnxruntime::ep::adapter::KernelCreatePtrFn; \
30+
using KernelDefBuilder = onnxruntime::ep::adapter::KernelDefBuilder; \
31+
using KernelRegistry = onnxruntime::ep::adapter::KernelRegistry; \
32+
using KernelCreateInfo = onnxruntime::ep::adapter::KernelCreateInfo; \
33+
using BuildKernelCreateInfoFn = onnxruntime::ep::adapter::KernelCreateInfo (*)(); \
34+
using OpKernelInfo = onnxruntime::ep::adapter::OpKernelInfo; \
35+
using OpKernelContext = onnxruntime::ep::adapter::OpKernelContext; \
36+
using OpKernel = onnxruntime::ep::adapter::OpKernel; \
37+
using DataTransferManager = onnxruntime::ep::adapter::DataTransferManager; \
38+
namespace logging { \
39+
using Logger = onnxruntime::ep::adapter::Logger; \
40+
}
41+
42+
namespace onnxruntime {
43+
namespace webgpu {
44+
EP_SPECIFIC_USING_DECLARATIONS
45+
} // namespace webgpu
46+
namespace cuda {
47+
EP_SPECIFIC_USING_DECLARATIONS
48+
} // namespace cuda
49+
50+
#ifndef DISABLE_CONTRIB_OPS
51+
namespace contrib {
52+
namespace webgpu {
53+
EP_SPECIFIC_USING_DECLARATIONS
54+
} // namespace webgpu
55+
namespace cuda {
56+
EP_SPECIFIC_USING_DECLARATIONS
57+
} // namespace cuda
58+
} // namespace contrib
59+
#endif
60+
61+
} // namespace onnxruntime
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#include "core/framework/allocator.h"
7+
8+
namespace onnxruntime {
9+
namespace ep {
10+
namespace adapter {
11+
12+
/// <summary>
13+
/// A bridge class between the EP API OrtAllocator and an IAllocator implementation.
14+
/// </summary>
15+
class Allocator : public OrtAllocator {
16+
public:
17+
explicit Allocator(const OrtMemoryInfo* memory_info, AllocatorPtr impl)
18+
: OrtAllocator{}, memory_info_(memory_info), impl_(impl) {
19+
version = ORT_API_VERSION;
20+
Alloc = AllocImpl;
21+
Free = FreeImpl;
22+
Info = InfoImpl;
23+
}
24+
25+
private:
26+
static void* ORT_API_CALL AllocImpl(OrtAllocator* this_ptr, size_t size) noexcept {
27+
auto* allocator = static_cast<Allocator*>(this_ptr);
28+
return allocator->impl_->Alloc(size);
29+
}
30+
31+
static void ORT_API_CALL FreeImpl(OrtAllocator* this_ptr, void* p) noexcept {
32+
auto* allocator = static_cast<Allocator*>(this_ptr);
33+
allocator->impl_->Free(p);
34+
}
35+
36+
static const OrtMemoryInfo* ORT_API_CALL InfoImpl(const OrtAllocator* this_ptr) noexcept {
37+
auto* allocator = static_cast<const Allocator*>(this_ptr);
38+
return allocator->memory_info_;
39+
}
40+
41+
const OrtMemoryInfo* memory_info_;
42+
AllocatorPtr impl_;
43+
};
44+
45+
} // namespace adapter
46+
} // namespace ep
47+
} // namespace onnxruntime
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED)
7+
#error "This header should not be included directly. Include ep/_pch.h instead."
8+
#endif
9+
10+
#include "core/common/status.h"
11+
#include "core/common/common.h"
12+
#include "core/framework/data_transfer.h"
13+
#include "core/framework/tensor.h"
14+
15+
namespace onnxruntime {
16+
namespace ep {
17+
namespace adapter {
18+
19+
/// <summary>
20+
/// An adapter class partially implementing the facade of `onnxruntime::DataTransferManager`.
21+
/// </summary>
22+
struct DataTransferManager {
23+
explicit DataTransferManager(std::unique_ptr<IDataTransfer> impl) : impl_{std::move(impl)} {}
24+
25+
common::Status CopyTensor(const Tensor& src, Tensor& dst) const {
26+
if (src.Shape().Size() != dst.Shape().Size()) {
27+
return ORT_MAKE_STATUS(ONNXRUNTIME,
28+
FAIL,
29+
"Tensor size mismatch: source tensor size is ",
30+
src.Shape().Size(),
31+
", destination tensor size is ",
32+
dst.Shape().Size());
33+
}
34+
35+
if (impl_->CanCopy(src.Location().device, dst.Location().device)) {
36+
return impl_->CopyTensor(src, dst);
37+
}
38+
39+
return ORT_MAKE_STATUS(ONNXRUNTIME,
40+
FAIL,
41+
"There's no data transfer registered for copying tensors from ",
42+
src.Location().device.ToString(),
43+
" to ",
44+
dst.Location().device.ToString());
45+
}
46+
47+
private:
48+
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(DataTransferManager);
49+
std::unique_ptr<IDataTransfer> impl_;
50+
};
51+
52+
} // namespace adapter
53+
} // namespace ep
54+
} // namespace onnxruntime
Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED)
7+
#error "This header should not be included directly. Include ep/_pch.h instead."
8+
#endif
9+
10+
#include "data_transfer_manager.h"
11+
12+
#include "core/framework/execution_provider.h"
13+
14+
namespace onnxruntime {
15+
namespace ep {
16+
namespace adapter {
17+
18+
/// <summary>
19+
/// Wrapper around IExecutionProvider to expose via OrtEp.
20+
/// </summary>
21+
class Ep : public OrtEp {
22+
protected:
23+
explicit Ep(IExecutionProvider* impl, AllocatorPtr temp_space_cpu_allocator, AllocatorPtr temp_space_allocator)
24+
: OrtEp{},
25+
impl_(impl),
26+
data_transfer_manager_{impl->GetDataTransfer()},
27+
profiler_{impl->GetProfiler()},
28+
temp_space_cpu_allocator_{temp_space_cpu_allocator},
29+
temp_space_allocator_{temp_space_allocator} {
30+
}
31+
32+
public:
33+
inline IExecutionProvider* EpImpl() const noexcept {
34+
return impl_.get();
35+
}
36+
inline const DataTransferManager& GetDataTransferManager() const noexcept {
37+
return data_transfer_manager_;
38+
}
39+
[[nodiscard]] Status GetTempSpaceCPUAllocator(AllocatorPtr* output) const {
40+
*output = temp_space_cpu_allocator_;
41+
return Status::OK();
42+
}
43+
[[nodiscard]] Status GetTempSpaceAllocator(AllocatorPtr* output) const {
44+
*output = temp_space_allocator_;
45+
return Status::OK();
46+
}
47+
48+
private:
49+
std::unique_ptr<IExecutionProvider> impl_;
50+
DataTransferManager data_transfer_manager_;
51+
std::unique_ptr<profiling::EpProfiler> profiler_;
52+
AllocatorPtr temp_space_cpu_allocator_;
53+
AllocatorPtr temp_space_allocator_;
54+
};
55+
56+
} // namespace adapter
57+
} // namespace ep
58+
} // namespace onnxruntime
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright (c) Microsoft Corporation. All rights reserved.
2+
// Licensed under the MIT License.
3+
4+
#pragma once
5+
6+
#if !defined(ORT_EP_API_ADAPTER_HEADER_INCLUDED)
7+
#error "This header should not be included directly. Include ep/_pch.h instead."
8+
#endif
9+
10+
#include <memory>
11+
12+
namespace onnxruntime {
13+
namespace ep {
14+
namespace adapter {
15+
16+
/// <summary>
17+
/// An adapter class partially implementing the facade of `onnxruntime::KernelDef`.
18+
/// </summary>
19+
class KernelDef {
20+
public:
21+
explicit KernelDef(const OrtKernelInfo* kernel_info) : kernel_info_{kernel_info} {}
22+
23+
const std::string OpName() const {
24+
return kernel_info_.GetNodeName();
25+
}
26+
27+
const std::string Domain() const {
28+
return kernel_info_.GetOperatorDomain();
29+
}
30+
31+
private:
32+
const Ort::ConstKernelInfo kernel_info_;
33+
};
34+
35+
} // namespace adapter
36+
} // namespace ep
37+
} // namespace onnxruntime

0 commit comments

Comments
 (0)