Skip to content

Commit 765b180

Browse files
authored
Improve encapsulation of weight_data_type_freq_ (#27292)
### Description Move weight_data_type_freq_ to private section and add public getter method GetWeightDataTypeFrequency() to provide controlled read-only access. Update inference_session.cc to use the new getter instead of direct member access. ### Motivation and Context To ensure we follow best practices when accessing class variables.
1 parent 9adf238 commit 765b180

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

include/onnxruntime/core/graph/graph.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,13 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
779779
bool IsSparseInitializer(const std::string& name) const;
780780
#endif
781781

782+
#if !defined(ORT_MINIMAL_BUILD)
783+
/** Gets the frequency count of weight data types in this graph. */
784+
gsl::span<const int32_t> GetWeightDataTypeFrequency() const noexcept {
785+
return weight_data_type_freq_;
786+
}
787+
#endif
788+
782789
/** Gets an initializer tensor with the provided name.
783790
@param[out] value Set to the TensorProto* if the initializer is found, or nullptr if not.
784791
@returns True if found.
@@ -1608,9 +1615,9 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
16081615

16091616
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Graph);
16101617

1618+
private:
16111619
int32_t weight_data_type_freq_[ONNX_NAMESPACE::TensorProto_DataType_DataType_ARRAYSIZE] = {0};
16121620

1613-
private:
16141621
void InitializeStateFromModelFileGraphProto();
16151622

16161623
// Add node with specified <node_proto>.

onnxruntime/core/session/inference_session.cc

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,15 +1818,16 @@ static bool ModelHasFP16Inputs(const Graph& graph) {
18181818
#if !defined(ORT_MINIMAL_BUILD)
18191819
[[maybe_unused]] static std::string ModelWeightDataType(const Graph& graph) {
18201820
std::string data_type_list;
1821+
auto weight_freq = graph.GetWeightDataTypeFrequency();
18211822

1822-
for (int i = 0; i < ONNX_NAMESPACE::TensorProto_DataType_DataType_ARRAYSIZE; ++i) {
1823-
if (graph.weight_data_type_freq_[i] > 0) {
1823+
for (size_t i = 0; i < weight_freq.size(); ++i) {
1824+
if (weight_freq[i] > 0) {
18241825
if (!data_type_list.empty()) {
18251826
data_type_list += ", ";
18261827
}
1827-
data_type_list += TensorProto_DataType_Name(i);
1828+
data_type_list += TensorProto_DataType_Name(static_cast<int>(i));
18281829
data_type_list += ": ";
1829-
data_type_list += std::to_string(graph.weight_data_type_freq_[i]);
1830+
data_type_list += std::to_string(weight_freq[i]);
18301831
}
18311832
}
18321833

0 commit comments

Comments
 (0)