Skip to content

Commit 0c60f4e

Browse files
committed
Merge branch 'microsoft:main' into melkap01_implement_mt_qgemm
2 parents 11c856c + 2f4be9a commit 0c60f4e

File tree

54 files changed

+1436
-572
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+1436
-572
lines changed

cmake/adjust_global_compile_flags.cmake

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,10 +208,8 @@ endif()
208208

209209

210210
macro(check_nvcc_compiler_flag _FLAG _RESULT)
211-
execute_process(COMMAND ${CUDAToolkit_BIN_DIR}/nvcc "${_FLAG}" RESULT_VARIABLE NVCC_OUT ERROR_VARIABLE NVCC_ERROR)
212-
message("NVCC_ERROR = ${NVCC_ERROR}")
213-
message("NVCC_OUT = ${NVCC_OUT}")
214-
if ("${NVCC_OUT}" MATCHES "0")
211+
execute_process(COMMAND ${CMAKE_CUDA_COMPILER} --compiler-options "${_FLAG}" -c ${REPO_ROOT}/cmake/empty.c -o ${CMAKE_CURRENT_BINARY_DIR}/empty.o RESULT_VARIABLE NVCC_OUT ERROR_QUIET OUTPUT_QUIET)
212+
if (NVCC_OUT EQUAL 0)
215213
set(${_RESULT} 1)
216214
else()
217215
set(${_RESULT} 0)

cmake/empty.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
// This file is used by the check_nvcc_compiler_flag macro in adjust_global_compile_flags.cmake to test nvcc compiler flags.
2+
void empty() {}

cmake/onnxruntime.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ function(get_c_cxx_api_headers HEADERS_VAR)
2828
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_ep_device_ep_metadata_keys.h"
2929
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_float16.h"
3030
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_lite_custom_op.h"
31+
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_env_config_keys.h"
3132
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_run_options_config_keys.h"
3233
"${REPO_ROOT}/include/onnxruntime/core/session/onnxruntime_session_options_config_keys.h"
3334
)

cmake/onnxruntime_providers_cuda.cmake

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,17 @@
257257
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include)
258258
target_link_libraries(${target} PRIVATE Eigen3::Eigen)
259259
target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
260+
261+
# Handle CUDA 13.0 CCCL header directory move
262+
if (CMAKE_CUDA_COMPILER_VERSION VERSION_GREATER_EQUAL 13.0)
263+
foreach(inc_dir ${CUDAToolkit_INCLUDE_DIRS})
264+
if (EXISTS "${inc_dir}/cccl")
265+
# Add the cccl subdirectory to the include path so <cuda/std/utility> can be found
266+
target_include_directories(${target} PRIVATE "${inc_dir}/cccl")
267+
endif()
268+
endforeach()
269+
endif()
270+
260271
# ${CMAKE_CURRENT_BINARY_DIR} is so that #include "onnxruntime_config.h" inside tensor_shape.h is found
261272
set_target_properties(${target} PROPERTIES LINKER_LANGUAGE CUDA)
262273
set_target_properties(${target} PROPERTIES FOLDER "ONNXRuntime")

cmake/onnxruntime_unittests.cmake

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -597,6 +597,7 @@ set (onnxruntime_shared_lib_test_SRC
597597
if (NOT onnxruntime_MINIMAL_BUILD)
598598
list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_inference.cc)
599599
list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_model_builder_api.cc)
600+
list(APPEND onnxruntime_shared_lib_test_SRC ${ONNXRUNTIME_SHARED_LIB_TEST_SRC_DIR}/test_env_creation.cc)
600601
endif()
601602

602603
if(onnxruntime_RUN_ONNX_TESTS)

docs/OperatorKernels.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,7 @@ Do not modify directly.*
653653
|ArgMin|*in* data:**T**<br> *out* reduced:**tensor(int64)**|13+|**T** = tensor(double), tensor(float), tensor(float16)|
654654
|||12|**T** = tensor(double), tensor(float), tensor(float16)|
655655
|||[1, 11]|**T** = tensor(double), tensor(float), tensor(float16)|
656+
|Attention|*in* Q:**T1**<br> *in* K:**T1**<br> *in* V:**T2**<br> *in* attn_mask:**U**<br> *in* past_key:**T1**<br> *in* past_value:**T2**<br> *in* nonpad_kv_seqlen:**tensor(int64)**<br> *out* Y:**T1**<br> *out* present_key:**T1**<br> *out* present_value:**T2**<br> *out* qk_matmul_output:**T1**<br><br>or<br><br>*in* Q:**T1**<br> *in* K:**T1**<br> *in* V:**T2**<br> *in* attn_mask:**U**<br> *in* past_key:**T1**<br> *in* past_value:**T2**<br> *out* Y:**T1**<br> *out* present_key:**T1**<br> *out* present_value:**T2**<br> *out* qk_matmul_output:**T1**|23+|**T1** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **T2** = tensor(bfloat16), tensor(float), tensor(float16)<br/> **U** = tensor(bfloat16), tensor(bool), tensor(float), tensor(float16)|
656657
|AveragePool|*in* X:**T**<br> *out* Y:**T**|22+|**T** = tensor(bfloat16), tensor(double), tensor(float), tensor(float16)|
657658
|||[19, 21]|**T** = tensor(double), tensor(float), tensor(float16)|
658659
|||[11, 18]|**T** = tensor(double), tensor(float), tensor(float16)|

