Skip to content

Commit 27019c2

Browse files
committed
Improve encapsulation of weight_data_type_freq_
Move weight_data_type_freq_ to private section and add public getter method GetWeightDataTypeFrequency() to provide controlled read-only access. Update inference_session.cc to use the new getter instead of direct member access.
1 parent a3749f1 commit 27019c2

File tree

2 files changed

+9
-3
lines changed

2 files changed

+9
-3
lines changed

include/onnxruntime/core/graph/graph.h

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -779,6 +779,11 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
779779
bool IsSparseInitializer(const std::string& name) const;
780780
#endif
781781

782+
/** Gets the frequency count of weight data types in this graph. */
783+
const int32_t* GetWeightDataTypeFrequency() const noexcept {
784+
return weight_data_type_freq_;
785+
}
786+
782787
/** Gets an initializer tensor with the provided name.
783788
@param[out] value Set to the TensorProto* if the initializer is found, or nullptr if not.
784789
@returns True if found.
@@ -1608,9 +1613,9 @@ class Graph { // NOLINT(clang-analyzer-optin.performance.Padding): preserve exi
16081613

16091614
ORT_DISALLOW_COPY_ASSIGNMENT_AND_MOVE(Graph);
16101615

1616+
private:
16111617
int32_t weight_data_type_freq_[ONNX_NAMESPACE::TensorProto_DataType_DataType_ARRAYSIZE] = {0};
16121618

1613-
private:
16141619
void InitializeStateFromModelFileGraphProto();
16151620

16161621
// Add node with specified <node_proto>.

onnxruntime/core/session/inference_session.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1818,15 +1818,16 @@ static bool ModelHasFP16Inputs(const Graph& graph) {
18181818
#if !defined(ORT_MINIMAL_BUILD)
18191819
[[maybe_unused]] static std::string ModelWeightDataType(const Graph& graph) {
18201820
std::string data_type_list;
1821+
const int32_t* weight_freq = graph.GetWeightDataTypeFrequency();
18211822

18221823
for (int i = 0; i < ONNX_NAMESPACE::TensorProto_DataType_DataType_ARRAYSIZE; ++i) {
1823-
if (graph.weight_data_type_freq_[i] > 0) {
1824+
if (weight_freq[i] > 0) {
18241825
if (!data_type_list.empty()) {
18251826
data_type_list += ", ";
18261827
}
18271828
data_type_list += TensorProto_DataType_Name(i);
18281829
data_type_list += ": ";
1829-
data_type_list += std::to_string(graph.weight_data_type_freq_[i]);
1830+
data_type_list += std::to_string(weight_freq[i]);
18301831
}
18311832
}
18321833

0 commit comments

Comments
 (0)