Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion include/onnxruntime/core/graph/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -779,6 +779,11 @@
bool IsSparseInitializer(const std::string& name) const;
#endif

/** Gets the frequency count of weight data types in this graph. */
gsl::span<const int32_t> GetWeightDataTypeFrequency() const noexcept {
return weight_data_type_freq_;

Check failure on line 784 in include/onnxruntime/core/graph/graph.h

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'weight_data_type_freq_': undeclared identifier

Check failure on line 784 in include/onnxruntime/core/graph/graph.h

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'weight_data_type_freq_': undeclared identifier

Check failure on line 784 in include/onnxruntime/core/graph/graph.h

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'weight_data_type_freq_': undeclared identifier

Check failure on line 784 in include/onnxruntime/core/graph/graph.h

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'weight_data_type_freq_': undeclared identifier

Check failure on line 784 in include/onnxruntime/core/graph/graph.h

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'weight_data_type_freq_': undeclared identifier

Check failure on line 784 in include/onnxruntime/core/graph/graph.h

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'weight_data_type_freq_': undeclared identifier

Check failure on line 784 in include/onnxruntime/core/graph/graph.h

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'weight_data_type_freq_': undeclared identifier

Check failure on line 784 in include/onnxruntime/core/graph/graph.h

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'weight_data_type_freq_': undeclared identifier

Check failure on line 784 in include/onnxruntime/core/graph/graph.h

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'weight_data_type_freq_': undeclared identifier

Check failure on line 784 in include/onnxruntime/core/graph/graph.h

View workflow job for this annotation

GitHub Actions / webgpu_minimal_build_edge_build_x64_RelWithDebInfo

'weight_data_type_freq_': undeclared identifier
}

/** Gets an initializer tensor with the provided name.
@param[out] value Set to the TensorProto* if the initializer is found, or nullptr if not.
@returns True if found.
Expand Down Expand Up @@ -1608,9 +1613,9 @@

ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Graph);

private:
int32_t weight_data_type_freq_[ONNX_NAMESPACE::TensorProto_DataType_DataType_ARRAYSIZE] = {0};

private:
void InitializeStateFromModelFileGraphProto();

// Add node with specified <node_proto>.
Expand Down
9 changes: 5 additions & 4 deletions onnxruntime/core/session/inference_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1818,15 +1818,16 @@ static bool ModelHasFP16Inputs(const Graph& graph) {
#if !defined(ORT_MINIMAL_BUILD)
[[maybe_unused]] static std::string ModelWeightDataType(const Graph& graph) {
std::string data_type_list;
auto weight_freq = graph.GetWeightDataTypeFrequency();

for (int i = 0; i < ONNX_NAMESPACE::TensorProto_DataType_DataType_ARRAYSIZE; ++i) {
if (graph.weight_data_type_freq_[i] > 0) {
for (size_t i = 0; i < weight_freq.size(); ++i) {
if (weight_freq[i] > 0) {
if (!data_type_list.empty()) {
data_type_list += ", ";
}
data_type_list += TensorProto_DataType_Name(i);
data_type_list += TensorProto_DataType_Name(static_cast<int>(i));
data_type_list += ": ";
data_type_list += std::to_string(graph.weight_data_type_freq_[i]);
data_type_list += std::to_string(weight_freq[i]);
}
}

Expand Down
Loading