include/onnxruntime/core/session/environment.h

Lines changed: 39 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
#include <filesystem>
88
#include <memory>
99
#include <vector>
10+
#include <shared_mutex>
1011
#include <string>
1112

1213
#include "core/common/common.h"
@@ -20,6 +21,7 @@
2021
#include "core/platform/threadpool.h"
2122

2223
#include "core/session/abi_devices.h"
24+
#include "core/session/abi_key_value_pairs.h"
2325
#include "core/session/plugin_ep/ep_library.h"
2426
#include "core/session/onnxruntime_c_api.h"
2527

@@ -51,11 +53,13 @@ class Environment {
5153
@param tp_options optional set of parameters controlling the number of intra and inter op threads for the global
5254
threadpools.
5355
@param create_global_thread_pools determine if this function will create the global threadpools or not.
56+
@param config_entries Application-specified configuration entries.
5457
*/
5558
static Status Create(std::unique_ptr<logging::LoggingManager> logging_manager,
5659
std::unique_ptr<Environment>& environment,
5760
const OrtThreadingOptions* tp_options = nullptr,
58-
bool create_global_thread_pools = false);
61+
bool create_global_thread_pools = false,
62+
const OrtKeyValuePairs* config_entries = nullptr);
5963

6064
/**
6165
* Set the global threading options for the environment, if no global thread pools have been created yet.
@@ -170,14 +174,26 @@ class Environment {
170174
// return a shared allocator from a plugin EP or custom allocator added with RegisterAllocator
171175
Status GetSharedAllocator(const OrtMemoryInfo& mem_info, OrtAllocator*& allocator);
172176

177+
/// <summary>
178+
/// Returns a copy of the configuration entries set by the application on environment creation.
179+
///
180+
/// Primarily used by EP libraries to retrieve environment-level configurations, but could be used
181+
/// more generally to specify global settings.
182+
///
183+
/// Refer to OrtApi::CreateEnvWithOptions().
184+
/// </summary>
185+
/// <returns></returns>
186+
OrtKeyValuePairs GetConfigEntries() const;
187+
173188
~Environment();
174189

175190
private:
176191
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Environment);
177192

178193
Status Initialize(std::unique_ptr<logging::LoggingManager> logging_manager,
179194
const OrtThreadingOptions* tp_options = nullptr,
180-
bool create_global_thread_pools = false);
195+
bool create_global_thread_pools = false,
196+
const OrtKeyValuePairs* config_entries = nullptr);
181197

182198
Status RegisterAllocatorImpl(AllocatorPtr allocator);
183199
Status UnregisterAllocatorImpl(const OrtMemoryInfo& mem_info, bool error_if_not_found = true);
@@ -186,6 +202,13 @@ class Environment {
186202
const OrtKeyValuePairs* allocator_options, OrtAllocator** allocator,
187203
bool replace_existing);
188204

205+
// Inserts (or assigns) a config entry into `config_entries_`. Locks `config_entries_mutex_`.
206+
void InsertOrAssignConfigEntry(std::string key, std::string value);
207+
208+
// Removes a config entry from `config_entries_`. Does nothing if the key does not exist.
209+
// Locks `config_entries_mutex_`.
210+
void RemoveConfigEntry(const std::string& key);
211+
189212
std::unique_ptr<logging::LoggingManager> logging_manager_;
190213
std::unique_ptr<onnxruntime::concurrency::ThreadPool> intra_op_thread_pool_;
191214
std::unique_ptr<onnxruntime::concurrency::ThreadPool> inter_op_thread_pool_;
@@ -254,6 +277,20 @@ class Environment {
254277
DataTransferManager data_transfer_mgr_; // plugin EP IDataTransfer instances
255278

256279
#endif // !defined(ORT_MINIMAL_BUILD)
280+
281+
// Application-specified environment configuration entries
282+
// The environment may add or remove an entry on EP library registration and unregistration, respectively.
283+
OrtKeyValuePairs config_entries_;
284+
mutable std::shared_mutex config_entries_mutex_; // Should be locked when accessing config_entries_
285+
286+
// Tracks the number of registered EP libraries that can create virtual devices.
287+
// It is incremented when an EP library is registered with a name that ends in ".virtual".
288+
// It is decremented when that EP library is unregistered.
289+
// If it reaches 0, the config entry "allow_virtual_devices" is removed.
290+
//
291+
// This starts at 1 if user created an OrtEnv with the config "allow_virtual_devices" set to "1"
292+
// to prevent removal of the config entry in that case.
293+
size_t num_allow_virtual_device_uses_{};
257294
};
258295

259296
} // namespace onnxruntime

include/onnxruntime/core/session/onnxruntime_c_api.h

Lines changed: 109 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -965,14 +965,6 @@ typedef OrtStatus*(ORT_API_CALL* RegisterCustomOpsFn)(OrtSessionOptions* options
965965
*/
966966
typedef void (*RunAsyncCallbackFn)(void* user_data, OrtValue** outputs, size_t num_outputs, OrtStatusPtr status);
967967

