diff --git a/include/onnxruntime/core/graph/graph.h b/include/onnxruntime/core/graph/graph.h index 804f4557fd321..58473a79ddaa6 100644 --- a/include/onnxruntime/core/graph/graph.h +++ b/include/onnxruntime/core/graph/graph.h @@ -779,6 +779,13 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi bool IsSparseInitializer(const std::string& name) const; #endif +#if !defined(ORT_MINIMAL_BUILD) + /** Gets the frequency count of weight data types in this graph. */ + gsl::span GetWeightDataTypeFrequency() const noexcept { + return weight_data_type_freq_; + } +#endif + /** Gets an initializer tensor with the provided name. @param[out] value Set to the TensorProto* if the initializer is found, or nullptr if not. @returns True if found. @@ -1608,9 +1615,9 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Graph); + private: int32_t weight_data_type_freq_[ONNX_NAMESPACE::TensorProto_DataType_DataType_ARRAYSIZE] = {0}; - private: void InitializeStateFromModelFileGraphProto(); // Add node with specified . diff --git a/onnxruntime/core/session/inference_session.cc b/onnxruntime/core/session/inference_session.cc index 0944be87591e2..c00d63d0be8a2 100644 --- a/onnxruntime/core/session/inference_session.cc +++ b/onnxruntime/core/session/inference_session.cc @@ -1818,15 +1818,16 @@ static bool ModelHasFP16Inputs(const Graph& graph) { #if !defined(ORT_MINIMAL_BUILD) [[maybe_unused]] static std::string ModelWeightDataType(const Graph& graph) { std::string data_type_list; + auto weight_freq = graph.GetWeightDataTypeFrequency(); - for (int i = 0; i < ONNX_NAMESPACE::TensorProto_DataType_DataType_ARRAYSIZE; ++i) { - if (graph.weight_data_type_freq_[i] > 0) { + for (size_t i = 0; i < weight_freq.size(); ++i) { + if (weight_freq[i] > 0) { if (!data_type_list.empty()) { data_type_list += ", "; } - data_type_list += TensorProto_DataType_Name(i); + data_type_list += TensorProto_DataType_Name(static_cast(i)); data_type_list += ": "; - data_type_list += std::to_string(graph.weight_data_type_freq_[i]); + data_type_list += std::to_string(weight_freq[i]); } }