1212
1313#include " iree_ep_factory.h"
1414
15- #include < algorithm>
1615#include < memory>
1716#include < mutex>
1817
@@ -24,6 +23,9 @@ namespace onnxruntime::iree {
2423
2524namespace {
2625
26+ constexpr const char * kEpVendor = " IREE" ;
27+ constexpr const char * kEpVersion = " 0.1.0" ;
28+
2729// Stub kernel — never instantiated or executed. Exists only to satisfy the
2830// CustomOpBase<TOp, TKernel> template requirements.
2931struct ExternDispatchKernel {
@@ -110,89 +112,6 @@ IreeEpFactory::IreeEpFactory(const char* ep_name, ApiPtrs apis,
110112 iree_status_ignore (status);
111113 return ;
112114 }
113-
114- CreateIreeHwDevices ();
115- }
116-
117- void IreeEpFactory::CreateIreeHwDevices () {
118- iree_allocator_t allocator = iree_allocator_system ();
119- iree_hal_driver_registry_t * registry =
120- iree_runtime_instance_driver_registry (instance_.Get ());
121-
122- iree_host_size_t driver_count = 0 ;
123- IreeAllocatedPtr<iree_hal_driver_info_t > driver_infos (allocator);
124- iree_status_t status = iree_hal_driver_registry_enumerate (
125- registry, allocator, &driver_count, driver_infos.ForOutput ());
126- if (!iree_status_is_ok (status)) {
127- ORT_CXX_LOG_NOEXCEPT (logger_, ORT_LOGGING_LEVEL_WARNING,
128- " IREE EP: Failed to enumerate drivers" );
129- iree_status_ignore (status);
130- return ;
131- }
132-
133- size_t device_index = 0 ;
134- for (iree_host_size_t i = 0 ; i < driver_count; ++i) {
135- iree_string_view_t driver_name = driver_infos.Get ()[i].driver_name ;
136- std::string driver_name_str (driver_name.data , driver_name.size );
137-
138- HalDriverPtr driver;
139- status = iree_hal_driver_registry_try_create (registry, driver_name,
140- allocator, driver.ForOutput ());
141- if (!iree_status_is_ok (status)) {
142- iree_status_ignore (status);
143- continue ;
144- }
145-
146- iree_host_size_t device_count = 0 ;
147- IreeAllocatedPtr<iree_hal_device_info_t > device_infos (allocator);
148- status = iree_hal_driver_query_available_devices (
149- driver.Get (), allocator, &device_count, device_infos.ForOutput ());
150- if (!iree_status_is_ok (status)) {
151- iree_status_ignore (status);
152- continue ;
153- }
154-
155- // Create OrtHardwareDevice for each IREE device available.
156- //
157- // Given an OrtHardwareDevice, ORT can find which hardware device it is by
158- // checking its vendor id (should be kEpVendorId) and the device id (the
159- // index it's stored in hw_devices_).
160- //
161- // IREE can identify the device using the driver and device path stored in
162- // the hardware device metadata.
163- for (iree_host_size_t j = 0 ; j < device_count; ++j) {
164- iree_string_view_t device_path = device_infos.Get ()[j].path ;
165- std::string device_path_str (device_path.data , device_path.size );
166-
167- OrtKeyValuePairs* hw_metadata = nullptr ;
168- ort_api.CreateKeyValuePairs (&hw_metadata);
169- ort_api.AddKeyValuePair (hw_metadata, " iree.driver" ,
170- driver_name_str.c_str ());
171- ort_api.AddKeyValuePair (hw_metadata, " iree.device_path" ,
172- device_path_str.c_str ());
173- // Store device_id for later retrieval in CreateEpImpl.
174- ort_api.AddKeyValuePair (hw_metadata, " iree.device_id" ,
175- std::to_string (device_index).c_str ());
176-
177- // TODO: We are pretending here that all IREE devices are GPUs. This is
178- // because we follow a host <-> device model in IREE. How does this
179- // matter in practice? If we are using IREE APIs anyway, does it matter
180- // if we set this to CPU?
181- OrtHardwareDeviceType device_type = OrtHardwareDeviceType_GPU;
182-
183- OrtHardwareDevice* hw_device = nullptr ;
184- OrtStatus* ort_status = ep_api.CreateHardwareDevice (
185- device_type, kEpVendorId , static_cast <uint32_t >(device_index++),
186- kEpVendor , hw_metadata, &hw_device);
187- ort_api.ReleaseKeyValuePairs (hw_metadata);
188- if (ort_status != nullptr ) {
189- ort_api.ReleaseStatus (ort_status);
190- continue ;
191- }
192-
193- hw_devices_.push_back (hw_device);
194- }
195- }
196115}
197116
198117IreeEpFactory::~IreeEpFactory () {
@@ -309,44 +228,109 @@ OrtStatus* ORT_API_CALL IreeEpFactory::GetSupportedDevicesImpl(
309228 size_t & num_ep_devices = *p_num_ep_devices;
310229 num_ep_devices = 0 ;
311230
312- // Reserve capacity for memory_info objects to prevent reallocation during
313- // the loop. This is critical because EpDevice_AddAllocatorInfo stores
314- // pointers and reallocation would invalidate them.
315- size_t num_devices_to_add =
316- std::min (factory->hw_devices_ .size (), max_ep_devices);
317- factory->device_memory_infos_ .reserve (num_devices_to_add);
231+ // Lazily enumerate IREE devices on first call.
232+ if (factory->hw_devices_ .empty ()) {
233+ iree_allocator_t allocator = iree_allocator_system ();
234+ iree_hal_driver_registry_t * registry =
235+ iree_runtime_instance_driver_registry (factory->instance_ .Get ());
318236
319- // Create OrtEpDevice for each hardware device enumerated in constructor.
237+ iree_host_size_t driver_count = 0 ;
238+ IreeAllocatedPtr<iree_hal_driver_info_t > driver_infos (allocator);
239+ iree_status_t status = iree_hal_driver_registry_enumerate (
240+ registry, allocator, &driver_count, driver_infos.ForOutput ());
241+ if (!iree_status_is_ok (status)) {
242+ ORT_CXX_LOG_NOEXCEPT (factory->logger_ , ORT_LOGGING_LEVEL_WARNING,
243+ " IREE EP: Failed to enumerate drivers" );
244+ iree_status_ignore (status);
245+ return nullptr ;
246+ }
247+
248+ size_t device_index = 0 ;
249+ for (iree_host_size_t i = 0 ; i < driver_count; ++i) {
250+ iree_string_view_t driver_name = driver_infos.Get ()[i].driver_name ;
251+ std::string driver_name_str (driver_name.data , driver_name.size );
252+
253+ HalDriverPtr driver;
254+ status = iree_hal_driver_registry_try_create (
255+ registry, driver_name, allocator, driver.ForOutput ());
256+ if (!iree_status_is_ok (status)) {
257+ iree_status_ignore (status);
258+ continue ;
259+ }
260+
261+ iree_host_size_t device_count = 0 ;
262+ IreeAllocatedPtr<iree_hal_device_info_t > device_infos (allocator);
263+ status = iree_hal_driver_query_available_devices (
264+ driver.Get (), allocator, &device_count, device_infos.ForOutput ());
265+ if (!iree_status_is_ok (status)) {
266+ iree_status_ignore (status);
267+ continue ;
268+ }
269+
270+ // Create OrtHardwareDevice and MemoryInfo for each IREE device.
271+ //
272+ // Given an OrtHardwareDevice, ORT identifies it by vendor_id
273+ // (kEpVendorId) and device_id (index in hw_devices_). IREE identifies
274+ // the device using the driver and device path in metadata.
275+ for (iree_host_size_t j = 0 ; j < device_count; ++j) {
276+ iree_string_view_t device_path = device_infos.Get ()[j].path ;
277+ std::string device_path_str (device_path.data , device_path.size );
278+
279+ OrtKeyValuePairs* hw_metadata = nullptr ;
280+ factory->ort_api .CreateKeyValuePairs (&hw_metadata);
281+ factory->ort_api .AddKeyValuePair (hw_metadata, " iree.driver" ,
282+ driver_name_str.c_str ());
283+ factory->ort_api .AddKeyValuePair (hw_metadata, " iree.device_path" ,
284+ device_path_str.c_str ());
285+ factory->ort_api .AddKeyValuePair (hw_metadata, " iree.device_id" ,
286+ std::to_string (device_index).c_str ());
287+
288+ // TODO: We are pretending here that all IREE devices are GPUs.
289+ // This is because we follow a host <-> device model in IREE.
290+ OrtHardwareDeviceType device_type = OrtHardwareDeviceType_GPU;
291+
292+ OrtHardwareDevice* hw_device = nullptr ;
293+ OrtStatus* ort_status = factory->ep_api .CreateHardwareDevice (
294+ device_type, kEpVendorId , static_cast <uint32_t >(device_index),
295+ kEpVendor , hw_metadata, &hw_device);
296+ factory->ort_api .ReleaseKeyValuePairs (hw_metadata);
297+ if (ort_status != nullptr ) {
298+ factory->ort_api .ReleaseStatus (ort_status);
299+ continue ;
300+ }
301+
302+ factory->hw_devices_ .push_back (hw_device);
303+
304+ // Create MemoryInfo for device-local memory. These must stay alive
305+ // for the factory lifetime since EpDevice_AddAllocatorInfo stores
306+ // the pointer without copying.
307+ factory->device_memory_infos_ .emplace_back (
308+ " IREE" , // name
309+ OrtMemoryInfoDeviceType_GPU, // device_type
310+ kEpVendorId , // vendor_id
311+ static_cast <uint32_t >(device_index), // device_id
312+ OrtDeviceMemoryType_DEFAULT, // mem_type
313+ 0 , // alignment (default)
314+ OrtDeviceAllocator); // allocator_type
315+
316+ ++device_index;
317+ }
318+ }
319+ }
320+
321+ // Create OrtEpDevice for each enumerated hardware device.
322+ // ORT takes ownership of OrtEpDevice objects.
320323 for (size_t i = 0 ;
321324 i < factory->hw_devices_ .size () && num_ep_devices < max_ep_devices;
322325 ++i) {
323326 OrtEpDevice* ep_device = nullptr ;
324327 ORT_RETURN_IF_ERROR (factory->ep_api .CreateEpDevice (
325328 factory, factory->hw_devices_ [i], nullptr , nullptr , &ep_device));
326329
327- // Register allocator info for device-local memory.
328- // This tells ORT that this device has an allocator available for
329- // device-local memory allocations.
330- //
331- // IMPORTANT: EpDevice_AddAllocatorInfo does NOT copy the OrtMemoryInfo,
332- // it just stores the pointer. We must keep the memory_info alive for
333- // the lifetime of the factory by storing it in device_memory_infos_.
334- factory->device_memory_infos_ .emplace_back (
335- " IREE" , // name
336- OrtMemoryInfoDeviceType_GPU, // device_type
337- kEpVendorId , // vendor_id
338- static_cast <uint32_t >(i), // device_id
339- OrtDeviceMemoryType_DEFAULT, // mem_type
340- 0 , // alignment (default)
341- OrtDeviceAllocator); // allocator_type
342-
343- // Get raw pointer to pass to ORT.
344- const OrtMemoryInfo* mem_info_ptr = factory->device_memory_infos_ .back ();
345-
330+ const OrtMemoryInfo* mem_info_ptr = factory->device_memory_infos_ [i];
346331 OrtStatus* add_alloc_status =
347332 factory->ep_api .EpDevice_AddAllocatorInfo (ep_device, mem_info_ptr);
348333 if (add_alloc_status != nullptr ) {
349- factory->device_memory_infos_ .pop_back ();
350334 factory->ep_api .ReleaseEpDevice (ep_device);
351335 return add_alloc_status;
352336 }
@@ -439,13 +423,11 @@ OrtStatus* ORT_API_CALL IreeEpFactory::CreateEpImpl(
439423 .release ();
440424 }
441425
442- // Require target_arch for non-CPU devices.
443- OrtHardwareDeviceType device_type =
444- factory->ort_api .HardwareDevice_Type (&hardware_device);
445- if (device_type != OrtHardwareDeviceType_CPU && config.target_arch .empty ()) {
426+ // Require target_arch for GPU backends (llvm-cpu defaults to host).
427+ if (config.backend != " llvm-cpu" && config.target_arch .empty ()) {
446428 return Ort::Status (
447- " IREE EP: 'target_arch' option must be specified for non-CPU "
448- " devices " ,
429+ " IREE EP: 'target_arch' option must be specified for "
430+ " hip, cuda, and vulkan backends " ,
449431 ORT_INVALID_ARGUMENT)
450432 .release ();
451433 }
0 commit comments