Skip to content

Commit 4f595c0

Browse files
Groverkssclaude
andauthored
Address code review issues (#26)
- Clean up constants: move kEpVendor/kEpVersion to .cc file (only used there), delete unused VendorIds namespace, keep kEpVendorId in header (used across 3 files). - Defer device enumeration: move driver/device enumeration from constructor (CreateIreeHwDevices) into GetSupportedDevicesImpl with lazy init on first call. Create OrtHardwareDevice and MemoryInfo together in one pass. Remove CreateIreeHwDevices method. - Remove dead CPU device type check: replace device_type check (always GPU) with backend-based validation — require target_arch for hip, cuda, vulkan backends; make it optional for llvm-cpu. - Fix CPU detection in data transfer: use MemoryDevice_GetDeviceType instead of unreliable vendor_id == 0 check, since ORT assigns real vendor IDs to CPU devices. Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 58cc82c commit 4f595c0

File tree

3 files changed

+107
-139
lines changed

3 files changed

+107
-139
lines changed

src/iree_data_transfer.cc

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,24 +40,23 @@ bool ORT_API_CALL IreeDataTransfer::CanCopyImpl(
4040
const OrtMemoryDevice* dst_memory_device) noexcept {
4141
const auto* self = static_cast<const IreeDataTransfer*>(this_ptr);
4242

43-
// Get vendor IDs from both memory devices.
4443
uint32_t src_vendor_id =
4544
self->factory_.ep_api.MemoryDevice_GetVendorId(src_memory_device);
4645
uint32_t dst_vendor_id =
4746
self->factory_.ep_api.MemoryDevice_GetVendorId(dst_memory_device);
47+
OrtMemoryInfoDeviceType src_type =
48+
self->factory_.ep_api.MemoryDevice_GetDeviceType(src_memory_device);
49+
OrtMemoryInfoDeviceType dst_type =
50+
self->factory_.ep_api.MemoryDevice_GetDeviceType(dst_memory_device);
4851

49-
// We can copy if either source or destination is our IREE device,
50-
// and the other is either our IREE device or CPU (vendor_id == 0).
51-
// TODO: Is the vendor id check for CPU actually correct? Maybe we can check
52-
// the hardware device for and determine if it's HOST/CPU.
5352
bool src_is_iree = (src_vendor_id == kEpVendorId);
5453
bool dst_is_iree = (dst_vendor_id == kEpVendorId);
55-
bool src_is_cpu = (src_vendor_id == 0);
56-
bool dst_is_cpu = (dst_vendor_id == 0);
54+
bool src_is_cpu = (src_type == OrtMemoryInfoDeviceType_CPU);
55+
bool dst_is_cpu = (dst_type == OrtMemoryInfoDeviceType_CPU);
5756

5857
// Supported transfers:
5958
// - IREE <-> CPU (H2D, D2H)
60-
// - IREE <-> IREE (D2D on same device)
59+
// - IREE <-> IREE (D2D)
6160
return (src_is_iree && (dst_is_iree || dst_is_cpu)) ||
6261
(dst_is_iree && (src_is_iree || src_is_cpu));
6362
}

src/iree_ep_factory.cc

Lines changed: 99 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
#include "iree_ep_factory.h"
1414

15-
#include <algorithm>
1615
#include <memory>
1716
#include <mutex>
1817

@@ -24,6 +23,9 @@ namespace onnxruntime::iree {
2423

2524
namespace {
2625

26+
constexpr const char* kEpVendor = "IREE";
27+
constexpr const char* kEpVersion = "0.1.0";
28+
2729
// Stub kernel — never instantiated or executed. Exists only to satisfy the
2830
// CustomOpBase<TOp, TKernel> template requirements.
2931
struct ExternDispatchKernel {
@@ -110,89 +112,6 @@ IreeEpFactory::IreeEpFactory(const char* ep_name, ApiPtrs apis,
110112
iree_status_ignore(status);
111113
return;
112114
}
113-
114-
CreateIreeHwDevices();
115-
}
116-
117-
void IreeEpFactory::CreateIreeHwDevices() {
118-
iree_allocator_t allocator = iree_allocator_system();
119-
iree_hal_driver_registry_t* registry =
120-
iree_runtime_instance_driver_registry(instance_.Get());
121-
122-
iree_host_size_t driver_count = 0;
123-
IreeAllocatedPtr<iree_hal_driver_info_t> driver_infos(allocator);
124-
iree_status_t status = iree_hal_driver_registry_enumerate(
125-
registry, allocator, &driver_count, driver_infos.ForOutput());
126-
if (!iree_status_is_ok(status)) {
127-
ORT_CXX_LOG_NOEXCEPT(logger_, ORT_LOGGING_LEVEL_WARNING,
128-
"IREE EP: Failed to enumerate drivers");
129-
iree_status_ignore(status);
130-
return;
131-
}
132-
133-
size_t device_index = 0;
134-
for (iree_host_size_t i = 0; i < driver_count; ++i) {
135-
iree_string_view_t driver_name = driver_infos.Get()[i].driver_name;
136-
std::string driver_name_str(driver_name.data, driver_name.size);
137-
138-
HalDriverPtr driver;
139-
status = iree_hal_driver_registry_try_create(registry, driver_name,
140-
allocator, driver.ForOutput());
141-
if (!iree_status_is_ok(status)) {
142-
iree_status_ignore(status);
143-
continue;
144-
}
145-
146-
iree_host_size_t device_count = 0;
147-
IreeAllocatedPtr<iree_hal_device_info_t> device_infos(allocator);
148-
status = iree_hal_driver_query_available_devices(
149-
driver.Get(), allocator, &device_count, device_infos.ForOutput());
150-
if (!iree_status_is_ok(status)) {
151-
iree_status_ignore(status);
152-
continue;
153-
}
154-
155-
// Create OrtHardwareDevice for each IREE device available.
156-
//
157-
// Given an OrtHardwareDevice, ORT can find which hardware device it is by
158-
// checking its vendor id (should be kEpVendorId) and the device id (the
159-
// index it's stored in hw_devices_).
160-
//
161-
// IREE can identify the device using the driver and device path stored in
162-
// the hardware device metadata.
163-
for (iree_host_size_t j = 0; j < device_count; ++j) {
164-
iree_string_view_t device_path = device_infos.Get()[j].path;
165-
std::string device_path_str(device_path.data, device_path.size);
166-
167-
OrtKeyValuePairs* hw_metadata = nullptr;
168-
ort_api.CreateKeyValuePairs(&hw_metadata);
169-
ort_api.AddKeyValuePair(hw_metadata, "iree.driver",
170-
driver_name_str.c_str());
171-
ort_api.AddKeyValuePair(hw_metadata, "iree.device_path",
172-
device_path_str.c_str());
173-
// Store device_id for later retrieval in CreateEpImpl.
174-
ort_api.AddKeyValuePair(hw_metadata, "iree.device_id",
175-
std::to_string(device_index).c_str());
176-
177-
// TODO: We are pretending here that all IREE devices are GPUs. This is
178-
// because we follow a host <-> device model in IREE. How does this
179-
// matter in practice? If we are using IREE APIs anyway, does it matter
180-
// if we set this to CPU?
181-
OrtHardwareDeviceType device_type = OrtHardwareDeviceType_GPU;
182-
183-
OrtHardwareDevice* hw_device = nullptr;
184-
OrtStatus* ort_status = ep_api.CreateHardwareDevice(
185-
device_type, kEpVendorId, static_cast<uint32_t>(device_index++),
186-
kEpVendor, hw_metadata, &hw_device);
187-
ort_api.ReleaseKeyValuePairs(hw_metadata);
188-
if (ort_status != nullptr) {
189-
ort_api.ReleaseStatus(ort_status);
190-
continue;
191-
}
192-
193-
hw_devices_.push_back(hw_device);
194-
}
195-
}
196115
}
197116

198117
IreeEpFactory::~IreeEpFactory() {
@@ -309,44 +228,109 @@ OrtStatus* ORT_API_CALL IreeEpFactory::GetSupportedDevicesImpl(
309228
size_t& num_ep_devices = *p_num_ep_devices;
310229
num_ep_devices = 0;
311230

312-
// Reserve capacity for memory_info objects to prevent reallocation during
313-
// the loop. This is critical because EpDevice_AddAllocatorInfo stores
314-
// pointers and reallocation would invalidate them.
315-
size_t num_devices_to_add =
316-
std::min(factory->hw_devices_.size(), max_ep_devices);
317-
factory->device_memory_infos_.reserve(num_devices_to_add);
231+
// Lazily enumerate IREE devices on first call.
232+
if (factory->hw_devices_.empty()) {
233+
iree_allocator_t allocator = iree_allocator_system();
234+
iree_hal_driver_registry_t* registry =
235+
iree_runtime_instance_driver_registry(factory->instance_.Get());
318236

319-
// Create OrtEpDevice for each hardware device enumerated in constructor.
237+
iree_host_size_t driver_count = 0;
238+
IreeAllocatedPtr<iree_hal_driver_info_t> driver_infos(allocator);
239+
iree_status_t status = iree_hal_driver_registry_enumerate(
240+
registry, allocator, &driver_count, driver_infos.ForOutput());
241+
if (!iree_status_is_ok(status)) {
242+
ORT_CXX_LOG_NOEXCEPT(factory->logger_, ORT_LOGGING_LEVEL_WARNING,
243+
"IREE EP: Failed to enumerate drivers");
244+
iree_status_ignore(status);
245+
return nullptr;
246+
}
247+
248+
size_t device_index = 0;
249+
for (iree_host_size_t i = 0; i < driver_count; ++i) {
250+
iree_string_view_t driver_name = driver_infos.Get()[i].driver_name;
251+
std::string driver_name_str(driver_name.data, driver_name.size);
252+
253+
HalDriverPtr driver;
254+
status = iree_hal_driver_registry_try_create(
255+
registry, driver_name, allocator, driver.ForOutput());
256+
if (!iree_status_is_ok(status)) {
257+
iree_status_ignore(status);
258+
continue;
259+
}
260+
261+
iree_host_size_t device_count = 0;
262+
IreeAllocatedPtr<iree_hal_device_info_t> device_infos(allocator);
263+
status = iree_hal_driver_query_available_devices(
264+
driver.Get(), allocator, &device_count, device_infos.ForOutput());
265+
if (!iree_status_is_ok(status)) {
266+
iree_status_ignore(status);
267+
continue;
268+
}
269+
270+
// Create OrtHardwareDevice and MemoryInfo for each IREE device.
271+
//
272+
// Given an OrtHardwareDevice, ORT identifies it by vendor_id
273+
// (kEpVendorId) and device_id (index in hw_devices_). IREE identifies
274+
// the device using the driver and device path in metadata.
275+
for (iree_host_size_t j = 0; j < device_count; ++j) {
276+
iree_string_view_t device_path = device_infos.Get()[j].path;
277+
std::string device_path_str(device_path.data, device_path.size);
278+
279+
OrtKeyValuePairs* hw_metadata = nullptr;
280+
factory->ort_api.CreateKeyValuePairs(&hw_metadata);
281+
factory->ort_api.AddKeyValuePair(hw_metadata, "iree.driver",
282+
driver_name_str.c_str());
283+
factory->ort_api.AddKeyValuePair(hw_metadata, "iree.device_path",
284+
device_path_str.c_str());
285+
factory->ort_api.AddKeyValuePair(hw_metadata, "iree.device_id",
286+
std::to_string(device_index).c_str());
287+
288+
// TODO: We are pretending here that all IREE devices are GPUs.
289+
// This is because we follow a host <-> device model in IREE.
290+
OrtHardwareDeviceType device_type = OrtHardwareDeviceType_GPU;
291+
292+
OrtHardwareDevice* hw_device = nullptr;
293+
OrtStatus* ort_status = factory->ep_api.CreateHardwareDevice(
294+
device_type, kEpVendorId, static_cast<uint32_t>(device_index),
295+
kEpVendor, hw_metadata, &hw_device);
296+
factory->ort_api.ReleaseKeyValuePairs(hw_metadata);
297+
if (ort_status != nullptr) {
298+
factory->ort_api.ReleaseStatus(ort_status);
299+
continue;
300+
}
301+
302+
factory->hw_devices_.push_back(hw_device);
303+
304+
// Create MemoryInfo for device-local memory. These must stay alive
305+
// for the factory lifetime since EpDevice_AddAllocatorInfo stores
306+
// the pointer without copying.
307+
factory->device_memory_infos_.emplace_back(
308+
"IREE", // name
309+
OrtMemoryInfoDeviceType_GPU, // device_type
310+
kEpVendorId, // vendor_id
311+
static_cast<uint32_t>(device_index), // device_id
312+
OrtDeviceMemoryType_DEFAULT, // mem_type
313+
0, // alignment (default)
314+
OrtDeviceAllocator); // allocator_type
315+
316+
++device_index;
317+
}
318+
}
319+
}
320+
321+
// Create OrtEpDevice for each enumerated hardware device.
322+
// ORT takes ownership of OrtEpDevice objects.
320323
for (size_t i = 0;
321324
i < factory->hw_devices_.size() && num_ep_devices < max_ep_devices;
322325
++i) {
323326
OrtEpDevice* ep_device = nullptr;
324327
ORT_RETURN_IF_ERROR(factory->ep_api.CreateEpDevice(
325328
factory, factory->hw_devices_[i], nullptr, nullptr, &ep_device));
326329

327-
// Register allocator info for device-local memory.
328-
// This tells ORT that this device has an allocator available for
329-
// device-local memory allocations.
330-
//
331-
// IMPORTANT: EpDevice_AddAllocatorInfo does NOT copy the OrtMemoryInfo,
332-
// it just stores the pointer. We must keep the memory_info alive for
333-
// the lifetime of the factory by storing it in device_memory_infos_.
334-
factory->device_memory_infos_.emplace_back(
335-
"IREE", // name
336-
OrtMemoryInfoDeviceType_GPU, // device_type
337-
kEpVendorId, // vendor_id
338-
static_cast<uint32_t>(i), // device_id
339-
OrtDeviceMemoryType_DEFAULT, // mem_type
340-
0, // alignment (default)
341-
OrtDeviceAllocator); // allocator_type
342-
343-
// Get raw pointer to pass to ORT.
344-
const OrtMemoryInfo* mem_info_ptr = factory->device_memory_infos_.back();
345-
330+
const OrtMemoryInfo* mem_info_ptr = factory->device_memory_infos_[i];
346331
OrtStatus* add_alloc_status =
347332
factory->ep_api.EpDevice_AddAllocatorInfo(ep_device, mem_info_ptr);
348333
if (add_alloc_status != nullptr) {
349-
factory->device_memory_infos_.pop_back();
350334
factory->ep_api.ReleaseEpDevice(ep_device);
351335
return add_alloc_status;
352336
}
@@ -439,13 +423,11 @@ OrtStatus* ORT_API_CALL IreeEpFactory::CreateEpImpl(
439423
.release();
440424
}
441425

442-
// Require target_arch for non-CPU devices.
443-
OrtHardwareDeviceType device_type =
444-
factory->ort_api.HardwareDevice_Type(&hardware_device);
445-
if (device_type != OrtHardwareDeviceType_CPU && config.target_arch.empty()) {
426+
// Require target_arch for GPU backends (llvm-cpu defaults to host).
427+
if (config.backend != "llvm-cpu" && config.target_arch.empty()) {
446428
return Ort::Status(
447-
"IREE EP: 'target_arch' option must be specified for non-CPU "
448-
"devices",
429+
"IREE EP: 'target_arch' option must be specified for "
430+
"hip, cuda, and vulkan backends",
449431
ORT_INVALID_ARGUMENT)
450432
.release();
451433
}

src/iree_ep_factory.h

Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,8 @@ namespace onnxruntime::iree {
2626
class IreeAllocator;
2727
class IreeDataTransfer;
2828

29-
// EP configuration constants
30-
inline constexpr const char* kEpVendor = "IREE";
29+
// EP vendor ID used across the codebase to identify IREE devices.
3130
inline constexpr uint32_t kEpVendorId = 0x1EEE; // "IREE" in hex-ish
32-
inline constexpr const char* kEpVersion = "0.1.0";
33-
34-
// Hardware vendor IDs for device matching.
35-
// These match OrtDevice::VendorIds from onnxruntime/core/framework/ortdevice.h.
36-
namespace VendorIds {
37-
inline constexpr uint32_t kAmd = 0x1002; // AMD: ROCm, MIGraphX EPs
38-
inline constexpr uint32_t kNvidia = 0x10DE; // NVIDIA: CUDA/TensorRT
39-
inline constexpr uint32_t kIntel = 0x8086; // Intel: OpenVINO
40-
} // namespace VendorIds
4131

4232
// Helper struct to pass API pointers
4333
struct ApiPtrs {
@@ -121,9 +111,6 @@ class IreeEpFactory : public OrtEpFactory, public ApiPtrs {
121111
GetCustomOpDomainsImpl(OrtEpFactory* this_ptr, OrtCustomOpDomain** domains,
122112
size_t num_domains) noexcept;
123113

124-
// Enumerate IREE devices and create OrtHardwareDevice instances.
125-
void CreateIreeHwDevices();
126-
127114
// Member variables
128115
Ort::Logger logger_;
129116
const std::string ep_name_;

0 commit comments

Comments
 (0)