968-
/** \brief The C API
969-
*
970-
* All C API functions are defined inside this structure as pointers to functions.
971-
* Call OrtApiBase::GetApi to get a pointer to it
972-
*
973-
* \nosubgrouping
974-
*/
975-
976968
/** \addtogroup Global
977969
* @{
978970
*/
@@ -1056,6 +1048,101 @@ typedef enum OrtCompiledModelCompatibility {
10561048
OrtCompiledModelCompatibility_EP_UNSUPPORTED,
10571049
} OrtCompiledModelCompatibility;
10581050

1051+
/** \brief Configuration options for creating an OrtEnv.
1052+
*
1053+
* \note The version field must be set to ORT_API_VERSION.
1054+
* This ensures forward compatibility as fields may be added in future versions.
1055+
*
1056+
* \since Version 1.24.
1057+
*/
1058+
typedef struct OrtEnvCreationOptions {
1059+
uint32_t version; ///< Must be set to ORT_API_VERSION
1060+
1061+
/** \brief The logging severity level for the environment. Must be set to a value from OrtLoggingLevel.
1062+
*
1063+
* \note Logging messages which are less severe than the `logging_severity_level` are not emitted.
1064+
*
1065+
* \note Serves as the default logging severity level for session creation and runs.
1066+
* Use ::SetSessionLogSeverityLevel() to set a logging severity level for the creation of specific session.
1067+
* Use ::RunOptionsSetRunLogSeverityLevel() to set a logging severity level for a specific session run.
1068+
*
1069+
* \since Version 1.24.
1070+
*/
1071+
int32_t logging_severity_level;
1072+
1073+
/** \brief The log identifier. Must be set to a valid UTF-8 null-terminated string.
1074+
*
1075+
* \note This string identifier is copied by ORT.
1076+
*
1077+
* \since Version 1.24.
1078+
*/
1079+
const char* log_id;
1080+
1081+
/** \brief Optional custom logging function. May be set to NULL.
1082+
*
1083+
* \note The OrtEnvCreationOptions::custom_logging_param is provided as the first argument to this logging function.
1084+
* This allows passing custom state into the logging function.
1085+
*
1086+
* \note This function is only called when a message's severity meets or exceeds the set logging severity level.
1087+
*
1088+
* \since Version 1.24.
1089+
*/
1090+
OrtLoggingFunction custom_logging_function;
1091+
1092+
/** \brief Optional state to pass as the first argument to OrtEnvCreationOptions::custom_logger_function.
1093+
* May be set to NULL.
1094+
*
1095+
* \since Version 1.24.
1096+
*/
1097+
void* custom_logging_param;
1098+
1099+
/** \brief Optional threading options for creating an environment with global thread pools shared across sessions.
1100+
* May be set to NULL.
1101+
*
1102+
* \note The OrtThreadingOptions instance is copied by ORT.
1103+
*
1104+
* \note Use OrtApi::CreateThreadingOptions() to create an instance of OrtThreadingOptions.
1105+
*
1106+
* \note Use this in conjunction with OrtApi::DisablePerSessionThreads or else the session will use its own
1107+
* thread pools.
1108+
*
1109+
* \since Version 1.24.
1110+
*/
1111+
const OrtThreadingOptions* threading_options;
1112+
1113+
/** \brief Optional environment configuration entries represented as string key-value pairs. May be set to NULL.
1114+
*
1115+
* \note The OrtKeyValuePairs instance is copied by ORT.
1116+
*
1117+
* \note Refer to onnxruntime_env_config_keys.h for common config entry keys and their supported values.
1118+
*
1119+
* \note An application provides environment-level configuration options for execution provider libraries by
1120+
* using keys with the prefix 'ep_factory.<ep_name>.'. Ex: the key 'ep_factory.my_ep.some_ep_key' represents
1121+
* a key named 'some_ep_key' that is meant to be consumed by an execution provider named 'my_ep'. Refer to
1122+
* the specific execution provider's documentation for valid keys and values.
1123+
*
1124+
* \note An application may separately set session-level configuration options for execution providers via other APIs
1125+
* such as SessionOptionsAppendExecutionProvider_V2, which store configuration entries within OrtSessionOptions.
1126+
* If an environment-level configuration conflicts with a session-level configuration, then
1127+
* precedence is determined by the execution provider library itself.
1128+
*
1129+
* \since Version 1.24.
1130+
*/
1131+
const OrtKeyValuePairs* config_entries;
1132+
1133+
//
1134+
// End of fields available in ORT 1.24
1135+
//
1136+
1137+
} OrtEnvCreationOptions;
1138+
1139+
/** \brief The C API
1140+
*
1141+
* All C API functions are defined inside this structure as pointers to functions.
1142+
* Call OrtApiBase::GetApi to get a pointer to it
1143+
*
1144+
* \nosubgrouping
1145+
*/
10591146
struct OrtApi {
10601147
/// \name OrtStatus
10611148
/// @{
@@ -6912,6 +6999,20 @@ struct OrtApi {
69126999
ORT_CLASS_RELEASE(DeviceEpIncompatibilityDetails);
69137000

69147001
/// @}
7002+
7003+
/** \brief Create an OrtEnv instance with the given options.
7004+
*
7005+
* \note Invoking this function will return the same instance of the environment as that returned by a previous call
7006+
* to another env creation function; all arguments to this function will be ignored.
7007+
*
7008+
* \param[in] options The OrtEnvCreationOptions instance that contains creation options.
7009+
* \param[out] out Output parameter set to the new OrtEnv instance. Must be freed with OrtApi::ReleaseEnv.
7010+
*
7011+
* \snippet{doc} snippets.dox OrtStatus Return Value
7012+
*
7013+
* \since Version 1.24
7014+
*/
7015+
ORT_API2_STATUS(CreateEnvWithOptions, _In_ const OrtEnvCreationOptions* options, _Outptr_ OrtEnv** out);
69157016
};
69167017

69177018
/*

include/onnxruntime/core/session/onnxruntime_cxx_api.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1185,6 +1185,9 @@ struct Env : detail::Base<OrtEnv> {
11851185
Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction logging_function, void* logger_param,
11861186
OrtLoggingLevel logging_level = ORT_LOGGING_LEVEL_WARNING, _In_ const char* logid = "");
11871187

1188+
/// \brief Wraps OrtApi::CreateEnvWithOptions
1189+
explicit Env(const OrtEnvCreationOptions* options);
1190+
11881191
/// \brief C Interop Helper
11891192
explicit Env(OrtEnv* p) : Base<OrtEnv>{p} {}
11901193

@@ -3431,5 +3434,8 @@ struct SharedPrePackedWeightCacheImpl : Ort::detail::Base<T> {
34313434
*/
34323435
using UnownedSharedPrePackedWeightCache =
34333436
detail::SharedPrePackedWeightCacheImpl<Ort::detail::Unowned<OrtSharedPrePackedWeightCache>>;
3437+
3438+
///< Wraps OrtEpApi::GetEnvConfigEntries()
3439+
Ort::KeyValuePairs GetEnvConfigEntries();
34343440
} // namespace Ort
34353441
#include "onnxruntime_cxx_inline.h"

include/onnxruntime/core/session/onnxruntime_cxx_inline.h

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,15 @@ inline Env::Env(const OrtThreadingOptions* tp_options, OrtLoggingFunction loggin
784784
}
785785
}
786786

787+
inline Env::Env(const OrtEnvCreationOptions* options) {
788+
ThrowOnError(GetApi().CreateEnvWithOptions(options, &p_));
789+
if (strcmp(options->log_id, "onnxruntime-node") == 0) {
790+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_NODEJS));
791+
} else {
792+
ThrowOnError(GetApi().SetLanguageProjection(p_, OrtLanguageProjection::ORT_PROJECTION_CPLUSPLUS));
793+
}
794+
}
795+
787796
inline Env& Env::EnableTelemetryEvents() {
788797
ThrowOnError(GetApi().EnableTelemetryEvents(p_));
789798
return *this;
@@ -3779,4 +3788,11 @@ inline Status SharedPrePackedWeightCacheImpl<T>::StoreWeightData(void** buffer_d
37793788
num_buffers)};
37803789
}
37813790
} // namespace detail
3791+
3792+
inline Ort::KeyValuePairs GetEnvConfigEntries() {
3793+
OrtKeyValuePairs* entries = nullptr;
3794+
Ort::ThrowOnError(GetEpApi().GetEnvConfigEntries(&entries));
3795+
3796+
return Ort::KeyValuePairs{entries};
3797+
}
37823798
} // namespace Ort

0 commit comments

Comments
 (0)